Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][compiled graphs] Support reduce scatter collective in compiled graph #49404

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

anyadontfly
Copy link
Contributor

@anyadontfly anyadontfly commented Dec 22, 2024

Why are these changes needed?

Currently we do not have other collective operations except allreduce in Ray Compiled Graphs, we plan to add the other collective operations required in FSDP in the future.

Proposed API:

import ray.experimental.collective as collective

with InputNode() as inp:
    dag = [worker.return_tensor.bind(inp) for worker in workers]
    dag = collective.reducescatter.bind(dag, ReduceOp.SUM)
    dag = MultiOutputNode(dag)

Related issue number

Meta-issue: #47983

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

communicator.allreduce(send_buf, recv_buf, self._op)
elif self.comm_op == self.REDUCESCATTER:
world_size = len(self._actor_handles)
assert (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm planning on adding a test on throwing an error when input tensors' first dimension is not divisible by the number of reduce scatter participants. This is a RayTaskError, I'm not sure what is the expected behavior of compiled graph when this error is generated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Raise ValueError for now.

@jcotant1 jcotant1 added core Issues that should be addressed in Ray Core compiled-graphs labels Dec 23, 2024
Copy link
Contributor

@dengwxn dengwxn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good! Left some comments to polish.

PARENT_CLASS_NODE_KEY,
)
from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType
from ray.experimental.util.types import ReduceOp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create two new enums AllReduceOp and ReduceScatterOp inherited from ReduceOp. They do not need to introduce new enum values, just use pass. Such that we don't need a str to tell whether this is a allreduce or reducescatter, simply just check if their types are AllReduceOp and ReduceScatterOp.

def bind(
self,
input_nodes: List["ray.dag.DAGNode"],
op: ReduceOp = ReduceOp.SUM,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After introducing the two new enums, this would be ReduceScatterOp.SUM. Also update for allreduce to use AllReduceOp.SUM. Make sure update all the places including allreduce.py and tests.



@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_reduce_scatter_different_shapes_among_participants(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test same as all_reduce_wrong_shape? If so, we should use the same names for both. wrong_shape is fine to keep.

self,
send_buf: "torch.Tensor",
recv_buf: "torch.Tensor",
op: ReduceOp = ReduceOp.SUM,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After introducing two new enums, we need to update all of these.

@@ -31,9 +31,13 @@ class _CollectiveOperation:
3. Actor handles match the custom NCCL group if specified.
"""

ALLREDUCE = "ar"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After introducing the two new enums, we don't need these.

world_size = len(self._actor_handles)
assert (
send_buf.shape[0] % world_size == 0
), "Input tensor's first dimension should be divisible by "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise a ValueError here. For all the user-input errors, we would raise a ValueError instead of asserting. Assert is used for checking internal states.

communicator.allreduce(send_buf, recv_buf, self._op)
elif self.comm_op == self.REDUCESCATTER:
world_size = len(self._actor_handles)
assert (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Raise ValueError for now.

@anyadontfly anyadontfly changed the title [compiled graphs] Support reduce scatter collective in compiled graph [core][compiled graphs] Support reduce scatter collective in compiled graph Dec 23, 2024
@@ -9,6 +9,23 @@ class _CollectiveOp(Enum):

@PublicAPI
class ReduceOp(_CollectiveOp):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the types to this because python has strict requirements on subclassing enum.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiled-graphs core Issues that should be addressed in Ray Core
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants