diff --git a/pipegoose/nn/pipeline_parallel2/_comm.py b/pipegoose/nn/pipeline_parallel2/_comm.py index 5a70854..47eb6f9 100644 --- a/pipegoose/nn/pipeline_parallel2/_comm.py +++ b/pipegoose/nn/pipeline_parallel2/_comm.py @@ -5,7 +5,8 @@ from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.nn.pipeline_parallel2._package import Package -from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext + +# from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext RECV_QUEUE = Queue() # RECV_QUEUE = dict() @@ -15,12 +16,12 @@ _PIPELINE_CONTEXT = None -def set_pipeline_context(pipeline_context: PipelineContext): +def set_pipeline_context(pipeline_context): global _PIPELINE_CONTEXT _PIPELINE_CONTEXT = pipeline_context -def get_pipeline_context() -> PipelineContext: +def get_pipeline_context(): return _PIPELINE_CONTEXT diff --git a/pipegoose/nn/pipeline_parallel2/_job/backward.py b/pipegoose/nn/pipeline_parallel2/_job/backward.py index 536b2d9..2252fcc 100644 --- a/pipegoose/nn/pipeline_parallel2/_job/backward.py +++ b/pipegoose/nn/pipeline_parallel2/_job/backward.py @@ -49,12 +49,17 @@ def run_compute(self) -> torch.Tensor: partition_idx = self.input.metadata.partition_idx prev_grad = self.input.data - output = get_output_activations(microbatch_idx, partition_idx) input = get_input_activations(microbatch_idx, partition_idx) + output = get_output_activations(microbatch_idx, partition_idx, is_pipeline=True) torch.autograd.backward(output, grad_tensors=prev_grad) if input.requires_grad is False: + raise PipelineGradientFlowError( + "Please set .requires_grad = True to input activations. Gradients can't flow back to the input of the pipeline stage" + ) + + if input.grad is None: raise PipelineGradientFlowError("Gradients can't flow back to the input of the pipeline stage") # # TODO: remove this, since the grads is stored in module's weights @@ -65,7 +70,4 @@ def run_compute(self) -> torch.Tensor: # print(f"executing backward job, rank={rank}, microbatch_idx={microbatch_idx}, partition_idx={partition_idx}") print(f"yay! gradients: {input.grad.shape}") - if input.grad is None: - raise PipelineGradientFlowError("Gradients can't flow back to the input of the pipeline stage") - return input.grad diff --git a/pipegoose/nn/pipeline_parallel2/_job/creator.py b/pipegoose/nn/pipeline_parallel2/_job/creator.py index c299164..88f7229 100644 --- a/pipegoose/nn/pipeline_parallel2/_job/creator.py +++ b/pipegoose/nn/pipeline_parallel2/_job/creator.py @@ -4,10 +4,7 @@ import torch from pipegoose.distributed.parallel_context import ParallelContext -from pipegoose.nn.pipeline_parallel2._comm import ( - get_pipeline_context, - set_pipeline_context, -) +from pipegoose.nn.pipeline_parallel2._comm import get_pipeline_context from pipegoose.nn.pipeline_parallel2._job.backward import BackwardJob from pipegoose.nn.pipeline_parallel2._job.callback import Callback from pipegoose.nn.pipeline_parallel2._job.forward import ( @@ -138,54 +135,6 @@ def backward_function(self): def schedule_backward_job(package: Package, pipeline_context: PipelineContext) -> Package: - # assert isinstance(package, Package), f"package must be an instance of Package, got {type(package)}" - - # class _ScheduleBackwardJob(torch.autograd.Function): - # @staticmethod - # def forward(ctx, metadata: Metadata, pipeline_context: PipelineContext, input: torch.Tensor): - # # NOTE: can't assign metadata attribute to ctx - # # "AttributeError: attribute 'metadata' of 'torch._C._FunctionBase' - # # objects is not writable" - # rank = pipeline_context.parallel_context.get_global_rank() - # print(f"scheduled a backward job, rank={rank}, microbatch_idx={metadata.microbatch_idx}") - # ctx.package_meta = metadata - # ctx.pipeline_context = pipeline_context - # ctx.input = input - # return input - - # @staticmethod - # def backward(ctx: Any, grad_input: torch.Tensor): - # metadata = ctx.package_meta - # # pipeline_context = ctx.pipeline_context - # # parallel_context = pipeline_context.parallel_context - - # # rank = parallel_context.get_global_rank() - # # microbatch_idx = metadata.microbatch_idx - - # # dst_worker_name = parallel_context.get_worker_name(metadata.dst) - # # print(grad_input) - # # print(f"creating a backward job, rank={rank}, microbatch_idx={microbatch_idx}, dst_worker_name={dst_worker_name}") - - # _create_backward_job_and_put_to_pending_queue(grad_input, metadata) - # # TODO: because forward job and backward job are in the same node - # # rpc isn't necessary - # # rpc.rpc_sync( - # # # NOTE: the backward job create in the same node - # # # as the forward job - # # to=dst_worker_name, - # # func=_create_backward_job_and_put_to_pending_queue, - # # args=(grad_input, metadata), - # # ) - - # return (None, None, None) - - set_pipeline_context(pipeline_context) - - # metadata = package.metadata - # new_data = _ScheduleBackwardJob.apply(metadata, pipeline_context, package.data) - # package.data = new_data - # return package - class Function(torch.autograd.Function): @staticmethod def forward(ctx, metadata: Metadata, input): @@ -196,12 +145,6 @@ def forward(ctx, metadata: Metadata, input): def backward(ctx, grad_input): metadata = ctx.package_meta _create_backward_job_and_put_to_pending_queue(grad_input, metadata) - - # from pipegoose.nn.pipeline_parallel2.queue import SavedActivation - # output = SavedActivation.get_saved_activations((0, 0)) - # detached_output = output.detach().requires_grad_() - # torch.autograd.backward(detached_output, grad_input) - # return detached_output.grad return (None, grad_input) data = package.data diff --git a/pipegoose/nn/pipeline_parallel2/pipeline_context.py b/pipegoose/nn/pipeline_parallel2/pipeline_context.py index a9c2584..09ca29e 100644 --- a/pipegoose/nn/pipeline_parallel2/pipeline_context.py +++ b/pipegoose/nn/pipeline_parallel2/pipeline_context.py @@ -3,6 +3,7 @@ from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.distributed.parallel_mode import ParallelMode +from pipegoose.nn.pipeline_parallel2._comm import set_pipeline_context from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx, is_last_stage from pipegoose.nn.pipeline_parallel2.scheduler import BaseScheduler @@ -19,6 +20,8 @@ def __init__(self, scheduler: BaseScheduler, parallel_context: ParallelContext): # NOTE: block CPU thread until the next clock cycle self._wait_new_clock_cycle = threading.Condition() + set_pipeline_context(self) + @property def partition_idx(self) -> int: parallel_context = self.parallel_context diff --git a/tests/nn/pipeline_parallel_2/conftest.py b/tests/nn/pipeline_parallel_2/conftest.py index 3f12ec4..8853ed7 100644 --- a/tests/nn/pipeline_parallel_2/conftest.py +++ b/tests/nn/pipeline_parallel_2/conftest.py @@ -144,7 +144,14 @@ def backward_function(*args, **kwargs): return job +@pytest.fixture +def forward_function(): + return nn.Linear(*LINEAR_SHAPE) + + @pytest.fixture(scope="function") -def forward_job(forward_package, parallel_context, pipeline_context): - function = nn.Linear(*LINEAR_SHAPE) - return create_job(function, forward_package, parallel_context, pipeline_context) +def forward_job(forward_package, forward_function): + from pipegoose.nn.pipeline_parallel2._job.forward import ForwardJob + + # return create_job(function, forward_package, parallel_context, pipeline_context) + return ForwardJob(forward_function, forward_package) diff --git a/tests/nn/pipeline_parallel_2/test_pipeline_context.py b/tests/nn/pipeline_parallel_2/test_pipeline_context.py index 6348dc5..f18ec2d 100644 --- a/tests/nn/pipeline_parallel_2/test_pipeline_context.py +++ b/tests/nn/pipeline_parallel_2/test_pipeline_context.py @@ -2,6 +2,7 @@ import pytest +from pipegoose.nn.pipeline_parallel2._comm import get_pipeline_context from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType, get_scheduler from pipegoose.nn.pipeline_parallel2.task import Task @@ -37,6 +38,7 @@ def run_pipeline_context(rank, world_size, port, tensor_parallel_size, pipeline_ assert isinstance(pipeline_context.schedules, list) assert isinstance(pipeline_context.get_schedule_from_partition(clock_idx=3, partition_idx=2), list) assert isinstance(pipeline_context.get_schedule_from_microbatch(clock_idx=3, microbatch_idx=0), list) + assert get_pipeline_context() == pipeline_context next_schedules = pipeline_context.get_next_schedule_from_microbatch(microbatch_idx=0) assert isinstance(next_schedules, list) @@ -108,37 +110,3 @@ def test_get_syncronous_schedule(): pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, data_parallel_size=DATA_PARALLEL_SIZE, ) - - -# def run_pipeline_context_init_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size): -# N_PARTITIONS = 4 -# N_MICROBATCHES = 5 - -# parallel_context = init_parallel_context( -# rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size -# ) -# scheduler = get_scheduler(SchedulerType.GPIPE)(N_MICROBATCHES, N_PARTITIONS) -# TOTAL_SCHEDULES = scheduler.total_clock_cycles - -# def increase_clock_every_second(pipeline_context): -# for _ in range(TOTAL_SCHEDULES): -# pipeline_context.increase_a_clock_cycle() - -# pipeline_context = PipelineContext(scheduler, parallel_context) - -# assert 1 == 1 - - -# def test_pipeline_context_init_progress_tracker(): -# TENSOR_PARALLEL_SIZE = 1 -# PIPELINE_PARALLEL_SIZE = 2 -# DATA_PARALLEL_SIZE = 1 -# WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE - -# spawn( -# run_pipeline_context_init_progress_tracker, -# world_size=WORLD_SIZE, -# tensor_parallel_size=TENSOR_PARALLEL_SIZE, -# pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, -# data_parallel_size=DATA_PARALLEL_SIZE, -# )