Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DDP support to hivemind.optim #475

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions hivemind/optim/ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import time
from typing import Callable, Optional, Union

import torch
from torch.distributed.distributed_c10d import _get_default_group, _get_default_store

from hivemind.dht import DHT
from hivemind.optim.grad_scaler import GradScaler
from hivemind.optim.optimizer import Optimizer
from hivemind.optim.state_averager import OptimizerFactory, Parameters, ParamGroups, TorchOptimizer, TrainingStateAverager
from hivemind.utils import get_logger

logger = get_logger(__name__)


class DDPOptimizer(Optimizer):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: A better way to do it is:

  • Don't inherit hivemind.Optimizer
  • Make _create_optimizer() method and forward __init__'s kwargs there
  • Make opt property
  • Maybe create __getattr__ that can forward attrs to opt

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: A better way to do it is:

  • Don't inherit hivemind.Optimizer
  • Make _create_optimizer() method and forward __init__'s kwargs there
  • Make opt property
  • Maybe create __getattr__ that can forward attrs to opt

_DDP_LEADER_RANK = 0
_BROADCAST_BUFFER_SIZE = 250 * 1024 ** 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New pytorch seems to have finally implemented broadcast_coalesced in distributed,
we can directly address this https://pytorch.org/docs/stable/_modules/torch/nn/parallel/comm.html#broadcast_coalesced as long as we bump minimal pytorch version. Wadayathink?


@staticmethod
def is_ddp_enabled():
return torch.distributed.is_initialized()

@staticmethod
def is_ddp_leader():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would recommend reusing the same terminology as somewhere, such as inside DistributedDataParallel

For instance, the above DDP uses

  • leader rank -> authoritative rank
  • is_ddp_enabled -> _initialized

return not torch.distributed.is_initialized() or torch.distributed.get_rank() == DDPOptimizer._DDP_LEADER_RANK

def __init__(
self,
*,
dht: Optional[DHT] = None,
optimizer: Union[TorchOptimizer, OptimizerFactory],
params: Optional[Union[Parameters, ParamGroups]] = None,
reuse_grad_buffers: bool = False,
use_local_updates: bool = False,
**kwargs
):
if self.is_ddp_leader() != (dht is not None):
class_name = self.__class__.__name__
raise ValueError(
f"{class_name}(dht=...) is expected to be a hivemind.DHT instance "
f"if {class_name}.is_ddp_leader(), None otherwise. "
f"Please write code as follows:\n\n"
f"if {class_name}.is_ddp_leader():\n"
f" dht = hivemind.DHT(...)\n"
f"else:\n"
f" dht = None\n"
f"optimizer = {class_name}(dht=dht, ...)"
)

if self.is_ddp_leader():
super().__init__(
dht,
optimizer,
params,
reuse_grad_buffers,
use_local_updates,
**kwargs
)
self._main_parameters = self.state_averager.main_parameters
else:
self._param_groups, self._main_parameters, _ = TrainingStateAverager.check_params(optimizer, params)
self.reuse_grad_buffers, self.use_local_updates = reuse_grad_buffers, use_local_updates

self._checksum_counter = 0
self._prev_version = self._prev_epoch = -1
self._sync_among_ddp_ranks()

# Collect fields of DDPOptimizer and its descendants
self._ddp_aware_fields = set(self.__dict__.keys())
for klass in self.__mro__:
self._ddp_aware_fields.update(klass.__dict__.keys())
if klass is DDPOptimizer:
break

def __getattribute__(self, name: str):
"""
This works as usual on leaders, but denies access to non DDP-aware fields
(i.e., fields defined in DDPOptimizer ancestors) on followers.
"""

if (
not name.startswith("_") and
name not in self._ddp_aware_fields and
not DDPOptimizer.is_ddp_leader()
):
raise RuntimeError(
f"{self.__class__.__name__}.{name} is only available on the DDP leader. "
f"Please access it only if DDPOptimizer.is_ddp_leader() == True"
)

return super().__getattribute__(name)

def is_alive(self) -> bool:
# On followers, this always returns False since there's nothing to shut down in __del__()
return self.is_ddp_leader() and super().is_alive()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if leader:
return is_alive
else:
raise NotImplementedError?


