From 2d9729d0fb3ff386970615d39fcd32f6d9b68843 Mon Sep 17 00:00:00 2001 From: xrsrke Date: Wed, 25 Oct 2023 06:26:53 +0700 Subject: [PATCH] [Bug] fixed import path in hybrid parallelism test --- .../{test_hybrid_parallel.py => run_hybrid_parallel.py} | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) rename tests/convergence/{test_hybrid_parallel.py => run_hybrid_parallel.py} (96%) diff --git a/tests/convergence/test_hybrid_parallel.py b/tests/convergence/run_hybrid_parallel.py similarity index 96% rename from tests/convergence/test_hybrid_parallel.py rename to tests/convergence/run_hybrid_parallel.py index befd20a..bfcb215 100644 --- a/tests/convergence/test_hybrid_parallel.py +++ b/tests/convergence/run_hybrid_parallel.py @@ -2,7 +2,6 @@ import torch import torch.distributed as dist -import wandb from datasets import load_dataset from torch.optim import SGD from torch.utils.data import DataLoader @@ -11,8 +10,7 @@ from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.distributed.parallel_mode import ParallelMode -from pipegoose.nn.data_parallel.data_parallel import DataParallel -from pipegoose.nn.tensor_parallel.tensor_parallel import TensorParallel +from pipegoose.nn import DataParallel, TensorParallel def get_model_params_size(model, fp_bytes=4): @@ -24,6 +22,8 @@ def get_model_params_size(model, fp_bytes=4): if __name__ == "__main__": + import wandb + DATA_PARALLEL_SIZE = 2 TENSOR_PARALLEL_SIZE = 2 PIPELINE_PARALLEL_SIZE = 1 @@ -133,5 +133,4 @@ def get_time_name(): step += 1 wandb.finish() - -model.cpu() + model.cpu()