Skip to content

Commit

Permalink
[FIX] DataParallel's tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 22, 2023
1 parent 38b215d commit f33d5e4
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 58 deletions.
12 changes: 4 additions & 8 deletions pipegoose/nn/data_parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.distributed as dist
from torch import nn

from pipegoose.distributed.functional import all_reduce
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode

Expand All @@ -25,12 +26,7 @@ def _register_grad_avg_hook(self, module: nn.Module):
if p.requires_grad is True:
p.register_hook(self._average_grad)

def _average_grad(self, grad: torch.Tensor) -> torch.Tensor:
data_parallel_size = self.parallel_context.data_parallel_size
process_group = self.parallel_context.get_group(ParallelMode.DATA)

def _average_grad(self, grad: torch.Tensor):
# NOTE: (grad1 + grad2 + ... + gradn) / n = grad1/n + grad2/n + ... + gradn/n
new_grad = grad / data_parallel_size
dist.all_reduce(new_grad, op=dist.ReduceOp.SUM, group=process_group)

return new_grad
grad.div_(self.parallel_context.data_parallel_size)
all_reduce(grad, op=dist.ReduceOp.SUM, parallel_context=self.parallel_context, parallel_mode=ParallelMode.DATA)
30 changes: 28 additions & 2 deletions pipegoose/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from typing import Callable

import pytest
import torch
import torch.multiprocessing as mp
from torch import nn

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode

# NOTE: because these tests run too slow in GitHub Actions
skip_in_github_actions = pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") == "true", reason="Test skipped in GitHub Actions")
Expand Down Expand Up @@ -35,8 +40,6 @@ def spawn(func: Callable, world_size: int = 1, **kwargs):


def init_parallel_context(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
from pipegoose.distributed.parallel_context import ParallelContext

HOST = "localhost"
SEED = 69
BACKEND = "gloo"
Expand Down Expand Up @@ -86,3 +89,26 @@ def init_pipeline_context(
)

return pipeline_context, parallel_context


def get_partition(data: torch.Tensor, dim: int, parallel_context: ParallelContext) -> torch.Tensor:
local_world_size = parallel_context.get_world_size(ParallelMode.TENSOR)
local_rank = parallel_context.get_local_rank(ParallelMode.TENSOR)
chunks = torch.chunk(data, chunks=local_world_size, dim=dim)
return chunks[local_rank]


def calculate_parameter_similarity(module1: nn.Module, module2: nn.Module, rtol: float = 1e-3) -> float:
# NOTE: In some test cases, the parameters of an updated model after
# .step() are very close to the parameters of the original model.
# So we use this function to check if the parameters of
# the updated model have deviated from the parameters
# of the original model enough.
total_parameters, equal_parameters = 0, 0
for param1, param2 in zip(module1.parameters(), module2.parameters()):
assert param1.shape == param2.shape, "Parameters have different shapes"
flat_param1, flat_param2 = param1.view(-1), param2.view(-1)
total_parameters += flat_param1.shape[0]
equal_parameters += torch.sum(torch.isclose(flat_param1, flat_param2, rtol=rtol)).item()

return equal_parameters / total_parameters
98 changes: 50 additions & 48 deletions tests/nn/data_parallel/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@

from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.data_parallel.data_parallel import DataParallel
from pipegoose.testing.utils import init_parallel_context, spawn
from pipegoose.testing.utils import (
calculate_parameter_similarity,
init_parallel_context,
spawn,
)

MODEL_NAME = "prajjwal1/bert-tiny"


@pytest.fixture(scope="module")
def model():
return AutoModelForCausalLM.from_pretrained(MODEL_NAME)
# @pytest.fixture(scope="module")
# def model():
# return AutoModelForCausalLM.from_pretrained(MODEL_NAME)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -82,78 +86,76 @@ def test_parallelize_a_transformer_and_inference(model, tokenizer, data_parallel
def run_backward_a_parallelized_transformers(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, kwargs
):
def get_microbatch(kwargs):
input_ids = kwargs["inputs"]["input_ids"][local_rank].unsqueeze(0)
attention_mask = kwargs["inputs"]["attention_mask"][local_rank].unsqueeze(0)
labels = kwargs["labels"][local_rank].unsqueeze(0)
return input_ids, attention_mask, labels

model = kwargs["model"]
loss = kwargs["loss"]
lr = kwargs["lr"]
def get_microbatch(inputs, labels):
local_rank = parallel_context.get_local_rank(ParallelMode.DATA)
input_chunks = torch.chunk(inputs["input_ids"], chunks=world_size, dim=0)
attention_chunks = torch.chunk(inputs["attention_mask"], chunks=world_size, dim=0)
label_chunks = torch.chunk(labels, chunks=world_size, dim=0)
return input_chunks[local_rank], attention_chunks[local_rank], label_chunks[local_rank]

model = deepcopy(kwargs["model"])
UPDATED_MODEL = deepcopy(kwargs["updated_model"])
LR = kwargs["lr"]
inputs = kwargs["inputs"]
labels = kwargs["labels"]

parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)

input_ids, attention_mask, labels = get_microbatch(inputs, labels)
parallelized_model = DataParallel(model, parallel_context).parallelize()
local_rank = parallel_context.get_local_rank(ParallelMode.DATA)

input_ids, attention_mask, labels = get_microbatch(kwargs)
optim = SGD(parallelized_model.parameters(), lr=LR)

p_output = parallelized_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
p_loss = p_output.loss

# NOTE: since each replica only computes on a subset of data,
# the replica's loss and the loss of the original model that trains
# on the whole set of data should not be equal.
assert not torch.allclose(p_loss, loss)

optim = SGD(parallelized_model.parameters(), lr=lr)
optim.zero_grad()
p_loss.backward()
outputs = parallelized_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

loss = outputs.loss
loss.backward()
optim.step()

# NOTE: after averaging the gradient, we expect the gradient of a replica
# that trains on a subset of data to be equal to the gradient of
# the original model that trains on the whole set of data
for parallel_p, p in zip(parallelized_model.parameters(), model.parameters()):
assert torch.allclose(parallel_p.grad, p.grad, rtol=1e-3)
for p, ref_p in zip(parallelized_model.parameters(), UPDATED_MODEL.parameters()):
assert torch.allclose(p, ref_p)


@pytest.mark.parametrize("data_parallel_size", [1, 2])
def test_backward_pass_a_parallelized_transformers(model, tokenizer, data_parallel_size):
def test_backward_pass_a_parallelized_transformers(tokenizer, data_parallel_size):
TENSOR_PARALLEL_SIZE = 1
PIPELINE_PARALLEL_SIZE = 1

LR = 1e-3
# NOTE: if use small learning rate,
# the updated model and the original model's weights can be identical in some cases
# this could leads to wrong test
LR = 1e-1

text = ["Persistence is all you need.", "3D parallelism is all you need."]
inputs = tokenizer(text, return_tensors="pt", padding="longest")
labels = inputs["input_ids"]
optim = SGD(model.parameters(), lr=LR)

GRADS = []

def save_orig_grad(grad):
GRADS.append(grad.clone().detach())

for p in model.parameters():
if p.requires_grad is True:
p.register_hook(save_orig_grad)
labels = torch.randint_like(inputs["input_ids"], low=100, high=200)

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
ORIG_MODEL = deepcopy(model)
optim = SGD(model.parameters(), lr=LR)
optim.zero_grad()
outputs = model(**inputs, labels=labels)

# NOTE: we make a copy of the model before updating its weights
# so the output of the model is not affected by the updated weights
orig_model = deepcopy(model)
loss = outputs.loss

optim.zero_grad()
loss.backward()
optim.step()

kwargs = {"model": orig_model, "lr": LR, "inputs": inputs, "labels": labels, "grads": GRADS, "loss": loss.detach()}
# NOTE: if some cases, the updated model and the original model's weights can be identical
# so we need to make sure the updated model and the original model's weights are different
similarity = calculate_parameter_similarity(ORIG_MODEL, model)
assert similarity < 0.95, f"Two models should be different before training. Similarity: {similarity}"

kwargs = {
"model": ORIG_MODEL,
"updated_model": model,
"lr": LR,
"inputs": inputs,
"labels": labels,
}

spawn(
run_backward_a_parallelized_transformers,
Expand Down

0 comments on commit f33d5e4

Please sign in to comment.