def _compute_state_version(self) -> int:
"""Return a non-decreasing integer that goes up whenever model params and/or buffers were updated"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is meant as a workaround to catch the moment when optimizer has updated parameters (load from peers, apply optimizer step average params)

All changes to state are currently handled in StateAverager.
Maybe we can implement StateAverager.local_version that gets incremented every time StateAverager loads, averages or updates state by optimizer


assert self.is_ddp_leader()
return sum(state["step"] for state in self.opt.state.values())

def _has_updated_params_after_sync(self) -> bool:
if not self.is_ddp_enabled():
return False

store = _get_default_store()
if self.is_ddp_leader():
current_version = self._compute_state_version()
if current_version == self._prev_version and self.local_epoch > self._prev_epoch + 1:
logger.warning("Model state version has not changed during a full epoch; "
"broadcasting parameters between torch.distributed synchronization may be broken")

should_broadcast = (current_version != self._prev_version or self.local_epoch > self._prev_epoch + 1)

store.set(f"_hivemind_should_broadcast_state", str(int(should_broadcast)))
torch.distributed.barrier()
return should_broadcast
else:
torch.distributed.barrier()
raw_should_broadcast = store.get(f"_hivemind_should_broadcast_state")
return bool(int(raw_should_broadcast))

def _sync_among_ddp_ranks(self) -> None:
"""Synchronize model params and buffers from the DDP leader"""

if not self.is_ddp_enabled():
return

t_start = time.perf_counter()
with torch.no_grad():
torch.distributed._broadcast_coalesced(
_get_default_group(), self._main_parameters, self._BROADCAST_BUFFER_SIZE, self._DDP_LEADER_RANK
)
if self.is_ddp_leader():
self._prev_version = self._compute_state_version()
self._prev_epoch = self.local_epoch
elapsed = time.perf_counter() - t_start
logger.debug(f"Broadcasting leader params among DDP ranks took {elapsed:.2f} sec")

def step(
self,
closure: Optional[Callable[[], torch.Tensor]] = None,
batch_size: Optional[int] = None,
grad_scaler: Optional[GradScaler] = None,
):
if self.is_ddp_leader():
loss = super().step(closure, batch_size, grad_scaler)

if self._has_updated_params_after_sync():
self._sync_among_ddp_ranks()
else:
logger.debug("No need to broadcast leader params among DDP ranks")

if self.is_ddp_enabled():
self._checksum_counter += 1
if self._checksum_counter % 100 == 0:
rank = torch.distributed.get_rank()
checksum = sum(p.sum().item() for p in self._main_parameters)
logger.debug(f"Parameter checksum (ddp_rank={rank}): {float(checksum)}")

return loss if self.is_ddp_leader() else None

def load_state_from_peers(self, **kwargs) -> None:
if self.is_ddp_leader():
super().load_state_from_peers(**kwargs)

self._sync_among_ddp_ranks()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not synchronize here: non-master ranks cannot call this and we will deadlock.

We should only sync in step -- and after step check IF master updated/loaded/averaged step and then broadcast.


def load_state_dict(self, state_dict: dict) -> None:
if self.is_ddp_leader():
super().load_state_dict(state_dict)

self._sync_among_ddp_ranks()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not synchronize here: non-master ranks cannot call this and we will deadlock: see load_state_from_peers


@property
def param_groups(self) -> ParamGroups:
if self.is_ddp_leader():
return super().param_groups
else:
return self._param_groups

def zero_grad(self, set_to_none: bool = False):
# We explicitly define this method to mark that it should be available on the DDP followers
super().zero_grad(set_to_none)

def shutdown(self):
if self.is_ddp_leader():
super().shutdown()
Copy link
Member

@justheuristic justheuristic May 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: else raise NotImplemented or warn?

10 changes: 3 additions & 7 deletions hivemind/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from hivemind.dht import DHT
from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
from hivemind.optim.grad_scaler import GradScaler
from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
from hivemind.optim.state_averager import (
LRSchedulerBase,
Expand Down Expand Up @@ -238,6 +237,7 @@ def __init__(
self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
self.reuse_grad_buffers, self.use_local_updates = reuse_grad_buffers, use_local_updates

self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
Expand Down Expand Up @@ -358,13 +358,9 @@ def local_epoch(self) -> int:
def local_progress(self) -> LocalTrainingProgress:
return self.tracker.local_progress

@property
def use_local_updates(self) -> bool:
return self.grad_averager is None

@property
def use_gradient_averaging(self) -> bool:
return self.grad_averager is not None
return not self.use_local_updates

def step(
self,
Expand Down Expand Up @@ -637,7 +633,7 @@ def _load_local_gradients_into_optimizer(self):

def zero_grad(self, set_to_none: bool = False):
"""Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
if self.use_gradient_averaging and self.reuse_grad_buffers:
raise ValueError(
f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
f"call zero_grad manually. Gradients will be refreshed internally"
Expand Down
8 changes: 4 additions & 4 deletions hivemind/optim/state_averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
if reuse_tensors and delta_rule_averaging:
raise ValueError("reuse_tensors and delta_rule_averaging are mutually exclusive")

param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
param_groups, main_parameters, parameter_names = self.check_params(optimizer, params, parameter_names)

self.status_loglevel = status_loglevel
self.offload_optimizer, self.custom_gradients = offload_optimizer, custom_gradients
Expand Down Expand Up @@ -131,10 +131,10 @@ def __init__(
)

@staticmethod
def _check_params(
def check_params(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def check_params(
def prepare_params(

optimizer: Union[TorchOptimizer, OptimizerFactory],
param_groups: Optional[Union[Parameters, ParamGroups]],
parameter_names: Optional[Sequence[str]],
param_groups: Optional[Union[Parameters, ParamGroups]] = None,
parameter_names: Optional[Sequence[str]] = None,
) -> Tuple[ParamGroups, Sequence[torch.Tensor], Sequence[str]]:
"""Get and verify parameters, groups and names"""
if param_groups is None:
Expand Down