From ce4e8de624f1e7f7ecb483f255bbfba7406c4fb7 Mon Sep 17 00:00:00 2001 From: xrsrke Date: Wed, 11 Oct 2023 08:00:41 +0700 Subject: [PATCH] WIP implemented vanila ZeRO-1 --- pipegoose/optim/zero/optim.py | 60 ++++++++++++++-------------- pipegoose/optim/zero/sharding.py | 2 +- tests/nn/optim/zero/test_optim.py | 3 +- tests/nn/optim/zero/test_sharding.py | 4 +- 4 files changed, 35 insertions(+), 34 deletions(-) diff --git a/pipegoose/optim/zero/optim.py b/pipegoose/optim/zero/optim.py index fabfc4c..7dd5098 100644 --- a/pipegoose/optim/zero/optim.py +++ b/pipegoose/optim/zero/optim.py @@ -1,46 +1,38 @@ +from torch._utils import _flatten_dense_tensors from torch.optim import Optimizer from pipegoose.distributed.functional import broadcast from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.distributed.parallel_mode import ParallelMode from pipegoose.optim import BaseDistributedOptimizer -from pipegoose.optim.zero.sharding import ParameterSharding +from pipegoose.optim.zero.sharding import OptimizerStateSharding class DistributedOptimizer(BaseDistributedOptimizer): """ZeRO-1 optimizer that works natively in 3D parallelism.""" - def __init__( - self, - optim: Optimizer, - parallel_context: ParallelContext, - ): + def __init__(self, optim: Optimizer, parallel_context: ParallelContext): self.optim = optim self.parallel_context = parallel_context - self._master_params = None - + self._master_params = {} self._setup_local_optim() - # def _sync_hyperparams(self, source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]): - # for source_group, destination_group in zip(source, destination): - # for k in source_group.keys(): - # if k != "params": - # destination_group[k] = source_group[k] - def _setup_local_optim(self): """Setup local optimizer.""" - local_rank = self.parallel_context.get_local_rank(ParallelMode.DATA) - # optim = self._optim_constructor(self.params, **self.default) # NOTE: shard and assign the corresponding local parameters to the local optimizer - sharded_param_groups = ParameterSharding(self.optim.param_groups, self.parallel_context, ParallelMode.DATA).shard() - self._master_params = {rank: params for rank, params in enumerate(sharded_param_groups)} - self.optim.param_groups = sharded_param_groups[local_rank] + for i, param_groups in enumerate(self.optim.param_groups): + self._master_params[i] = param_groups["params"] + + sharded_param_groups = OptimizerStateSharding( + self.optim.param_groups, self.parallel_context, ParallelMode.DATA + ).shard() + ranks_in_group = self.parallel_context.get_ranks_in_group(ParallelMode.DATA) + self._rank_to_params = {rank: params for rank, params in zip(ranks_in_group, sharded_param_groups)} - # def _construct_local_optim(self, local_params: Dict[str, torch.Tensor]) -> Optimizer: - # optim = self._optim_constructor(local_params, **self.default) - # return optim + local_rank = self.parallel_context.get_local_rank(ParallelMode.DATA) + self.optim.param_groups = self._rank_to_params[local_rank] @property def defaults(self): @@ -60,22 +52,32 @@ def load_state_dict(self, *args, **kwargs): """Load the optimizer state.""" self.optim.load_state_dict(*args, **kwargs) - def state_dict(self): + def state_dict(self, *args, **kwargs): """Return the state of the optimizer""" - return self.optim.state_dict() + return self.optim.state_dict(*args, **kwargs) + + # def _update_master_params(self): + # """Update the master parameters from the updated local parameters.""" + # local_rank = self.parallel_context.get_local_rank(ParallelMode.DATA) + + # for i, param_groups in enumerate(self.optim.param_groups): + # updated_params = all_gather(param_groups["params"], self.parallel_context, ParallelMode.DATA) + # self._master_params[i] = updated_params[local_rank] def step(self, *args, **kwargs): # NOTE: each rank updates its subset of parameters using the local optimizer self.optim.step(*args, **kwargs) # NOTE: update the master parameters from the updated local parameters + # self._update_master_params() + + # NOTE: gather the full updated parameters from all ranks # NOTE: each model replicas broadcast the updated parameters to other model replicas - rank = self.parallel_context.get_local_rank(ParallelMode.DATA) - for group in self.optim.param_groups: - for p in group["params"]: - if p.requires_grad is True: - broadcast(p, src=rank, parallel_context=self.parallel_context, parallel_mode=ParallelMode.DATA) + for rank, param_groups in self._rank_to_params.items(): + flatten_params = _flatten_dense_tensors(param_groups[0]["params"]) + broadcast(flatten_params, src=rank, parallel_context=self.parallel_context, parallel_mode=ParallelMode.DATA) + assert 1 == 1 def zero_grad(self, *args, **kwargs): """Zero out gradients.""" diff --git a/pipegoose/optim/zero/sharding.py b/pipegoose/optim/zero/sharding.py index 582a4ce..3adae71 100644 --- a/pipegoose/optim/zero/sharding.py +++ b/pipegoose/optim/zero/sharding.py @@ -7,7 +7,7 @@ from pipegoose.distributed.parallel_mode import ParallelMode -class ParameterSharding: +class OptimizerStateSharding: """ Shard optimizer parameters across parallelism dimension. diff --git a/tests/nn/optim/zero/test_optim.py b/tests/nn/optim/zero/test_optim.py index ba4d313..2624bb7 100644 --- a/tests/nn/optim/zero/test_optim.py +++ b/tests/nn/optim/zero/test_optim.py @@ -44,7 +44,7 @@ def run_dist_optim( # NOTE: test whether the model parameters are updated correctly for p1, p2 in zip(model.parameters(), ORIG_UPDATED_MODEL.parameters()): - assert not torch.allclose(p1, p2), f"p1: {p1}, p2: {p2}" + assert torch.allclose(p1, p2), f"p1: {p1}, p2: {p2}" # dist_optimizer.zero_grad() @@ -52,7 +52,6 @@ def run_dist_optim( # assert p.grad is None -@pytest.mark.skip @pytest.mark.parametrize("data_parallel_size", [2, 5]) def test_dist_optim(data_parallel_size): TENSOR_PARALLEL_SIZE = 1 diff --git a/tests/nn/optim/zero/test_sharding.py b/tests/nn/optim/zero/test_sharding.py index 21a3dc3..d1dfe62 100644 --- a/tests/nn/optim/zero/test_sharding.py +++ b/tests/nn/optim/zero/test_sharding.py @@ -6,7 +6,7 @@ from transformers import AutoModel from pipegoose.distributed.parallel_mode import ParallelMode -from pipegoose.optim.zero.sharding import ParameterSharding +from pipegoose.optim.zero.sharding import OptimizerStateSharding from pipegoose.testing.utils import init_parallel_context, spawn @@ -35,7 +35,7 @@ def calculate_total_sharded_elements(sharded_params): optim = SGD(model.parameters(), lr=0.01) param_groups = optim.param_groups - sharder = ParameterSharding(param_groups, parallel_context, ParallelMode.DATA) + sharder = OptimizerStateSharding(param_groups, parallel_context, ParallelMode.DATA) sharded_params = sharder.shard() assert len(sharded_params) == world_size