diff --git a/examples/hybrid_parallelism.py b/examples/hybrid_parallelism.py index 17930ac..2a480ad 100644 --- a/examples/hybrid_parallelism.py +++ b/examples/hybrid_parallelism.py @@ -4,8 +4,7 @@ from torch.utils.data.distributed import DistributedSampler from transformers import AutoModelForCausalLM, AutoTokenizer -from pipegoose.distributed.parallel_context import ParallelContext -from pipegoose.distributed.parallel_mode import ParallelMode +from pipegoose.distributed import ParallelContext, ParallelMode from pipegoose.nn import DataParallel, TensorParallel if __name__ == "__main__": @@ -22,7 +21,7 @@ rank = parallel_context.get_global_rank() dataset = load_dataset("imdb", split="train[:100]") - dataset = dataset.map(lambda x: {"text": x["text"][:30]}) + dataset = dataset.map(lambda x: {"text": x["text"][:30]}) # for demonstration purposes, you can remove this line dp_rank = parallel_context.get_local_rank(ParallelMode.DATA) sampler = DistributedSampler(dataset, num_replicas=DATA_PARALLEL_SIZE, rank=dp_rank, seed=69) diff --git a/pipegoose/nn/parallel.py b/pipegoose/nn/parallel.py index ce06fb1..9e952c1 100644 --- a/pipegoose/nn/parallel.py +++ b/pipegoose/nn/parallel.py @@ -12,8 +12,6 @@ @dataclass class ParallelMetadata: - is_moved_to_device: bool = False - device: int = None local_device: int = None @@ -74,7 +72,6 @@ def is_specific_device(device): assert parallel_metadata is not None, "Module is not parallelized yet" assert device in SUPPORTED_DEVICES, f"Device must be one of {SUPPORTED_DEVICES}, got {device}" - assert parallel_metadata.is_moved_to_device is False, "Module is already moved to device" assert not is_specific_device( device ), f'Moving to a specific device {device} is not supported. pipegoose will handle device assignment automatically. Please use "cuda" instead' @@ -91,8 +88,6 @@ def is_specific_device(device): for b in self.buffers(): b.data = b.to(f"cuda:{local_device}") - parallel_metadata.is_moved_to_device = True - def _to_cuda(self): self.to("cuda") diff --git a/tests/convergence/test_hybrid_3d.py b/tests/convergence/test_hybrid_3d.py index cf878fd..6d626da 100644 --- a/tests/convergence/test_hybrid_3d.py +++ b/tests/convergence/test_hybrid_3d.py @@ -2,7 +2,6 @@ import torch import torch.distributed as dist -import wandb from datasets import load_dataset from torch.optim import SGD from torch.utils.data import DataLoader @@ -24,6 +23,8 @@ def get_model_params_size(model, fp_bytes=4): if __name__ == "__main__": + import wandb + DATA_PARALLEL_SIZE = 2 TENSOR_PARALLEL_SIZE = 2 PIPELINE_PARALLEL_SIZE = 1