Skip to content

Commit

Permalink
[BUG] fix LinearParallel's tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 22, 2023
1 parent 35eb547 commit 9bc42a4
Showing 1 changed file with 66 additions and 42 deletions.
108 changes: 66 additions & 42 deletions tests/nn/tensor_parallel/test_linear.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
from copy import deepcopy

import pytest
import torch
from torch import nn

from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.tensor_parallel.linear import ColumnParallelLinear, RowParallelLinear
from pipegoose.testing.utils import init_parallel_context, spawn
from pipegoose.testing.utils import get_partition, init_parallel_context, spawn


def run_parallel_column_linear(
def run_column_parallel_linear(
rank,
world_size,
port,
tensor_parallel_size,
pipeline_parallel_size,
data_parallel_size,
batch_size,
in_features,
out_features,
inputs,
outputs,
params,
grads,
ref_outputs,
orig_params,
ref_params,
ref_grads,
):
ORIG_PARAMS = deepcopy(orig_params)
REF_PARAMS = deepcopy(ref_params)
REF_GRADS = deepcopy(ref_grads)

parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
Expand All @@ -39,39 +45,50 @@ def run_parallel_column_linear(
parallel_context=parallel_context,
)

partition_size = params["weight"].shape[0] // local_world_size
partition_size = ORIG_PARAMS["weight"].shape[0] // local_world_size
partition_start, partition_end = local_rank * partition_size, (local_rank + 1) * partition_size

model.weight.data = params["weight"][partition_start:partition_end, :]
model.bias.data = params["bias"][partition_start:partition_end]
model.weight.data = ORIG_PARAMS["weight"][partition_start:partition_end, :]
model.bias.data = ORIG_PARAMS["bias"][partition_start:partition_end]

parallel_outputs = model(inputs)
outputs = model(inputs)

assert parallel_outputs.shape == (batch_size, out_features)
# NOTE: sometimes it's not equal due to small relative differences (rtol)
assert torch.allclose(parallel_outputs, outputs, rtol=1e-3)
assert outputs.shape == ref_outputs.shape
assert torch.allclose(outputs, ref_outputs)

parallel_outputs.sum().backward()
outputs.sum().backward()

assert torch.allclose(model.weight.grad, grads["weight"][local_rank])
assert torch.allclose(model.bias.grad, grads["bias"][local_rank])
split_dim = 0
REF_WEIGHT_GRADS = get_partition(REF_GRADS["weight"], dim=split_dim, parallel_context=parallel_context)
REF_BIAS_GRADS = get_partition(REF_GRADS["bias"], dim=split_dim, parallel_context=parallel_context)
assert torch.allclose(model.weight.grad, REF_WEIGHT_GRADS)
assert torch.allclose(model.bias.grad, REF_BIAS_GRADS)

REF_WEIGHT = get_partition(REF_PARAMS["weight"], dim=split_dim, parallel_context=parallel_context)
REF_BIAS = get_partition(REF_PARAMS["bias"], dim=split_dim, parallel_context=parallel_context)
assert torch.allclose(model.weight, REF_WEIGHT)
assert torch.allclose(model.bias, REF_BIAS)

def run_parallel_row_linear(

def run_row_parallel_linear(
rank,
world_size,
port,
tensor_parallel_size,
pipeline_parallel_size,
data_parallel_size,
batch_size,
in_features,
out_features,
inputs,
outputs,
params,
grads,
ref_outputs,
orig_params,
ref_params,
ref_grads,
):
ORIG_PARAMS = deepcopy(orig_params)
REF_PARAMS = deepcopy(ref_params)
REF_GRADS = deepcopy(ref_grads)

parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
Expand All @@ -88,26 +105,30 @@ def run_parallel_row_linear(
parallel_context=parallel_context,
)

partition_size = params["weight"].shape[1] // local_world_size
partition_size = ORIG_PARAMS["weight"].shape[1] // local_world_size
partition_start, partition_end = local_rank * partition_size, (local_rank + 1) * partition_size

model.weight.data = params["weight"][:, partition_start:partition_end]
model.bias.data = params["bias"]
model.weight.data = ORIG_PARAMS["weight"][:, partition_start:partition_end]
model.bias.data = ORIG_PARAMS["bias"]

parallel_outputs = model(inputs)
outputs = model(inputs)

assert parallel_outputs.shape == outputs.shape
assert torch.allclose(parallel_outputs, outputs)
assert outputs.shape == ref_outputs.shape
assert torch.allclose(outputs, ref_outputs)

parallel_outputs.sum().backward()
outputs.sum().backward()

weight_grad_chunks = torch.split(grads["weight"], partition_size, dim=1)
split_dim = 1
REF_WEIGHT_GRADS = get_partition(REF_GRADS["weight"], dim=split_dim, parallel_context=parallel_context)
assert torch.allclose(model.weight.grad, REF_WEIGHT_GRADS)
assert torch.allclose(model.bias.grad, REF_GRADS["bias"])

assert torch.allclose(model.weight.grad, weight_grad_chunks[local_rank])
assert torch.allclose(model.bias.grad, grads["bias"])
REF_WEIGHT = get_partition(REF_PARAMS["weight"], dim=split_dim, parallel_context=parallel_context)
assert torch.allclose(model.weight, REF_WEIGHT)
assert torch.allclose(model.bias, REF_PARAMS["bias"])


@pytest.mark.parametrize("run_linear", [run_parallel_column_linear, run_parallel_row_linear])
@pytest.mark.parametrize("run_linear", [run_column_parallel_linear, run_row_parallel_linear])
def test_parallel_linear(run_linear):
TENSOR_PARALLEL_SIZE = 2
PIPELINE_PARALLEL_SIZE = 1
Expand All @@ -119,31 +140,34 @@ def test_parallel_linear(run_linear):

inputs = torch.randn(batch_size, in_features)
model = nn.Linear(in_features, out_features)
ORIG_PARAMS = {
"weight": deepcopy(model.weight.detach().requires_grad_(False)),
"bias": deepcopy(model.bias.detach().requires_grad_(False)),
}

outputs = model(inputs)
outputs.sum().backward()

params = {
"weight": model.weight.detach().requires_grad_(False),
"bias": model.bias.detach().requires_grad_(False),
}

grads = {
REF_GRADS = {
"weight": model.weight.grad.detach().requires_grad_(False),
"bias": model.bias.grad.detach().requires_grad_(False),
}
REF_PARAMS = {
"weight": deepcopy(model.weight.detach().requires_grad_(False)),
"bias": deepcopy(model.bias.detach().requires_grad_(False)),
}

spawn(
run_linear,
world_size=TENSOR_PARALLEL_SIZE,
tensor_parallel_size=TENSOR_PARALLEL_SIZE,
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
batch_size=batch_size,
in_features=in_features,
out_features=out_features,
inputs=inputs.detach(),
outputs=outputs.detach(),
params=params,
grads=grads,
ref_outputs=outputs.detach(),
orig_params=ORIG_PARAMS,
ref_params=REF_PARAMS,
ref_grads=REF_GRADS,
)

0 comments on commit 9bc42a4

Please sign in to comment.