Skip to content

Commit

Permalink
WIP implemented vanila ZeRO-1
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 11, 2023
1 parent f2da44d commit ce4e8de
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 34 deletions.
60 changes: 31 additions & 29 deletions pipegoose/optim/zero/optim.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,38 @@
from torch._utils import _flatten_dense_tensors
from torch.optim import Optimizer

from pipegoose.distributed.functional import broadcast
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.optim import BaseDistributedOptimizer
from pipegoose.optim.zero.sharding import ParameterSharding
from pipegoose.optim.zero.sharding import OptimizerStateSharding


class DistributedOptimizer(BaseDistributedOptimizer):
"""ZeRO-1 optimizer that works natively in 3D parallelism."""

def __init__(
self,
optim: Optimizer,
parallel_context: ParallelContext,
):
def __init__(self, optim: Optimizer, parallel_context: ParallelContext):
self.optim = optim
self.parallel_context = parallel_context

self._master_params = None

self._master_params = {}
self._setup_local_optim()

# def _sync_hyperparams(self, source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]):
# for source_group, destination_group in zip(source, destination):
# for k in source_group.keys():
# if k != "params":
# destination_group[k] = source_group[k]

def _setup_local_optim(self):
"""Setup local optimizer."""
local_rank = self.parallel_context.get_local_rank(ParallelMode.DATA)

# optim = self._optim_constructor(self.params, **self.default)
# NOTE: shard and assign the corresponding local parameters to the local optimizer
sharded_param_groups = ParameterSharding(self.optim.param_groups, self.parallel_context, ParallelMode.DATA).shard()
self._master_params = {rank: params for rank, params in enumerate(sharded_param_groups)}
self.optim.param_groups = sharded_param_groups[local_rank]
for i, param_groups in enumerate(self.optim.param_groups):
self._master_params[i] = param_groups["params"]

sharded_param_groups = OptimizerStateSharding(
self.optim.param_groups, self.parallel_context, ParallelMode.DATA
).shard()
ranks_in_group = self.parallel_context.get_ranks_in_group(ParallelMode.DATA)
self._rank_to_params = {rank: params for rank, params in zip(ranks_in_group, sharded_param_groups)}

# def _construct_local_optim(self, local_params: Dict[str, torch.Tensor]) -> Optimizer:
# optim = self._optim_constructor(local_params, **self.default)
# return optim
local_rank = self.parallel_context.get_local_rank(ParallelMode.DATA)
self.optim.param_groups = self._rank_to_params[local_rank]

@property
def defaults(self):
Expand All @@ -60,22 +52,32 @@ def load_state_dict(self, *args, **kwargs):
"""Load the optimizer state."""
self.optim.load_state_dict(*args, **kwargs)

def state_dict(self):
def state_dict(self, *args, **kwargs):
"""Return the state of the optimizer"""
return self.optim.state_dict()
return self.optim.state_dict(*args, **kwargs)

# def _update_master_params(self):
# """Update the master parameters from the updated local parameters."""
# local_rank = self.parallel_context.get_local_rank(ParallelMode.DATA)

# for i, param_groups in enumerate(self.optim.param_groups):
# updated_params = all_gather(param_groups["params"], self.parallel_context, ParallelMode.DATA)
# self._master_params[i] = updated_params[local_rank]

def step(self, *args, **kwargs):
# NOTE: each rank updates its subset of parameters using the local optimizer
self.optim.step(*args, **kwargs)

# NOTE: update the master parameters from the updated local parameters
# self._update_master_params()

# NOTE: gather the full updated parameters from all ranks

# NOTE: each model replicas broadcast the updated parameters to other model replicas
rank = self.parallel_context.get_local_rank(ParallelMode.DATA)
for group in self.optim.param_groups:
for p in group["params"]:
if p.requires_grad is True:
broadcast(p, src=rank, parallel_context=self.parallel_context, parallel_mode=ParallelMode.DATA)
for rank, param_groups in self._rank_to_params.items():
flatten_params = _flatten_dense_tensors(param_groups[0]["params"])
broadcast(flatten_params, src=rank, parallel_context=self.parallel_context, parallel_mode=ParallelMode.DATA)
assert 1 == 1

def zero_grad(self, *args, **kwargs):
"""Zero out gradients."""
Expand Down
2 changes: 1 addition & 1 deletion pipegoose/optim/zero/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pipegoose.distributed.parallel_mode import ParallelMode


class ParameterSharding:
class OptimizerStateSharding:
"""
Shard optimizer parameters across parallelism dimension.
Expand Down
3 changes: 1 addition & 2 deletions tests/nn/optim/zero/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,14 @@ def run_dist_optim(

# NOTE: test whether the model parameters are updated correctly
for p1, p2 in zip(model.parameters(), ORIG_UPDATED_MODEL.parameters()):
assert not torch.allclose(p1, p2), f"p1: {p1}, p2: {p2}"
assert torch.allclose(p1, p2), f"p1: {p1}, p2: {p2}"

# dist_optimizer.zero_grad()

# for p in model.parameters():
# assert p.grad is None


@pytest.mark.skip
@pytest.mark.parametrize("data_parallel_size", [2, 5])
def test_dist_optim(data_parallel_size):
TENSOR_PARALLEL_SIZE = 1
Expand Down
4 changes: 2 additions & 2 deletions tests/nn/optim/zero/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import AutoModel

from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.optim.zero.sharding import ParameterSharding
from pipegoose.optim.zero.sharding import OptimizerStateSharding
from pipegoose.testing.utils import init_parallel_context, spawn


Expand Down Expand Up @@ -35,7 +35,7 @@ def calculate_total_sharded_elements(sharded_params):
optim = SGD(model.parameters(), lr=0.01)
param_groups = optim.param_groups

sharder = ParameterSharding(param_groups, parallel_context, ParallelMode.DATA)
sharder = OptimizerStateSharding(param_groups, parallel_context, ParallelMode.DATA)
sharded_params = sharder.shard()

assert len(sharded_params) == world_size
Expand Down

0 comments on commit ce4e8de

Please sign in to comment.