Skip to content

Commit

Permalink
[Bug] fixed import path in hybrid parallelism test
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 24, 2023
1 parent 2619390 commit 2d9729d
Showing 1 changed file with 4 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.distributed as dist
import wandb
from datasets import load_dataset
from torch.optim import SGD
from torch.utils.data import DataLoader
Expand All @@ -11,8 +10,7 @@

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.data_parallel.data_parallel import DataParallel
from pipegoose.nn.tensor_parallel.tensor_parallel import TensorParallel
from pipegoose.nn import DataParallel, TensorParallel


def get_model_params_size(model, fp_bytes=4):
Expand All @@ -24,6 +22,8 @@ def get_model_params_size(model, fp_bytes=4):


if __name__ == "__main__":
import wandb

DATA_PARALLEL_SIZE = 2
TENSOR_PARALLEL_SIZE = 2
PIPELINE_PARALLEL_SIZE = 1
Expand Down Expand Up @@ -133,5 +133,4 @@ def get_time_name():
step += 1

wandb.finish()

model.cpu()
model.cpu()

0 comments on commit 2d9729d

Please sign in to comment.