-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core][compiled graphs] Support reduce scatter collective in compiled graph #49404
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: Puyuan Yao <[email protected]> rebase to updated main branch
Signed-off-by: Puyuan Yao <[email protected]>
python/ray/dag/collective_node.py
Outdated
communicator.allreduce(send_buf, recv_buf, self._op) | ||
elif self.comm_op == self.REDUCESCATTER: | ||
world_size = len(self._actor_handles) | ||
assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm planning on adding a test on throwing an error when input tensors' first dimension is not divisible by the number of reduce scatter participants. This is a RayTaskError
, I'm not sure what is the expected behavior of compiled graph when this error is generated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Raise ValueError for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good! Left some comments to polish.
PARENT_CLASS_NODE_KEY, | ||
) | ||
from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType | ||
from ray.experimental.util.types import ReduceOp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create two new enums AllReduceOp
and ReduceScatterOp
inherited from ReduceOp
. They do not need to introduce new enum values, just use pass
. Such that we don't need a str to tell whether this is a allreduce or reducescatter, simply just check if their types are AllReduceOp
and ReduceScatterOp
.
def bind( | ||
self, | ||
input_nodes: List["ray.dag.DAGNode"], | ||
op: ReduceOp = ReduceOp.SUM, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After introducing the two new enums, this would be ReduceScatterOp.SUM
. Also update for allreduce to use AllReduceOp.SUM
. Make sure update all the places including allreduce.py
and tests.
|
||
|
||
@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) | ||
def test_torch_tensor_nccl_reduce_scatter_different_shapes_among_participants( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this test same as all_reduce_wrong_shape
? If so, we should use the same names for both. wrong_shape
is fine to keep.
self, | ||
send_buf: "torch.Tensor", | ||
recv_buf: "torch.Tensor", | ||
op: ReduceOp = ReduceOp.SUM, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After introducing two new enums, we need to update all of these.
python/ray/dag/collective_node.py
Outdated
@@ -31,9 +31,13 @@ class _CollectiveOperation: | |||
3. Actor handles match the custom NCCL group if specified. | |||
""" | |||
|
|||
ALLREDUCE = "ar" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After introducing the two new enums, we don't need these.
python/ray/dag/collective_node.py
Outdated
world_size = len(self._actor_handles) | ||
assert ( | ||
send_buf.shape[0] % world_size == 0 | ||
), "Input tensor's first dimension should be divisible by " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Raise a ValueError here. For all the user-input errors, we would raise a ValueError instead of asserting. Assert is used for checking internal states.
python/ray/dag/collective_node.py
Outdated
communicator.allreduce(send_buf, recv_buf, self._op) | ||
elif self.comm_op == self.REDUCESCATTER: | ||
world_size = len(self._actor_handles) | ||
assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Raise ValueError for now.
@@ -9,6 +9,23 @@ class _CollectiveOp(Enum): | |||
|
|||
@PublicAPI | |||
class ReduceOp(_CollectiveOp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the types to this because python has strict requirements on subclassing enum.
Why are these changes needed?
Currently we do not have other collective operations except allreduce in Ray Compiled Graphs, we plan to add the other collective operations required in FSDP in the future.
Proposed API:
Related issue number
Meta-issue: #47983
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.