Skip to content

Commit

Permalink
fixed an issue where zero couldn't partition optimizer states in hybr…
Browse files Browse the repository at this point in the history
…id parallelism. added tests for hybrid parallelism
  • Loading branch information
xrsrke committed Oct 11, 2023
1 parent eed0c9e commit 4400931
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 12 deletions.
11 changes: 3 additions & 8 deletions pipegoose/optim/zero/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,19 @@ def __init__(self, optim: Optimizer, parallel_context: ParallelContext):
self.optim = optim
self.parallel_context = parallel_context

self._master_params = {}
self._setup_local_optim()

def _setup_local_optim(self):
"""Setup local optimizer."""

# NOTE: shard and assign the corresponding local parameters to the local optimizer
for i, param_groups in enumerate(self.optim.param_groups):
self._master_params[i] = param_groups["params"]

sharded_param_groups = OptimizerStateSharding(
self.optim.param_groups, self.parallel_context, ParallelMode.DATA
).shard()
ranks_in_group = self.parallel_context.get_ranks_in_group(ParallelMode.DATA)
self._rank_to_param_groups = {rank: params for rank, params in zip(ranks_in_group, sharded_param_groups)}

local_rank = self.parallel_context.get_local_rank(ParallelMode.DATA)
self.optim.param_groups = self._rank_to_param_groups[local_rank]
dp_local_rank = self.parallel_context.get_local_rank(ParallelMode.DATA)
dp_global_rank = self.parallel_context.get_global_rank_from_local_rank(dp_local_rank, ParallelMode.DATA)
self.optim.param_groups = self._rank_to_param_groups[dp_global_rank]

@property
def defaults(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/nn/optim/zero/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def run_dist_optim(

# NOTE: make sure the optimizer keep the gradients after .step()
# it's up to the user to call .zero_grad() or not
p_grads = [p.grad for p in model.parameters()]
for p1, p2 in zip(p_grads, grads):
# NOTE: dist_grads just means the gradients of the model parameters
dist_grads = [p.grad for p in model.parameters()]
for p1, p2 in zip(dist_grads, grads):
assert p1 is not None
assert torch.allclose(p1, p2), f"p1: {p1}, p2: {p2}"

Expand All @@ -60,13 +61,12 @@ def run_dist_optim(
assert p.grad is None


@pytest.mark.parametrize("data_parallel_size", [2, 4, 5])
@pytest.mark.parametrize("data_parallel_size", [2, 4])
def test_dist_optim(data_parallel_size):
TENSOR_PARALLEL_SIZE = 1
PIPELINE_PARALLEL_SIZE = 1
WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * data_parallel_size

# LR = 1e-3
BATCH_SIZE = 500
HIDDEN_SIZE = 1000
OUTPUT_SIZE = 100
Expand Down
67 changes: 67 additions & 0 deletions tests/test_hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
from torch.optim import Adam
from transformers import AutoModelForCausalLM, AutoTokenizer

from pipegoose.nn import DataParallel
from pipegoose.optim.zero.optim import DistributedOptimizer
from pipegoose.testing.utils import init_parallel_context, spawn

MODEL_NAME = "prajjwal1/bert-tiny"


@pytest.fixture(scope="module")
def model():
return AutoModelForCausalLM.from_pretrained(MODEL_NAME)


@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME)


def run_hybrid_parallelism(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, kwargs):
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)

parallelized_model = DataParallel(kwargs["model"], parallel_context).parallelize()
optim = Adam(parallelized_model.parameters())
dist_optim = DistributedOptimizer(optim, parallel_context)

output = parallelized_model(**kwargs["input"], labels=kwargs["labels"])
loss = output.loss

dist_optim.zero_grad()
loss.backward()
dist_optim.step()


@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [2])
@pytest.mark.parametrize("data_parallel_size", [2])
def test_hybrid_parallelism(model, tokenizer, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
WORLD_SIZE = tensor_parallel_size * pipeline_parallel_size * data_parallel_size
GENERATION_CONFIGS = {"max_new_tokens": 1}

text = "Persistence is all you need."
input = tokenizer(text, return_tensors="pt")
labels = input["input_ids"]

kwargs = {
"model": model,
"generation_configs": GENERATION_CONFIGS,
"input": input,
"labels": labels,
# "generated_tokens": generated_tokens.detach(),
# "logits": logits.detach(),
# "loss": loss.detach(),
}

spawn(
run_hybrid_parallelism,
world_size=WORLD_SIZE,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
data_parallel_size=data_parallel_size,
kwargs=kwargs,
)

0 comments on commit 4400931

Please sign in to comment.