From 9bc42a44a4d4f6b636f084969f3e99f0ab8f8875 Mon Sep 17 00:00:00 2001 From: xrsrke Date: Sun, 22 Oct 2023 09:40:03 +0700 Subject: [PATCH] [BUG] fix LinearParallel's tests --- tests/nn/tensor_parallel/test_linear.py | 108 +++++++++++++++--------- 1 file changed, 66 insertions(+), 42 deletions(-) diff --git a/tests/nn/tensor_parallel/test_linear.py b/tests/nn/tensor_parallel/test_linear.py index 7cec801..88c75e3 100644 --- a/tests/nn/tensor_parallel/test_linear.py +++ b/tests/nn/tensor_parallel/test_linear.py @@ -1,27 +1,33 @@ +from copy import deepcopy + import pytest import torch from torch import nn from pipegoose.distributed.parallel_mode import ParallelMode from pipegoose.nn.tensor_parallel.linear import ColumnParallelLinear, RowParallelLinear -from pipegoose.testing.utils import init_parallel_context, spawn +from pipegoose.testing.utils import get_partition, init_parallel_context, spawn -def run_parallel_column_linear( +def run_column_parallel_linear( rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, - batch_size, in_features, out_features, inputs, - outputs, - params, - grads, + ref_outputs, + orig_params, + ref_params, + ref_grads, ): + ORIG_PARAMS = deepcopy(orig_params) + REF_PARAMS = deepcopy(ref_params) + REF_GRADS = deepcopy(ref_grads) + parallel_context = init_parallel_context( rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size ) @@ -39,39 +45,50 @@ def run_parallel_column_linear( parallel_context=parallel_context, ) - partition_size = params["weight"].shape[0] // local_world_size + partition_size = ORIG_PARAMS["weight"].shape[0] // local_world_size partition_start, partition_end = local_rank * partition_size, (local_rank + 1) * partition_size - model.weight.data = params["weight"][partition_start:partition_end, :] - model.bias.data = params["bias"][partition_start:partition_end] + model.weight.data = ORIG_PARAMS["weight"][partition_start:partition_end, :] + model.bias.data = ORIG_PARAMS["bias"][partition_start:partition_end] - parallel_outputs = model(inputs) + outputs = model(inputs) - assert parallel_outputs.shape == (batch_size, out_features) - # NOTE: sometimes it's not equal due to small relative differences (rtol) - assert torch.allclose(parallel_outputs, outputs, rtol=1e-3) + assert outputs.shape == ref_outputs.shape + assert torch.allclose(outputs, ref_outputs) - parallel_outputs.sum().backward() + outputs.sum().backward() - assert torch.allclose(model.weight.grad, grads["weight"][local_rank]) - assert torch.allclose(model.bias.grad, grads["bias"][local_rank]) + split_dim = 0 + REF_WEIGHT_GRADS = get_partition(REF_GRADS["weight"], dim=split_dim, parallel_context=parallel_context) + REF_BIAS_GRADS = get_partition(REF_GRADS["bias"], dim=split_dim, parallel_context=parallel_context) + assert torch.allclose(model.weight.grad, REF_WEIGHT_GRADS) + assert torch.allclose(model.bias.grad, REF_BIAS_GRADS) + REF_WEIGHT = get_partition(REF_PARAMS["weight"], dim=split_dim, parallel_context=parallel_context) + REF_BIAS = get_partition(REF_PARAMS["bias"], dim=split_dim, parallel_context=parallel_context) + assert torch.allclose(model.weight, REF_WEIGHT) + assert torch.allclose(model.bias, REF_BIAS) -def run_parallel_row_linear( + +def run_row_parallel_linear( rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, - batch_size, in_features, out_features, inputs, - outputs, - params, - grads, + ref_outputs, + orig_params, + ref_params, + ref_grads, ): + ORIG_PARAMS = deepcopy(orig_params) + REF_PARAMS = deepcopy(ref_params) + REF_GRADS = deepcopy(ref_grads) + parallel_context = init_parallel_context( rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size ) @@ -88,26 +105,30 @@ def run_parallel_row_linear( parallel_context=parallel_context, ) - partition_size = params["weight"].shape[1] // local_world_size + partition_size = ORIG_PARAMS["weight"].shape[1] // local_world_size partition_start, partition_end = local_rank * partition_size, (local_rank + 1) * partition_size - model.weight.data = params["weight"][:, partition_start:partition_end] - model.bias.data = params["bias"] + model.weight.data = ORIG_PARAMS["weight"][:, partition_start:partition_end] + model.bias.data = ORIG_PARAMS["bias"] - parallel_outputs = model(inputs) + outputs = model(inputs) - assert parallel_outputs.shape == outputs.shape - assert torch.allclose(parallel_outputs, outputs) + assert outputs.shape == ref_outputs.shape + assert torch.allclose(outputs, ref_outputs) - parallel_outputs.sum().backward() + outputs.sum().backward() - weight_grad_chunks = torch.split(grads["weight"], partition_size, dim=1) + split_dim = 1 + REF_WEIGHT_GRADS = get_partition(REF_GRADS["weight"], dim=split_dim, parallel_context=parallel_context) + assert torch.allclose(model.weight.grad, REF_WEIGHT_GRADS) + assert torch.allclose(model.bias.grad, REF_GRADS["bias"]) - assert torch.allclose(model.weight.grad, weight_grad_chunks[local_rank]) - assert torch.allclose(model.bias.grad, grads["bias"]) + REF_WEIGHT = get_partition(REF_PARAMS["weight"], dim=split_dim, parallel_context=parallel_context) + assert torch.allclose(model.weight, REF_WEIGHT) + assert torch.allclose(model.bias, REF_PARAMS["bias"]) -@pytest.mark.parametrize("run_linear", [run_parallel_column_linear, run_parallel_row_linear]) +@pytest.mark.parametrize("run_linear", [run_column_parallel_linear, run_row_parallel_linear]) def test_parallel_linear(run_linear): TENSOR_PARALLEL_SIZE = 2 PIPELINE_PARALLEL_SIZE = 1 @@ -119,19 +140,22 @@ def test_parallel_linear(run_linear): inputs = torch.randn(batch_size, in_features) model = nn.Linear(in_features, out_features) + ORIG_PARAMS = { + "weight": deepcopy(model.weight.detach().requires_grad_(False)), + "bias": deepcopy(model.bias.detach().requires_grad_(False)), + } outputs = model(inputs) outputs.sum().backward() - params = { - "weight": model.weight.detach().requires_grad_(False), - "bias": model.bias.detach().requires_grad_(False), - } - - grads = { + REF_GRADS = { "weight": model.weight.grad.detach().requires_grad_(False), "bias": model.bias.grad.detach().requires_grad_(False), } + REF_PARAMS = { + "weight": deepcopy(model.weight.detach().requires_grad_(False)), + "bias": deepcopy(model.bias.detach().requires_grad_(False)), + } spawn( run_linear, @@ -139,11 +163,11 @@ def test_parallel_linear(run_linear): tensor_parallel_size=TENSOR_PARALLEL_SIZE, pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, data_parallel_size=DATA_PARALLEL_SIZE, - batch_size=batch_size, in_features=in_features, out_features=out_features, inputs=inputs.detach(), - outputs=outputs.detach(), - params=params, - grads=grads, + ref_outputs=outputs.detach(), + orig_params=ORIG_PARAMS, + ref_params=REF_PARAMS, + ref_grads=REF_GRADS, )