From 789442ad131445eb34aa8b608b9d94483643eaf7 Mon Sep 17 00:00:00 2001 From: xrsrke Date: Mon, 23 Oct 2023 15:38:08 +0700 Subject: [PATCH] [BUG] fixed move padding tensors to gpu and add hybryd parallelism tests --- pipegoose/nn/parallel.py | 28 +++- .../nn/pipeline_parallel/pipeline_parallel.py | 41 +++-- pipegoose/nn/tensor_parallel/parallelizer.py | 2 +- .../nn/tensor_parallel/tensor_parallel.py | 3 +- tests/convergence/dataset.py | 148 ++++++++++++++++++ tests/convergence/test_dp.py | 0 tests/convergence/test_hybrid_3d.py | 145 +++++++++++++++++ tests/convergence/test_pp.py | 0 tests/convergence/test_tp.py | 0 tests/nn/data_parallel/test_data_parallel.py | 3 + 10 files changed, 344 insertions(+), 26 deletions(-) create mode 100644 tests/convergence/dataset.py create mode 100644 tests/convergence/test_dp.py create mode 100644 tests/convergence/test_hybrid_3d.py create mode 100644 tests/convergence/test_pp.py create mode 100644 tests/convergence/test_tp.py diff --git a/pipegoose/nn/parallel.py b/pipegoose/nn/parallel.py index 5b6aa7f..1d63125 100644 --- a/pipegoose/nn/parallel.py +++ b/pipegoose/nn/parallel.py @@ -1,6 +1,7 @@ from abc import abstractclassmethod from dataclasses import dataclass from functools import partial +from typing import cast from torch import nn @@ -60,20 +61,33 @@ def _to_device(self, device: str): """Move a parallelized module to accelerators.""" SUPPORTED_DEVICES = ["cuda"] - assert self.parallel_metadata is not None, "module is not parallelized yet" - assert device in SUPPORTED_DEVICES, f"device must be one of {SUPPORTED_DEVICES}, got {device}" - assert self.parallel_metadata.is_moved_to_device is False, "module is already moved to device" + def is_specific_device(device): + import re - local_device = self.parallel_metadata.local_device + pattern = r"^cuda:[0-9]+$" + if re.match(pattern, device): + return True + return False + + parallel_metadata = cast(ParallelMetadata, getattr(self, "parallel_metadata", None)) + + assert parallel_metadata is not None, "Module is not parallelized yet" + assert device in SUPPORTED_DEVICES, f"Device must be one of {SUPPORTED_DEVICES}, got {device}" + assert parallel_metadata.is_moved_to_device is False, "Module is already moved to device" + assert not is_specific_device( + device + ), f'Moving to a specific device {device} is not supported. pipegoose will handle device assignment automatically. Please use "cuda" instead' + + local_device = parallel_metadata.local_device for p in self.parameters(): - p.data = p.to(f"cuda:{local_device}") + p = p.to(f"cuda:{local_device}") if p.grad is not None: p.grad = p.grad.to(f"cuda:{local_device}") for b in self.buffers(): - b.data = b.to(f"cuda:{local_device}") + b = b.to(f"cuda:{local_device}") - self.parallel_metadata.is_moved_to_device = True + parallel_metadata.is_moved_to_device = True def _to_cuda(self): diff --git a/pipegoose/nn/pipeline_parallel/pipeline_parallel.py b/pipegoose/nn/pipeline_parallel/pipeline_parallel.py index 7945a2d..4f6ef20 100644 --- a/pipegoose/nn/pipeline_parallel/pipeline_parallel.py +++ b/pipegoose/nn/pipeline_parallel/pipeline_parallel.py @@ -4,13 +4,14 @@ from torch import nn from pipegoose.distributed.parallel_context import ParallelContext +from pipegoose.nn.parallel import Parallel from pipegoose.nn.pipeline_parallel._utils import get_partition_idx from pipegoose.nn.pipeline_parallel._worker import WorkerManager from pipegoose.nn.pipeline_parallel.pipeline_engine import PipelineEngine from pipegoose.nn.pipeline_parallel.scheduler import GPipeScheduler -class PipelineParallel: +class PipelineParallel(Parallel): """Automatically parallelize a module using pipeline parallelism.""" def __init__( @@ -25,19 +26,25 @@ def __init__( @torch.no_grad() def parallelize(self) -> nn.Module: - partition_idx = get_partition_idx(self.parallel_context) - module = self.modules[partition_idx] - - n_partitions = self.parallel_context.pipeline_parallel_size - scheduler = GPipeScheduler(self.num_microbatches, n_partitions) - worker_manager = WorkerManager() - - pipeline_engine = PipelineEngine( - module=module, - scheduler=scheduler, - worker_manager=worker_manager, - parallel_context=self.parallel_context, - ) - - module.forward = pipeline_engine.run - return module + if self.parallel_context.pipeline_parallel_size > 1: + partition_idx = get_partition_idx(self.parallel_context) + module = self.modules[partition_idx] + + n_partitions = self.parallel_context.pipeline_parallel_size + scheduler = GPipeScheduler(self.num_microbatches, n_partitions) + worker_manager = WorkerManager() + + pipeline_engine = PipelineEngine( + module=module, + scheduler=scheduler, + worker_manager=worker_manager, + parallel_context=self.parallel_context, + ) + + module.forward = pipeline_engine.run + + self._save_metadata(module, self.parallel_context) + + return module + else: + return self.modules diff --git a/pipegoose/nn/tensor_parallel/parallelizer.py b/pipegoose/nn/tensor_parallel/parallelizer.py index ebc6f49..8d9ae5a 100644 --- a/pipegoose/nn/tensor_parallel/parallelizer.py +++ b/pipegoose/nn/tensor_parallel/parallelizer.py @@ -159,7 +159,7 @@ def _resize_vocab_size(self, module: nn.Module): padding_size += 1 if padding_size > 0: - padding = torch.zeros((padding_size, embedding_dim)) + padding = torch.zeros((padding_size, embedding_dim), device=module.weight.device) new_embeddings = torch.cat([module.weight, padding], dim=0) module.weight.data = new_embeddings diff --git a/pipegoose/nn/tensor_parallel/tensor_parallel.py b/pipegoose/nn/tensor_parallel/tensor_parallel.py index 1a13fdc..b0130bd 100644 --- a/pipegoose/nn/tensor_parallel/tensor_parallel.py +++ b/pipegoose/nn/tensor_parallel/tensor_parallel.py @@ -4,6 +4,7 @@ from torch import nn from pipegoose.distributed.parallel_context import ParallelContext +from pipegoose.nn.parallel import Parallel from pipegoose.nn.tensor_parallel.parallelizer import ( EmbeddingParallelizer, LayerNormParallelizer, @@ -13,7 +14,7 @@ ) -class TensorParallel: +class TensorParallel(Parallel): """Turn a 🤗 transformers model into a tensor parallel model.""" PARALLELIZERS = [EmbeddingParallelizer, LinearParallelizer, LayerNormParallelizer, LMHeadParallelizer] diff --git a/tests/convergence/dataset.py b/tests/convergence/dataset.py new file mode 100644 index 0000000..c7e488d --- /dev/null +++ b/tests/convergence/dataset.py @@ -0,0 +1,148 @@ +from copy import deepcopy + +import torch +import wandb +from datasets import load_dataset +from torch.optim import SGD +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer + +from pipegoose.distributed.parallel_context import ParallelContext +from pipegoose.distributed.parallel_mode import ParallelMode +from pipegoose.nn.data_parallel.data_parallel import DataParallel +from pipegoose.nn.tensor_parallel.tensor_parallel import TensorParallel + + +def get_model_params_size(model, fp_bytes=4): + params_size = 0 + for p in model.parameters(): + params_size += p.numel() + params_gb = params_size * fp_bytes / 2**30 + return params_gb + + +class SimpleDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +DATA_PARALLEL_SIZE = 2 +TENSOR_PARALLEL_SIZE = 1 +PIPELINE_PARALLEL_SIZE = 1 +MODEL = "bigscience/bloom-560m" +DATASET = "imdb" +NUM_EPOCHS = 1 +LR = 1e-3 +SEED = 69 +BATCH_SIZE = 4 +CONTEXT_LENGTH = 1024 + +print("started") + + +parallel_context = ParallelContext.from_torch( + seed=SEED, + backend="gloo", + data_parallel_size=DATA_PARALLEL_SIZE, + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, +) +rank = parallel_context.get_global_rank() + + +print("inited parallel_context") + +# dataset = SimpleDataset(data=list(range(1, 9))) +dataset = load_dataset("imdb", split="train[:100]") +dataset = dataset.map(lambda x: {"text": x["text"][:30]}) + +dp_rank = parallel_context.get_local_rank(ParallelMode.DATA) +sampler = DistributedSampler(dataset, num_replicas=DATA_PARALLEL_SIZE, rank=dp_rank, seed=SEED) +dataloader = DataLoader(dataset, batch_size=BATCH_SIZE // DATA_PARALLEL_SIZE, shuffle=False, sampler=sampler) + + +model = AutoModelForCausalLM.from_pretrained(MODEL) +tokenizer = AutoTokenizer.from_pretrained(MODEL) +tokenizer.pad_token = tokenizer.eos_token + + +ref_model = deepcopy(model) +ref_model = torch.nn.parallel.DistributedDataParallel(ref_model) +ref_optim = SGD(ref_model.parameters(), lr=LR) + + +print(f"rank={rank}, model size before parallelizing: {round(get_model_params_size(model), 3)} GB") + + +model = DataParallel(model, parallel_context).parallelize() +model = TensorParallel(model, parallel_context).parallelize() +optim = SGD(model.parameters(), lr=LR) + +print(f"rank={rank}, model size before parallelizing: {round(get_model_params_size(model), 3)} GB") + +if rank == 0: + + def get_time_name(): + import datetime + + today = datetime.datetime.now() + return today.strftime("%d/%m/%Y_%H:%M:%S") + + wandb.init( + project="pipegoose", + name=f"{get_time_name()}.test_tp_dp_converegence", + config={ + "data_parallel_size": DATA_PARALLEL_SIZE, + "tensor_parallel_size": TENSOR_PARALLEL_SIZE, + "pipeline_parallel_size": PIPELINE_PARALLEL_SIZE, + "model": MODEL, + "dataset": DATASET, + "epochs": NUM_EPOCHS, + "learning_rate": LR, + "seed": SEED, + "batch_size": BATCH_SIZE, + }, + ) + +step = 0 + +for epoch in range(NUM_EPOCHS): + + sampler.set_epoch(epoch) + print(f"rank={rank}, epoch={epoch}") + + for batch in dataloader: + # print(f"dp_rank: {dp_rank}: {batch}") + + print(f"rank={rank}, step={step}") + print(batch["text"]) + + inputs = tokenizer(batch["text"], padding=True, truncation=True, max_length=CONTEXT_LENGTH, return_tensors="pt") + labels = inputs["input_ids"] + + outputs = model(**inputs, labels=labels) + ref_outputs = ref_model(**inputs, labels=labels) + + optim.zero_grad() + outputs.loss.backward() + optim.step() + + ref_optim.zero_grad() + ref_outputs.loss.backward() + ref_optim.step() + + print(f"rank={rank}, loss={outputs.loss}, ref_loss={ref_outputs.loss}") + # print(f"rank={rank}, ref_loss={ref_outputs.loss}") + # print(f"rank={rank}, loss={outputs.loss}") + + if rank == 0: + wandb.log({"loss": outputs.loss, "ref_loss": ref_outputs.loss, "step": step, "epoch": epoch}) + + step += 1 diff --git a/tests/convergence/test_dp.py b/tests/convergence/test_dp.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/convergence/test_hybrid_3d.py b/tests/convergence/test_hybrid_3d.py new file mode 100644 index 0000000..5c95da1 --- /dev/null +++ b/tests/convergence/test_hybrid_3d.py @@ -0,0 +1,145 @@ +from copy import deepcopy + +import torch +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim import SGD +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer + +from pipegoose.distributed.parallel_context import ParallelContext +from pipegoose.distributed.parallel_mode import ParallelMode +from pipegoose.nn.data_parallel.data_parallel import DataParallel +from pipegoose.nn.tensor_parallel.tensor_parallel import TensorParallel + + +def get_model_params_size(model, fp_bytes=4): + params_size = 0 + for p in model.parameters(): + params_size += p.numel() + params_gb = params_size * fp_bytes / 2**30 + return params_gb + + +if __name__ == "__main__": + DATA_PARALLEL_SIZE = 2 + TENSOR_PARALLEL_SIZE = 2 + PIPELINE_PARALLEL_SIZE = 1 + MODEL = "bigscience/bloom-560m" + DATASET = "imdb" + NUM_EPOCHS = 100 + LR = 1e-3 + SEED = 69 + BATCH_SIZE = 4 + CONTEXT_LENGTH = 1024 + + print("started") + + parallel_context = ParallelContext.from_torch( + seed=SEED, + backend="gloo", + data_parallel_size=DATA_PARALLEL_SIZE, + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, + ) + rank = parallel_context.get_global_rank() + + print("inited parallel_context") + + if rank == 0: + + def get_time_name(): + import datetime + + today = datetime.datetime.now() + return today.strftime("%d/%m/%Y_%H:%M:%S") + + wandb.init( + project="pipegoose", + name=f"{get_time_name()}.test_tp_dp_converegence", + config={ + "data_parallel_size": DATA_PARALLEL_SIZE, + "tensor_parallel_size": TENSOR_PARALLEL_SIZE, + "pipeline_parallel_size": PIPELINE_PARALLEL_SIZE, + "model": MODEL, + "dataset": DATASET, + "epochs": NUM_EPOCHS, + "learning_rate": LR, + "seed": SEED, + "batch_size": BATCH_SIZE, + "is_cuda": True, + }, + ) + + dist.barrier() + + print("logged wandb") + + dataset = load_dataset("imdb", split="train[:100]") + dataset = dataset.map(lambda x: {"text": x["text"][:30]}) + + dp_rank = parallel_context.get_local_rank(ParallelMode.DATA) + sampler = DistributedSampler(dataset, num_replicas=DATA_PARALLEL_SIZE, rank=dp_rank, seed=SEED) + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE // DATA_PARALLEL_SIZE, shuffle=False, sampler=sampler) + + model = AutoModelForCausalLM.from_pretrained(MODEL) + ref_model = deepcopy(model) + tokenizer = AutoTokenizer.from_pretrained(MODEL) + tokenizer.pad_token = tokenizer.eos_token + + print(f"rank={rank}, model size before parallelizing: {round(get_model_params_size(model), 3)} GB") + + model = DataParallel(model, parallel_context).parallelize() + model.to("cuda") + model = TensorParallel(model, parallel_context).parallelize() + optim = SGD(model.parameters(), lr=LR) + device = next(model.parameters()).device + + print(f"rank={rank}, model size before parallelizing: {round(get_model_params_size(model), 3)} GB") + print(f"rank={rank}, moved to device: {device}") + + ref_model.to(device) + if DATA_PARALLEL_SIZE > 1: + ref_model = torch.nn.parallel.DistributedDataParallel(ref_model, device_ids=[device]) + + ref_optim = SGD(ref_model.parameters(), lr=LR) + + dist.barrier() + step = 0 + + for epoch in range(NUM_EPOCHS): + + sampler.set_epoch(epoch) + print(f"rank={rank}, epoch={epoch}") + + for batch in dataloader: + # print(f"dp_rank: {dp_rank}: {batch}") + + print(f"rank={rank}, step={step}") + print(batch["text"]) + + inputs = tokenizer(batch["text"], padding=True, truncation=True, max_length=CONTEXT_LENGTH, return_tensors="pt") + inputs = {name: tensor.to(device) for name, tensor in inputs.items()} + labels = inputs["input_ids"] + + outputs = model(**inputs, labels=labels) + ref_outputs = ref_model(**inputs, labels=labels) + + optim.zero_grad() + outputs.loss.backward() + optim.step() + + ref_optim.zero_grad() + ref_outputs.loss.backward() + ref_optim.step() + + print(f"rank={rank}, loss={outputs.loss}, ref_loss={ref_outputs.loss}") + + if rank == 0: + wandb.log({"loss": outputs.loss, "ref_loss": ref_outputs.loss, "step": step, "epoch": epoch}) + + step += 1 + + wandb.finish() diff --git a/tests/convergence/test_pp.py b/tests/convergence/test_pp.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/convergence/test_tp.py b/tests/convergence/test_tp.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/nn/data_parallel/test_data_parallel.py b/tests/nn/data_parallel/test_data_parallel.py index 823ed14..bf51478 100644 --- a/tests/nn/data_parallel/test_data_parallel.py +++ b/tests/nn/data_parallel/test_data_parallel.py @@ -181,6 +181,9 @@ def run_move_a_model_to_gpu(rank, world_size, port, tensor_parallel_size, pipeli for p in parallelized_model.parameters(): assert p.device.type == "cuda" + if p.grad is not None: + assert p.grad.device.type == "cuda" + for b in parallelized_model.buffers(): assert b.device.type == "cuda"