Skip to content

Commit

Permalink
[Refactor] Refactor move device to GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 24, 2023
1 parent ada59a8 commit e14c69e
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 9 deletions.
5 changes: 2 additions & 3 deletions examples/hybrid_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.distributed import ParallelContext, ParallelMode
from pipegoose.nn import DataParallel, TensorParallel

if __name__ == "__main__":
Expand All @@ -22,7 +21,7 @@
rank = parallel_context.get_global_rank()

dataset = load_dataset("imdb", split="train[:100]")
dataset = dataset.map(lambda x: {"text": x["text"][:30]})
dataset = dataset.map(lambda x: {"text": x["text"][:30]}) # for demonstration purposes, you can remove this line

dp_rank = parallel_context.get_local_rank(ParallelMode.DATA)
sampler = DistributedSampler(dataset, num_replicas=DATA_PARALLEL_SIZE, rank=dp_rank, seed=69)
Expand Down
5 changes: 0 additions & 5 deletions pipegoose/nn/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

@dataclass
class ParallelMetadata:
is_moved_to_device: bool = False

device: int = None
local_device: int = None

Expand Down Expand Up @@ -74,7 +72,6 @@ def is_specific_device(device):

assert parallel_metadata is not None, "Module is not parallelized yet"
assert device in SUPPORTED_DEVICES, f"Device must be one of {SUPPORTED_DEVICES}, got {device}"
assert parallel_metadata.is_moved_to_device is False, "Module is already moved to device"
assert not is_specific_device(
device
), f'Moving to a specific device {device} is not supported. pipegoose will handle device assignment automatically. Please use "cuda" instead'
Expand All @@ -91,8 +88,6 @@ def is_specific_device(device):
for b in self.buffers():
b.data = b.to(f"cuda:{local_device}")

parallel_metadata.is_moved_to_device = True


def _to_cuda(self):
self.to("cuda")
3 changes: 2 additions & 1 deletion tests/convergence/test_hybrid_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.distributed as dist
import wandb
from datasets import load_dataset
from torch.optim import SGD
from torch.utils.data import DataLoader
Expand All @@ -24,6 +23,8 @@ def get_model_params_size(model, fp_bytes=4):


if __name__ == "__main__":
import wandb

DATA_PARALLEL_SIZE = 2
TENSOR_PARALLEL_SIZE = 2
PIPELINE_PARALLEL_SIZE = 1
Expand Down

0 comments on commit e14c69e

Please sign in to comment.