diff --git a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py index 6022949..042edef 100644 --- a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py +++ b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py @@ -42,7 +42,6 @@ def __init__( scheduler: BaseScheduler, worker_manager: BaseWorkerManager, parallel_context: ParallelContext, - partition_func, ): assert isinstance(module, nn.Module), f"module must be an instance of nn.Module, got {type(module)}" assert isinstance( @@ -53,9 +52,7 @@ def __init__( self.scheduler = scheduler self.worker_manager = worker_manager self.parallel_context = parallel_context - self.pipeline_context = PipelineContext(scheduler, parallel_context) - self.partition_func = partition_func def run(self, inputs: torch.Tensor) -> torch.Tensor: self.worker_manager.spawn() @@ -114,7 +111,7 @@ def after_new_clock_cycle(self, progress, clock_idx): else: package = RECV_QUEUE.get() - job = create_job(self.partition_func, package, self.parallel_context, self.pipeline_context) + job = create_job(self.module, package, self.parallel_context, self.pipeline_context) JobQueue.PENDING_JOBS.put(job) dist.barrier() diff --git a/pipegoose/nn/pipeline_parallel2/pipeline_parallel.py b/pipegoose/nn/pipeline_parallel2/pipeline_parallel.py index 58fbe81..9914ec7 100644 --- a/pipegoose/nn/pipeline_parallel2/pipeline_parallel.py +++ b/pipegoose/nn/pipeline_parallel2/pipeline_parallel.py @@ -1,12 +1,13 @@ +from typing import List + import torch from torch import nn -from pipegoose.constants import PIPELINE_MAX_WORKERS, PIPELINE_MIN_WORKERS from pipegoose.distributed.parallel_context import ParallelContext +from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx from pipegoose.nn.pipeline_parallel2._worker import WorkerManager -from pipegoose.nn.pipeline_parallel2.partitioner import PartitionPolicy from pipegoose.nn.pipeline_parallel2.pipeline_engine import PipelineEngine -from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType, get_scheduler +from pipegoose.nn.pipeline_parallel2.scheduler import GPipeScheduler class PipelineParallel: @@ -14,33 +15,23 @@ class PipelineParallel: def __init__( self, - module: nn.Module, + modules: List[nn.Module], num_microbatches: int, - scheduler_type: SchedulerType, - partition_policy: PartitionPolicy, parallel_context: ParallelContext, ): - self.module = module + self.modules = modules self.num_microbatches = num_microbatches - self.scheduler_type = scheduler_type - self.partition_policy = partition_policy self.parallel_context = parallel_context @torch.no_grad() def parallelize(self) -> nn.Module: - module = self.module + partition_idx = get_partition_idx(self.parallel_context) + module = self.modules[partition_idx] + + n_partitions = self.parallel_context.pipeline_parallel_size + scheduler = GPipeScheduler(self.num_microbatches, n_partitions) + worker_manager = WorkerManager() - # TODO: lazy init - scheduler = get_scheduler( - scheduler_type=self.scheduler_type, - num_microbatches=self.num_microbatches, - parallel_context=self.parallel_context, - ) - worker_manager = WorkerManager( - min_workers=PIPELINE_MIN_WORKERS, - max_workers=PIPELINE_MAX_WORKERS, - parallel_context=self.parallel_context, - ) pipeline_engine = PipelineEngine( module=module, scheduler=scheduler, @@ -48,6 +39,5 @@ def parallelize(self) -> nn.Module: parallel_context=self.parallel_context, ) - pipeline_engine.parallelize(module) - - return pipeline_engine + self.modules.forward = pipeline_engine.run + return self.modules diff --git a/tests/nn/pipeline_parallel_2/job/test_backward.py b/tests/nn/pipeline_parallel_2/job/test_backward.py index 9f6974e..1677c05 100644 --- a/tests/nn/pipeline_parallel_2/job/test_backward.py +++ b/tests/nn/pipeline_parallel_2/job/test_backward.py @@ -1,4 +1,3 @@ -import time from copy import deepcopy import pytest @@ -11,16 +10,15 @@ save_grad_loss, ) from pipegoose.nn.pipeline_parallel2._job.callback import Callback -from pipegoose.nn.pipeline_parallel2._job.creator import schedule_backward_job -from pipegoose.nn.pipeline_parallel2._job.forward import ( - CreateForwardOutputPackageCallback, - ForwardJob, - SaveActivationIfTrainingCallback, - SaveInputActivationsCallback, -) + +# from pipegoose.nn.pipeline_parallel2._job.forward import ( +# CreateForwardOutputPackageCallback, +# ForwardJob, +# # SaveActivationIfTrainingCallback, +# # SaveInputActivationsCallback, +# ) from pipegoose.nn.pipeline_parallel2._job.job_type import JobType from pipegoose.nn.pipeline_parallel2._package import Package -from pipegoose.nn.pipeline_parallel2.queue import JobQueue from pipegoose.testing.utils import init_pipeline_context, spawn @@ -78,36 +76,36 @@ def backward_package_in_the_second_last_pipeline_stage(backward_package): return backward_package -def test_create_a_backward_job_if_a_tensor_do_backprop(forward_package, forward_function, parallel_context, pipeline_context): - callbacks = [ - CreateForwardOutputPackageCallback(parallel_context, pipeline_context), - SaveInputActivationsCallback, - SaveActivationIfTrainingCallback, - ] - forward_job = ForwardJob(forward_function, forward_package, callbacks) +# def test_create_a_backward_job_if_a_tensor_do_backprop(forward_package, forward_function, parallel_context, pipeline_context): +# callbacks = [ +# CreateForwardOutputPackageCallback(parallel_context, pipeline_context), +# SaveInputActivationsCallback, +# SaveActivationIfTrainingCallback, +# ] +# forward_job = ForwardJob(forward_function, forward_package, callbacks) - # NOTE: we enqueue the backward job in the destination rank - output = forward_job.compute() - DATA = output.data.clone() - METADATA = deepcopy(output.metadata) +# # NOTE: we enqueue the backward job in the destination rank +# output = forward_job.compute() +# DATA = output.data.clone() +# METADATA = deepcopy(output.metadata) - output = schedule_backward_job(output, pipeline_context) - # NOTE: make sure we aren't change the package - assert torch.equal(output.data, DATA) - assert output.metadata == METADATA +# output = schedule_backward_job(output, pipeline_context) +# # NOTE: make sure we aren't change the package +# assert torch.equal(output.data, DATA) +# assert output.metadata == METADATA - output.data.sum().backward(retain_graph=True) +# output.data.sum().backward(retain_graph=True) - # NOTE: since we don't launch any job selector workers in the background, - # after triggering the creation of a backward job, - # we expect the destination worker's job queue to have one job - time.sleep(0.1) - assert JobQueue.PENDING_JOBS.qsize() == 1 +# # NOTE: since we don't launch any job selector workers in the background, +# # after triggering the creation of a backward job, +# # we expect the destination worker's job queue to have one job +# time.sleep(0.1) +# assert JobQueue.PENDING_JOBS.qsize() == 1 - backward_job = JobQueue.PENDING_JOBS.get() - assert isinstance(backward_job, BackwardJob) +# backward_job = JobQueue.PENDING_JOBS.get() +# assert isinstance(backward_job, BackwardJob) - backward_job.compute() +# backward_job.compute() def test_the_gradient_output_of_a_backward_job(backward_package): diff --git a/tests/nn/pipeline_parallel_2/job/test_forward.py b/tests/nn/pipeline_parallel_2/job/test_forward.py index 1feb2af..d23e91a 100644 --- a/tests/nn/pipeline_parallel_2/job/test_forward.py +++ b/tests/nn/pipeline_parallel_2/job/test_forward.py @@ -2,18 +2,16 @@ import torch from torch import nn -from pipegoose.nn.pipeline_parallel2._job.forward import ( +from pipegoose.nn.pipeline_parallel2._job.forward import ( # SaveActivationIfTrainingCallback,; SaveInputActivationsCallback, ConfirmCompleteATaskToProgressTracker, CreateForwardOutputPackageCallback, ForwardJob, - SaveActivationIfTrainingCallback, - SaveInputActivationsCallback, SendForwardPackageCallback, ) from pipegoose.nn.pipeline_parallel2._job.job_type import JobType from pipegoose.nn.pipeline_parallel2._package import Package from pipegoose.nn.pipeline_parallel2._utils import sleep -from pipegoose.nn.pipeline_parallel2.queue import SavedActivation, get_input_activations +from pipegoose.nn.pipeline_parallel2.queue import get_input_activations from pipegoose.testing.utils import init_pipeline_context, spawn # NOTE: use for creating a forward job @@ -138,25 +136,25 @@ def test_forward_job_save_input_activations_for_backward_pass(forward_package, p assert saved_activations.requires_grad is True -def test_forward_job_save_output_activations_for_backward_pass(forward_package, parallel_context, pipeline_context): - MICROBATCH_IDX = forward_package.metadata.microbatch_idx - PARTITION_IDX = forward_package.metadata.partition_idx - CALLBACKS = [CreateForwardOutputPackageCallback(parallel_context, pipeline_context), SaveActivationIfTrainingCallback()] +# def test_forward_job_save_output_activations_for_backward_pass(forward_package, parallel_context, pipeline_context): +# MICROBATCH_IDX = forward_package.metadata.microbatch_idx +# PARTITION_IDX = forward_package.metadata.partition_idx +# CALLBACKS = [CreateForwardOutputPackageCallback(parallel_context, pipeline_context), SaveActivationIfTrainingCallback()] - key = SavedActivation.get_key(MICROBATCH_IDX, PARTITION_IDX) - forward_job = ForwardJob(function, forward_package, CALLBACKS) +# key = SavedActivation.get_key(MICROBATCH_IDX, PARTITION_IDX) +# forward_job = ForwardJob(function, forward_package, CALLBACKS) - output = forward_job.compute() - saved_activations = SavedActivation.get_saved_activations(key) +# output = forward_job.compute() +# saved_activations = SavedActivation.get_saved_activations(key) - assert isinstance(saved_activations, torch.Tensor) - assert torch.equal(saved_activations, output.data) - assert saved_activations.requires_grad is True +# assert isinstance(saved_activations, torch.Tensor) +# assert torch.equal(saved_activations, output.data) +# assert saved_activations.requires_grad is True - with pytest.raises(KeyError): - # NOTE: we expect the saved activations to be removed - # after retrieving them - SavedActivation.get_saved_activations(key) +# with pytest.raises(KeyError): +# # NOTE: we expect the saved activations to be removed +# # after retrieving them +# SavedActivation.get_saved_activations(key) def run_forward_job_send_output_to_the_next_pipeline_stage( diff --git a/tests/nn/pipeline_parallel_2/test_pipeline_engine.py b/tests/nn/pipeline_parallel_2/test_pipeline_engine.py index a548f92..de61d77 100644 --- a/tests/nn/pipeline_parallel_2/test_pipeline_engine.py +++ b/tests/nn/pipeline_parallel_2/test_pipeline_engine.py @@ -22,8 +22,8 @@ def run_pipeline_engine( n_microbatches, model, inputs, - outputs, - grads, + ref_outputs, + ref_grads, ): forward_timeline = [] backward_timeline = [] @@ -51,13 +51,13 @@ def forward(self, input): scheduler = GPipeScheduler(n_microbatches, pipeline_parallel_size) worker_manager = WorkerManager() partition_idx = get_partition_idx(parallel_context) - partition_func = Function(partition_idx) + + partition = Function(partition_idx) pipeline_engine = PipelineEngine( - module=model, + module=partition, scheduler=scheduler, worker_manager=worker_manager, parallel_context=parallel_context, - partition_func=partition_func, ) [(microbatch_idx, partition_idx) for microbatch_idx in range(n_microbatches)] EXPECTED_FORWARD_TIMELINE = [(microbatch_idx, partition_idx) for microbatch_idx in range(n_microbatches)] @@ -67,7 +67,7 @@ def forward(self, input): assert forward_timeline == EXPECTED_FORWARD_TIMELINE if is_last_stage(parallel_context): - assert torch.allclose(torch.cat(p_outputs, dim=0), outputs) + assert torch.allclose(torch.cat(p_outputs, dim=0), ref_outputs) for output in p_outputs: output.sum().backward(retain_graph=True) else: @@ -75,11 +75,11 @@ def forward(self, input): # assert p_outputs is None p_outputs.sum().backward() - for param in partition_func.parameters(): + for param in partition.parameters(): assert param.grad is not None - for p, ground_grad in zip(partition_func.parameters(), grads[partition_idx]): - assert torch.allclose(p.grad, ground_grad) + for p, ref_grad in zip(partition.parameters(), ref_grads[partition_idx]): + assert torch.allclose(p.grad, ref_grad) @pytest.mark.parametrize( @@ -115,6 +115,6 @@ def test_pipeline_engine(tensor_parallel_size, pipeline_parallel_size, data_para n_microbatches=N_MICROBATCHES, model=ORIG_MODEL, inputs=inputs.detach(), - outputs=outputs.detach(), - grads=grads, + ref_outputs=outputs.detach(), + ref_grads=grads, ) diff --git a/tests/nn/pipeline_parallel_2/test_pipeline_parallel.py b/tests/nn/pipeline_parallel_2/test_pipeline_parallel.py index 6c128a8..889efac 100644 --- a/tests/nn/pipeline_parallel_2/test_pipeline_parallel.py +++ b/tests/nn/pipeline_parallel_2/test_pipeline_parallel.py @@ -1,31 +1,73 @@ +from copy import deepcopy +from functools import reduce + import pytest import torch +from torch import nn -from pipegoose.nn.pipeline_parallel2.partitioner import PartitionPolicy +from pipegoose.nn.pipeline_parallel2._utils import is_last_stage from pipegoose.nn.pipeline_parallel2.pipeline_parallel import PipelineParallel -from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType +from pipegoose.testing.utils import init_parallel_context, spawn + +def run_pipeline_engine( + rank, + world_size, + port, + tensor_parallel_size, + pipeline_parallel_size, + data_parallel_size, + n_microbatches, + model, + inputs, + ref_outputs, +): -class FakeParallelContext: - pass + parallel_context = init_parallel_context( + rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size + ) + model = PipelineParallel(model, num_microbatches=n_microbatches, parallel_context=parallel_context).parallelize() + outputs = model(inputs) -@pytest.mark.skip -def test_pipeline_parallel(model): - parallel_context = FakeParallelContext() + if is_last_stage(parallel_context): + assert torch.allclose(torch.cat(outputs, dim=0), ref_outputs) + for output in outputs: + output.sum().backward(retain_graph=True) + else: + outputs.sum().backward() - NUM_MICROBATCHES = 5 - input = torch.randn(NUM_MICROBATCHES, 4) +@pytest.mark.parametrize( + "tensor_parallel_size, pipeline_parallel_size, data_parallel_size", + [ + (1, 4, 1), + # TODO: not works with 3d parallelism yet + # (2, 4, 2) + ], +) +def test_pipeline_engine(tensor_parallel_size, pipeline_parallel_size, data_parallel_size): + BATCH_SIZE = 32 + N_MICROBATCHES = 6 + SEQ_LEN = 10 + HIDDEN_DIM = 5 + WORLD_SIZE = tensor_parallel_size * pipeline_parallel_size * data_parallel_size - parallelized_model = PipelineParallel( - module=model, - num_microbatches=NUM_MICROBATCHES, - scheduler_type=SchedulerType.GPIPE, - partition_policy=PartitionPolicy.UNIFORM, - parallel_context=parallel_context, - ).parallelize() + inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM, requires_grad=False) + model = nn.ModuleList([nn.Sequential(nn.Linear(HIDDEN_DIM, HIDDEN_DIM), nn.ReLU()) for _ in range(pipeline_parallel_size)]) + ORIG_MODEL = deepcopy(model) + outputs = reduce(lambda inputs, layer: layer(inputs), model, inputs) - output = parallelized_model(input) + outputs.sum().backward() - assert isinstance(output, torch.Tensor) + spawn( + run_pipeline_engine, + world_size=WORLD_SIZE, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + data_parallel_size=data_parallel_size, + n_microbatches=N_MICROBATCHES, + model=ORIG_MODEL, + inputs=inputs.detach(), + ref_outputs=outputs.detach(), + )