diff --git a/pipegoose/nn/data_parallel/data_parallel.py b/pipegoose/nn/data_parallel/data_parallel.py index e0cb888..0ebd555 100644 --- a/pipegoose/nn/data_parallel/data_parallel.py +++ b/pipegoose/nn/data_parallel/data_parallel.py @@ -5,9 +5,10 @@ from pipegoose.distributed.functional import all_reduce from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.distributed.parallel_mode import ParallelMode +from pipegoose.nn.parallel import Parallel -class DataParallel: +class DataParallel(Parallel): def __init__(self, module: nn.Module, parallel_context: ParallelContext): self.module = module self.parallel_context = parallel_context @@ -18,6 +19,7 @@ def parallelize(self) -> nn.Module: if self.parallel_context.data_parallel_size > 1: self._register_grad_avg_hook(module) + self._save_metadata(module, self.parallel_context) return module diff --git a/pipegoose/nn/tensor_parallel/tensor_parallel.py b/pipegoose/nn/tensor_parallel/tensor_parallel.py index 6e65718..1a13fdc 100644 --- a/pipegoose/nn/tensor_parallel/tensor_parallel.py +++ b/pipegoose/nn/tensor_parallel/tensor_parallel.py @@ -16,7 +16,7 @@ class TensorParallel: """Turn a 🤗 transformers model into a tensor parallel model.""" - PARALLELIZERS = [EmbeddingParallelizer, LayerNormParallelizer, LinearParallelizer, LMHeadParallelizer] + PARALLELIZERS = [EmbeddingParallelizer, LinearParallelizer, LayerNormParallelizer, LMHeadParallelizer] def __init__(self, module: nn.Module, parallel_context: ParallelContext): self.module = module @@ -24,16 +24,21 @@ def __init__(self, module: nn.Module, parallel_context: ParallelContext): @torch.no_grad() def parallelize(self) -> nn.Module: - # NOTE: because module.named_modules returns a leaf more than once, - # this could potentially lead to the weight of a module being split - # multiple times. so we filter out and retain the non-repetitive modules (leaf modules) - leaf_modules = self._get_leaf_modules(self.module) - for module_name, leaf_module in leaf_modules: - parallelizer = self._find_parallelizer(module_name, leaf_module) - if parallelizer is not None: - parallelizer(module_name, leaf_module, self.module, self.parallel_context).parallelize() + module = self.module - return self.module + if self.parallel_context.tensor_parallel_size > 1: + # NOTE: because module.named_modules returns a leaf more than once, + # this could potentially lead to the weight of a module being split + # multiple times. so we filter out and retain the non-repetitive modules (leaf modules) + leaf_modules = self._get_leaf_modules(module) + for module_name, leaf_module in leaf_modules: + parallelizer = self._find_parallelizer(module_name, leaf_module) + if parallelizer is not None: + parallelizer(module_name, leaf_module, module, self.parallel_context).parallelize() + + self._save_metadata(module, self.parallel_context) + + return module def _get_leaf_modules(self, model: nn.Module) -> List[Tuple[str, nn.Module]]: leaf_modules = [] @@ -53,7 +58,4 @@ def _find_parallelizer(self, module_name: str, module: nn.Module) -> Optional[Mo @torch.no_grad() def deparallelize(self) -> nn.Module: for module_name, module in self.module.named_modules(): - self.parallelers[module].deparallelize(module_name, module, self.parallel_context) - - def from_pretrained(self): - pass + self.PARALLELIZERS[module].deparallelize(module_name, module, self.parallel_context) diff --git a/pipegoose/testing/utils.py b/pipegoose/testing/utils.py index 5730c6f..57db4e5 100644 --- a/pipegoose/testing/utils.py +++ b/pipegoose/testing/utils.py @@ -14,6 +14,7 @@ # NOTE: because these tests run too slow in GitHub Actions skip_in_github_actions = pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") == "true", reason="Test skipped in GitHub Actions") +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int: diff --git a/tests/nn/data_parallel/test_data_parallel.py b/tests/nn/data_parallel/test_data_parallel.py index 0f2de81..823ed14 100644 --- a/tests/nn/data_parallel/test_data_parallel.py +++ b/tests/nn/data_parallel/test_data_parallel.py @@ -6,10 +6,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from pipegoose.distributed.parallel_mode import ParallelMode -from pipegoose.nn.data_parallel.data_parallel import DataParallel +from pipegoose.nn import DataParallel from pipegoose.testing.utils import ( calculate_parameter_similarity, init_parallel_context, + skip_if_no_cuda, spawn, ) @@ -166,3 +167,37 @@ def test_backward_pass_a_parallelized_transformers(model, tokenizer, data_parall data_parallel_size=data_parallel_size, kwargs=kwargs, ) + + +def run_move_a_model_to_gpu(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, model): + model = deepcopy(model) + parallel_context = init_parallel_context( + rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size + ) + parallelized_model = DataParallel(model, parallel_context).parallelize() + + parallelized_model.to("cuda") + + for p in parallelized_model.parameters(): + assert p.device.type == "cuda" + + for b in parallelized_model.buffers(): + assert b.device.type == "cuda" + + +@skip_if_no_cuda +def test_move_a_model_to_gpu(model): + DATA_PARALLEL_SIZE = 2 + TENSOR_PARALLEL_SIZE = 1 + PIPELINE_PARALLEL_SIZE = 1 + + WOLRD_SIZE = DATA_PARALLEL_SIZE * TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE + + spawn( + run_move_a_model_to_gpu, + world_size=WOLRD_SIZE, + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, + data_parallel_size=DATA_PARALLEL_SIZE, + model=model, + )