-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DDP support to hivemind.optim #475
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
import time | ||
from typing import Callable, Optional, Union | ||
|
||
import torch | ||
from torch.distributed.distributed_c10d import _get_default_group, _get_default_store | ||
|
||
from hivemind.dht import DHT | ||
from hivemind.optim.grad_scaler import GradScaler | ||
from hivemind.optim.optimizer import Optimizer | ||
from hivemind.optim.state_averager import OptimizerFactory, Parameters, ParamGroups, TorchOptimizer, TrainingStateAverager | ||
from hivemind.utils import get_logger | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class DDPOptimizer(Optimizer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note to self: A better way to do it is:
|
||
_DDP_LEADER_RANK = 0 | ||
_BROADCAST_BUFFER_SIZE = 250 * 1024 ** 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New pytorch seems to have finally implemented broadcast_coalesced in distributed, |
||
|
||
@staticmethod | ||
def is_ddp_enabled(): | ||
return torch.distributed.is_initialized() | ||
|
||
@staticmethod | ||
def is_ddp_leader(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: would recommend reusing the same terminology as somewhere, such as inside DistributedDataParallel For instance, the above DDP uses
|
||
return not torch.distributed.is_initialized() or torch.distributed.get_rank() == DDPOptimizer._DDP_LEADER_RANK | ||
|
||
def __init__( | ||
self, | ||
*, | ||
dht: Optional[DHT] = None, | ||
optimizer: Union[TorchOptimizer, OptimizerFactory], | ||
params: Optional[Union[Parameters, ParamGroups]] = None, | ||
reuse_grad_buffers: bool = False, | ||
use_local_updates: bool = False, | ||
**kwargs | ||
): | ||
if self.is_ddp_leader() != (dht is not None): | ||
class_name = self.__class__.__name__ | ||
raise ValueError( | ||
f"{class_name}(dht=...) is expected to be a hivemind.DHT instance " | ||
f"if {class_name}.is_ddp_leader(), None otherwise. " | ||
f"Please write code as follows:\n\n" | ||
f"if {class_name}.is_ddp_leader():\n" | ||
f" dht = hivemind.DHT(...)\n" | ||
f"else:\n" | ||
f" dht = None\n" | ||
f"optimizer = {class_name}(dht=dht, ...)" | ||
) | ||
|
||
if self.is_ddp_leader(): | ||
super().__init__( | ||
dht, | ||
optimizer, | ||
params, | ||
reuse_grad_buffers, | ||
use_local_updates, | ||
**kwargs | ||
) | ||
self._main_parameters = self.state_averager.main_parameters | ||
else: | ||
self._param_groups, self._main_parameters, _ = TrainingStateAverager.check_params(optimizer, params) | ||
self.reuse_grad_buffers, self.use_local_updates = reuse_grad_buffers, use_local_updates | ||
|
||
self._checksum_counter = 0 | ||
self._prev_version = self._prev_epoch = -1 | ||
self._sync_among_ddp_ranks() | ||
|
||
# Collect fields of DDPOptimizer and its descendants | ||
self._ddp_aware_fields = set(self.__dict__.keys()) | ||
for klass in self.__mro__: | ||
self._ddp_aware_fields.update(klass.__dict__.keys()) | ||
if klass is DDPOptimizer: | ||
break | ||
|
||
def __getattribute__(self, name: str): | ||
""" | ||
This works as usual on leaders, but denies access to non DDP-aware fields | ||
(i.e., fields defined in DDPOptimizer ancestors) on followers. | ||
""" | ||
|
||
if ( | ||
not name.startswith("_") and | ||
name not in self._ddp_aware_fields and | ||
not DDPOptimizer.is_ddp_leader() | ||
): | ||
raise RuntimeError( | ||
f"{self.__class__.__name__}.{name} is only available on the DDP leader. " | ||
f"Please access it only if DDPOptimizer.is_ddp_leader() == True" | ||
) | ||
|
||
return super().__getattribute__(name) | ||
|
||
def is_alive(self) -> bool: | ||
# On followers, this always returns False since there's nothing to shut down in __del__() | ||
return self.is_ddp_leader() and super().is_alive() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if leader: |
||
|
||
def _compute_state_version(self) -> int: | ||
"""Return a non-decreasing integer that goes up whenever model params and/or buffers were updated""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is meant as a workaround to catch the moment when optimizer has updated parameters (load from peers, apply optimizer step average params) All changes to state are currently handled in StateAverager. |
||
|
||
assert self.is_ddp_leader() | ||
return sum(state["step"] for state in self.opt.state.values()) | ||
|
||
def _has_updated_params_after_sync(self) -> bool: | ||
if not self.is_ddp_enabled(): | ||
return False | ||
|
||
store = _get_default_store() | ||
if self.is_ddp_leader(): | ||
current_version = self._compute_state_version() | ||
if current_version == self._prev_version and self.local_epoch > self._prev_epoch + 1: | ||
logger.warning("Model state version has not changed during a full epoch; " | ||
"broadcasting parameters between torch.distributed synchronization may be broken") | ||
|
||
should_broadcast = (current_version != self._prev_version or self.local_epoch > self._prev_epoch + 1) | ||
|
||
store.set(f"_hivemind_should_broadcast_state", str(int(should_broadcast))) | ||
torch.distributed.barrier() | ||
return should_broadcast | ||
else: | ||
torch.distributed.barrier() | ||
raw_should_broadcast = store.get(f"_hivemind_should_broadcast_state") | ||
return bool(int(raw_should_broadcast)) | ||
|
||
def _sync_among_ddp_ranks(self) -> None: | ||
"""Synchronize model params and buffers from the DDP leader""" | ||
|
||
if not self.is_ddp_enabled(): | ||
return | ||
|
||
t_start = time.perf_counter() | ||
with torch.no_grad(): | ||
torch.distributed._broadcast_coalesced( | ||
_get_default_group(), self._main_parameters, self._BROADCAST_BUFFER_SIZE, self._DDP_LEADER_RANK | ||
) | ||
if self.is_ddp_leader(): | ||
self._prev_version = self._compute_state_version() | ||
self._prev_epoch = self.local_epoch | ||
elapsed = time.perf_counter() - t_start | ||
logger.debug(f"Broadcasting leader params among DDP ranks took {elapsed:.2f} sec") | ||
|
||
def step( | ||
self, | ||
closure: Optional[Callable[[], torch.Tensor]] = None, | ||
batch_size: Optional[int] = None, | ||
grad_scaler: Optional[GradScaler] = None, | ||
): | ||
if self.is_ddp_leader(): | ||
loss = super().step(closure, batch_size, grad_scaler) | ||
|
||
if self._has_updated_params_after_sync(): | ||
self._sync_among_ddp_ranks() | ||
else: | ||
logger.debug("No need to broadcast leader params among DDP ranks") | ||
|
||
if self.is_ddp_enabled(): | ||
self._checksum_counter += 1 | ||
if self._checksum_counter % 100 == 0: | ||
rank = torch.distributed.get_rank() | ||
checksum = sum(p.sum().item() for p in self._main_parameters) | ||
logger.debug(f"Parameter checksum (ddp_rank={rank}): {float(checksum)}") | ||
|
||
return loss if self.is_ddp_leader() else None | ||
|
||
def load_state_from_peers(self, **kwargs) -> None: | ||
if self.is_ddp_leader(): | ||
super().load_state_from_peers(**kwargs) | ||
|
||
self._sync_among_ddp_ranks() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not synchronize here: non-master ranks cannot call this and we will deadlock. We should only sync in step -- and after step check IF master updated/loaded/averaged step and then broadcast. |
||
|
||
def load_state_dict(self, state_dict: dict) -> None: | ||
if self.is_ddp_leader(): | ||
super().load_state_dict(state_dict) | ||
|
||
self._sync_among_ddp_ranks() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not synchronize here: non-master ranks cannot call this and we will deadlock: see load_state_from_peers |
||
|
||
@property | ||
def param_groups(self) -> ParamGroups: | ||
if self.is_ddp_leader(): | ||
return super().param_groups | ||
else: | ||
return self._param_groups | ||
|
||
def zero_grad(self, set_to_none: bool = False): | ||
# We explicitly define this method to mark that it should be available on the DDP followers | ||
super().zero_grad(set_to_none) | ||
|
||
def shutdown(self): | ||
if self.is_ddp_leader(): | ||
super().shutdown() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional: else raise NotImplemented or warn? |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -102,7 +102,7 @@ def __init__( | |||||
if reuse_tensors and delta_rule_averaging: | ||||||
raise ValueError("reuse_tensors and delta_rule_averaging are mutually exclusive") | ||||||
|
||||||
param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names) | ||||||
param_groups, main_parameters, parameter_names = self.check_params(optimizer, params, parameter_names) | ||||||
|
||||||
self.status_loglevel = status_loglevel | ||||||
self.offload_optimizer, self.custom_gradients = offload_optimizer, custom_gradients | ||||||
|
@@ -131,10 +131,10 @@ def __init__( | |||||
) | ||||||
|
||||||
@staticmethod | ||||||
def _check_params( | ||||||
def check_params( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
optimizer: Union[TorchOptimizer, OptimizerFactory], | ||||||
param_groups: Optional[Union[Parameters, ParamGroups]], | ||||||
parameter_names: Optional[Sequence[str]], | ||||||
param_groups: Optional[Union[Parameters, ParamGroups]] = None, | ||||||
parameter_names: Optional[Sequence[str]] = None, | ||||||
) -> Tuple[ParamGroups, Sequence[torch.Tensor], Sequence[str]]: | ||||||
"""Get and verify parameters, groups and names""" | ||||||
if param_groups is None: | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: A better way to do it is:
hivemind.Optimizer
_create_optimizer()
method and forward__init__
's kwargs thereopt
property__getattr__
that can forward attrs toopt