From 1b177cdf441ad4c51bf3d47d68b20ddb9d54b5ff Mon Sep 17 00:00:00 2001 From: xrsrke Date: Fri, 13 Oct 2023 13:55:10 +0700 Subject: [PATCH] WIP refactor forward job, passed all tests except send output to the next pipeline stage --- .../nn/pipeline_parallel2/_job/backward.py | 13 ++-- .../nn/pipeline_parallel2/_job/creator.py | 41 ++++++++----- .../nn/pipeline_parallel2/_job/forward.py | 12 +++- .../nn/pipeline_parallel2/pipeline_engine.py | 22 ++++--- .../nn/pipeline_parallel2/sync/handshake.py | 3 +- tests/nn/pipeline_parallel_2/conftest.py | 8 +-- .../pipeline_parallel_2/job/test_backward.py | 61 ++++++------------- .../pipeline_parallel_2/job/test_creator.py | 61 +++++++++++++++---- .../pipeline_parallel_2/job/test_forward.py | 47 ++++++++++++++ .../test_pipeline_context.py | 34 +++++++++++ 10 files changed, 212 insertions(+), 90 deletions(-) diff --git a/pipegoose/nn/pipeline_parallel2/_job/backward.py b/pipegoose/nn/pipeline_parallel2/_job/backward.py index f4c8d4e..536b2d9 100644 --- a/pipegoose/nn/pipeline_parallel2/_job/backward.py +++ b/pipegoose/nn/pipeline_parallel2/_job/backward.py @@ -1,5 +1,6 @@ import torch +from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.nn.pipeline_parallel2._job.callback import Callback from pipegoose.nn.pipeline_parallel2._job.job import Job from pipegoose.nn.pipeline_parallel2._package import Package @@ -28,14 +29,16 @@ class SendBackwardPackageCallback(Callback): order = 5 + def __init__(self, parallel_context: ParallelContext): + self.parallel_context = parallel_context + def after_compute(self): from pipegoose.nn.pipeline_parallel2._comm import send_package - parallel_context = self.job.pipeline_context.parallel_context - if parallel_context.pipeline_parallel_size > 1: + if self.parallel_context.pipeline_parallel_size > 1: output = self.job.output assert isinstance(output, Package), f"output must be an instance of Package, got {type(output)}" - send_package(output, parallel_context) + send_package(output, self.parallel_context) class BackwardJob(Job): @@ -58,8 +61,8 @@ def run_compute(self) -> torch.Tensor: # # and we do gradient accumulation, we don't need return grads or send to other stages # assert isinstance(input.grad, torch.Tensor) - rank = self.pipeline_context.parallel_context.get_global_rank() - print(f"executing backward job, rank={rank}, microbatch_idx={microbatch_idx}, partition_idx={partition_idx}") + # rank = self.pipeline_context.parallel_context.get_global_rank() + # print(f"executing backward job, rank={rank}, microbatch_idx={microbatch_idx}, partition_idx={partition_idx}") print(f"yay! gradients: {input.grad.shape}") if input.grad is None: diff --git a/pipegoose/nn/pipeline_parallel2/_job/creator.py b/pipegoose/nn/pipeline_parallel2/_job/creator.py index 3e72957..04df5aa 100644 --- a/pipegoose/nn/pipeline_parallel2/_job/creator.py +++ b/pipegoose/nn/pipeline_parallel2/_job/creator.py @@ -3,6 +3,7 @@ import torch +from pipegoose.distributed.parallel_context import ParallelContext from pipegoose.nn.pipeline_parallel2._comm import ( get_pipeline_context, set_pipeline_context, @@ -35,27 +36,31 @@ def create(self) -> Job: class ScheduleBackwardJobCallback(Callback): order = 3 + def __init__(self, pipeline_context: PipelineContext): + self.pipeline_context = pipeline_context + def after_compute(self): package = self.job.output - new_package = schedule_backward_job(package, self.job.pipeline_context) + new_package = schedule_backward_job(package, self.pipeline_context) self.job.output = new_package class _ForwardJobCreator(JobCreator): """Put a forward job into job queue for a worker to execute.""" - CBS = [ - CreateForwardOutputPackageCallback, - SaveInputActivationsCallback, - SaveActivationIfTrainingCallback, - ScheduleBackwardJobCallback, - SendForwardPackageCallback, - ConfirmCompleteATaskToProgressTracker, - ] - @classmethod - def create(cls, function: Callable, package: Package, pipeline_context: PipelineContext) -> ForwardJob: - job = ForwardJob(function, package, cbs=cls.CBS, pipeline_context=pipeline_context) + def create( + cls, function: Callable, package: Package, parallel_context: ParallelContext, pipeline_context: PipelineContext + ) -> ForwardJob: + callbacks = [ + CreateForwardOutputPackageCallback(parallel_context, pipeline_context), + SaveInputActivationsCallback, + SaveActivationIfTrainingCallback, + ScheduleBackwardJobCallback(pipeline_context), + SendForwardPackageCallback(parallel_context), + ConfirmCompleteATaskToProgressTracker(parallel_context), + ] + job = ForwardJob(function, package, callbacks) return job @@ -63,7 +68,9 @@ class _BackwardJobCreator(JobCreator): # CBS = [CreateBackwardOutputPackageCallback, SendBackwardPackageCallback] @classmethod - def create(cls, function: Callable, package: Package, pipeline_context: PipelineContext) -> BackwardJob: + def create( + cls, function: Callable, package: Package, parallel_context: ParallelContext, pipeline_context: PipelineContext + ) -> BackwardJob: from pipegoose.nn.pipeline_parallel2.queue import ( InputActivations, SavedActivation, @@ -81,11 +88,13 @@ def create(cls, function: Callable, package: Package, pipeline_context: Pipeline ), f"No saved input activations for \ microbatch_idx={microbatch_idx}, partition_idx={partition_idx}" - job = BackwardJob(function, package, pipeline_context=pipeline_context) + job = BackwardJob(function, package) return job -def create_job(function: Callable, package: Package, pipeline_context: PipelineContext) -> Union[ForwardJob, BackwardJob]: +def create_job( + function: Callable, package: Package, parallel_context: ParallelContext, pipeline_context: PipelineContext +) -> Union[ForwardJob, BackwardJob]: """Create a job based on the package.""" assert isinstance(package, Package), f"package must be an instance of Package, got {type(package)}" assert isinstance( @@ -98,7 +107,7 @@ def create_job(function: Callable, package: Package, pipeline_context: PipelineC } job_type = package.metadata.job_type - job = JOB_TYPE_TO_CREATOR[job_type].create(function, package, pipeline_context) + job = JOB_TYPE_TO_CREATOR[job_type].create(function, package, parallel_context, pipeline_context) return job diff --git a/pipegoose/nn/pipeline_parallel2/_job/forward.py b/pipegoose/nn/pipeline_parallel2/_job/forward.py index ee4f37b..bccbcca 100644 --- a/pipegoose/nn/pipeline_parallel2/_job/forward.py +++ b/pipegoose/nn/pipeline_parallel2/_job/forward.py @@ -8,6 +8,7 @@ from pipegoose.nn.pipeline_parallel2._job.job import Job from pipegoose.nn.pipeline_parallel2._package import Package from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext +from pipegoose.nn.pipeline_parallel2.sync.handshake import get_progress_tracker class ForwardJob(Job): @@ -120,11 +121,16 @@ class ConfirmCompleteATaskToProgressTracker(Callback): order = 6 - def after_compute(self): - from pipegoose.nn.pipeline_parallel2.sync.handshake import get_progress_tracker + def __init__(self, parallel_context: ParallelContext): + assert get_progress_tracker() is not None, "Progress tracker must be initialized before using this callback" - progress_tracker = get_progress_tracker() + world_size = parallel_context.get_world_size(ParallelMode.GLOBAL) + assert world_size > 1, "Progress tracker is only used in distributed training" + + def after_compute(self): microbatch_idx = self.job.input.metadata.microbatch_idx partition_idx = self.job.input.metadata.partition_idx key = (microbatch_idx, partition_idx) + + progress_tracker = get_progress_tracker() progress_tracker.confirm(key) diff --git a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py index 38723d7..60d9da5 100644 --- a/pipegoose/nn/pipeline_parallel2/pipeline_engine.py +++ b/pipegoose/nn/pipeline_parallel2/pipeline_engine.py @@ -23,6 +23,9 @@ ProgressTracker, set_progress_tracker, ) +from pipegoose.nn.pipeline_parallel2.sync.progress_tracker import ( + get_progresses_from_pipeline_context, +) @dataclass @@ -102,11 +105,12 @@ def after_new_clock_cycle(self, progress, clock_idx): # if self.parallel_context.is_first_rank(ParallelMode.PIPELINE): if self.parallel_context.get_global_rank() == 0: - schedules = self.pipeline_context.schedules - progress = { - i: {(item.microbatch_idx, item.partition_idx): False for item in sublist} - for i, sublist in enumerate(schedules) - } + # schedules = self.pipeline_context.schedules + # progress = { + # i: {(item.microbatch_idx, item.partition_idx): False for item in sublist} + # for i, sublist in enumerate(schedules) + # } + progress = get_progresses_from_pipeline_context(self.pipeline_context) progress_tracker.initiate(progress) print(progress) @@ -157,9 +161,11 @@ def after_new_clock_cycle(self, progress, clock_idx): # outputs = [SavedActivation.get_saved_activations((microbatch_idx, partition_idx)) for microbatch_idx in range(n_microbatches)] # outputs = torch.cat(outputs, dim=0) return outputs - # else: - # # NOTE: not terminate the worker, make it wait for processing further backward jobs - # time.sleep(100) + else: + import time + + # NOTE: not terminate the worker, make it wait for processing further backward jobs + time.sleep(100) # dist.barrier() diff --git a/pipegoose/nn/pipeline_parallel2/sync/handshake.py b/pipegoose/nn/pipeline_parallel2/sync/handshake.py index 9841b7c..13f12a7 100644 --- a/pipegoose/nn/pipeline_parallel2/sync/handshake.py +++ b/pipegoose/nn/pipeline_parallel2/sync/handshake.py @@ -98,6 +98,7 @@ def is_initiated(self) -> bool: def initiate(self, progress: Progress): """Initiate the progress tracker.""" + set_progress_tracker(self) if self.parallel_context.get_local_rank(self.parallel_mode) == self.master_rank: INITIAL_CLOCK_IDX = 0 @@ -105,8 +106,6 @@ def initiate(self, progress: Progress): ProgressTracker.progress = progress ProgressTracker.clock_idx = INITIAL_CLOCK_IDX - set_progress_tracker(self) - @staticmethod def _broadcast_tasks(progress, clock_idx, is_init=False): parallel_context = ProgressTracker.parallel_context diff --git a/tests/nn/pipeline_parallel_2/conftest.py b/tests/nn/pipeline_parallel_2/conftest.py index 859a835..3f12ec4 100644 --- a/tests/nn/pipeline_parallel_2/conftest.py +++ b/tests/nn/pipeline_parallel_2/conftest.py @@ -133,18 +133,18 @@ def backward_package(base_package): @pytest.fixture(scope="function") -def backward_job(backward_package, pipeline_context): +def backward_job(backward_package, parallel_context, pipeline_context): def function(): def backward_function(*args, **kwargs): return torch.randn(1) return backward_function - job = create_job(function, backward_package, pipeline_context) + job = create_job(function, backward_package, parallel_context, pipeline_context) return job @pytest.fixture(scope="function") -def forward_job(forward_package, pipeline_context): +def forward_job(forward_package, parallel_context, pipeline_context): function = nn.Linear(*LINEAR_SHAPE) - return create_job(function, forward_package, pipeline_context) + return create_job(function, forward_package, parallel_context, pipeline_context) diff --git a/tests/nn/pipeline_parallel_2/job/test_backward.py b/tests/nn/pipeline_parallel_2/job/test_backward.py index 24f5ebd..b74a480 100644 --- a/tests/nn/pipeline_parallel_2/job/test_backward.py +++ b/tests/nn/pipeline_parallel_2/job/test_backward.py @@ -12,34 +12,16 @@ CreateBackwardOutputPackageCallback, ) from pipegoose.nn.pipeline_parallel2._job.creator import schedule_backward_job +from pipegoose.nn.pipeline_parallel2._job.job_type import JobType +from pipegoose.nn.pipeline_parallel2._package import Package from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext -from pipegoose.nn.pipeline_parallel2.queue import ( - JobQueue, - SavedActivation, - get_input_activations, -) +from pipegoose.nn.pipeline_parallel2.queue import JobQueue, SavedActivation from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType, get_scheduler from pipegoose.testing.utils import init_parallel_context, init_pipeline_context, spawn -# NOTE: use for creating a forward job -function = nn.Linear(2, 4) - -# BATCH_SIZE = 2 -# SEQ_LEN = 5 -# HIDDEN_SIZE = 10 -# linear = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE) - -# input = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, requires_grad=True) -# INPUT = deepcopy(input) -# LINEAR = deepcopy(linear) -# OUTPUT = LINEAR(INPUT) -# INITIAL_GRADS = torch.ones_like(OUTPUT) -# OUTPUT.sum().backward() - - -@pytest.fixture -def forward_job(): - """A forward job that set with callbacks that use in training. like save input activation and output activations for backward job""" +# @pytest.fixture +# def forward_job(): +# """A forward job that set with callbacks that use in training. like save input activation and output activations for backward job""" @pytest.fixture @@ -61,15 +43,6 @@ def run_create_a_backward_job_if_a_tensor_do_backprop( forward_package.metadata.microbatch_idx forward_package.metadata.partition_idx - # N_PARTITIONS = 3 - # N_MICROBATCHES = 5 - # parallel_context = init_parallel_context( - # rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size - # ) - - # scheduler = get_scheduler(SchedulerType.GPIPE)(N_MICROBATCHES, N_PARTITIONS) - # pipeline_context = PipelineContext(scheduler, parallel_context) - # set_pipeline_context(pipeline_context) pipeline_context, parallel_context = init_pipeline_context( rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size ) @@ -186,22 +159,28 @@ def test_execute_scheduled_backward_job(request, package, pipeline_parallel_size ) -def test_execute_a_backward_job(backward_package, pipeline_context): +def test_execute_a_backward_job(forward_job, backward_package, pipeline_context): def function(*args, **kwargs): pass - MICROBATCH_IDX = backward_package.metadata.microbatch_idx - PARTITION_IDX = backward_package.metadata.partition_idx - # # NOTE: the backward job should do the backward pass - # # with respect to the input activations - INPUT_ACTS = get_input_activations(MICROBATCH_IDX, PARTITION_IDX) + # MICROBATCH_IDX = backward_package.metadata.microbatch_idx + # PARTITION_IDX = backward_package.metadata.partition_idx + # # # NOTE: the backward job should do the backward pass + # # # with respect to the input activations + # INPUT_ACTS = get_input_activations(MICROBATCH_IDX, PARTITION_IDX) + + # backward_job = BackwardJob(function, backward_package) - backward_job = BackwardJob(function, backward_package, cbs=[], pipeline_context=pipeline_context) + output = forward_job.compute() + INITIAL_GRADS = torch.ones_like(output.data) + grad_package = Package(INITIAL_GRADS, forward_job.input.metadata) + grad_package.metadata.job_type = JobType.BACKWARD + backward_job = BackwardJob(function, grad_package) grads = backward_job.compute() assert isinstance(grads, torch.Tensor) - assert torch.equal(grads, INPUT_ACTS.grad) + # assert torch.equal(grads, INPUT_ACTS.grad) def run_execute_a_backward_job_and_send_the_output( diff --git a/tests/nn/pipeline_parallel_2/job/test_creator.py b/tests/nn/pipeline_parallel_2/job/test_creator.py index 2727650..7da37a1 100644 --- a/tests/nn/pipeline_parallel_2/job/test_creator.py +++ b/tests/nn/pipeline_parallel_2/job/test_creator.py @@ -1,33 +1,72 @@ import pytest +import torch.distributed as dist from torch import nn +from pipegoose.distributed.parallel_mode import ParallelMode from pipegoose.nn.pipeline_parallel2._job.backward import BackwardJob from pipegoose.nn.pipeline_parallel2._job.creator import create_job from pipegoose.nn.pipeline_parallel2._job.forward import ForwardJob from pipegoose.nn.pipeline_parallel2._job.job import JobStatus +from pipegoose.nn.pipeline_parallel2.sync.handshake import ProgressTracker +from pipegoose.nn.pipeline_parallel2.sync.progress_tracker import ( + get_progresses_from_pipeline_context, +) +from pipegoose.testing.utils import init_pipeline_context, spawn # NOTE: use for creating a forward job function = nn.Linear(2, 4) -@pytest.mark.parametrize("package", ["forward_package", "backward_package"]) -def test_the_job_status_after_executing_a_job(request, package, pipeline_context): - package = request.getfixturevalue(package) - job = create_job(function, package, pipeline_context) +# @pytest.mark.parametrize("package", ["forward_package", "backward_package"]) +# def test_the_job_status_after_executing_a_job(request, package, parallel_context, pipeline_context): +# package = request.getfixturevalue(package) +# job = create_job(function, package, parallel_context, pipeline_context) - job.compute() +# job.compute() - assert job.status == JobStatus.EXECUTED +# assert job.status == JobStatus.EXECUTED -@pytest.mark.parametrize("package, job_cls", [("forward_package", ForwardJob), ("backward_package", BackwardJob)]) -def test_create_a_job_from_package(request, package, forward_job, job_cls, pipeline_context): - package = request.getfixturevalue(package) +def run_create_a_job_from_package( + rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, package, job_cls +): + MASTER_RANK = 0 + pipeline_context, parallel_context = init_pipeline_context( + rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size + ) + tracker = ProgressTracker(MASTER_RANK, parallel_context=parallel_context, parallel_mode=ParallelMode.GLOBAL) + progresses = get_progresses_from_pipeline_context(pipeline_context) + tracker.initiate(progresses) + + dist.barrier() - forward_job.compute() - job = create_job(function, package, pipeline_context) + job = create_job(function, package, parallel_context, pipeline_context) assert isinstance(job, job_cls) assert isinstance(job.key, str) assert callable(job.function) is True assert job.status == JobStatus.PENDING + + job.compute() + + assert job.status == JobStatus.EXECUTED + + +@pytest.mark.parametrize("package, job_cls", [("forward_package", ForwardJob), ("backward_package", BackwardJob)]) +def test_create_a_job_from_package(request, package, job_cls): + TENSOR_PARALLEL_SIZE = 1 + PIPELINE_PARALLEL_SIZE = 2 + DATA_PARALLEL_SIZE = 1 + WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE + + package = request.getfixturevalue(package) + + spawn( + run_create_a_job_from_package, + world_size=WORLD_SIZE, + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, + data_parallel_size=DATA_PARALLEL_SIZE, + package=package, + job_cls=job_cls, + ) diff --git a/tests/nn/pipeline_parallel_2/job/test_forward.py b/tests/nn/pipeline_parallel_2/job/test_forward.py index 9bef755..a821a9b 100644 --- a/tests/nn/pipeline_parallel_2/job/test_forward.py +++ b/tests/nn/pipeline_parallel_2/job/test_forward.py @@ -3,6 +3,7 @@ from torch import nn from pipegoose.nn.pipeline_parallel2._job.forward import ( + ConfirmCompleteATaskToProgressTracker, CreateForwardOutputPackageCallback, ForwardJob, SaveActivationIfTrainingCallback, @@ -174,3 +175,49 @@ def test_forward_job_send_output_to_the_next_pipeline_stage(forward_package, pip data_parallel_size=DATA_PARALLEL_SIZE, forward_package=forward_package, ) + + +def run_confirm_a_forward_job_after_completing_it( + rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size, forward_package +): + import torch.distributed as dist + + from pipegoose.distributed.parallel_mode import ParallelMode + from pipegoose.nn.pipeline_parallel2.sync.handshake import ProgressTracker + from pipegoose.nn.pipeline_parallel2.sync.progress_tracker import ( + get_progresses_from_pipeline_context, + ) + + MASTER_RANK = 0 + + pipeline_context, parallel_context = init_pipeline_context( + rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size + ) + + tracker = ProgressTracker(MASTER_RANK, parallel_context=parallel_context, parallel_mode=ParallelMode.GLOBAL) + progresses = get_progresses_from_pipeline_context(pipeline_context) + tracker.initiate(progresses) + dist.barrier() + + callbacks = [ConfirmCompleteATaskToProgressTracker(parallel_context)] + forward_job = ForwardJob(function, forward_package, callbacks) + forward_job.compute() + + assert tracker.is_all_confirmed(clock_idx=0) is True + + +def test_confirm_a_forward_job_after_completing_it(forward_package): + TENSOR_PARALLEL_SIZE = 1 + PIPELINE_PARALLEL_SIZE = 2 + DATA_PARALLEL_SIZE = 1 + + WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE + + spawn( + run_confirm_a_forward_job_after_completing_it, + world_size=WORLD_SIZE, + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, + data_parallel_size=DATA_PARALLEL_SIZE, + forward_package=forward_package, + ) diff --git a/tests/nn/pipeline_parallel_2/test_pipeline_context.py b/tests/nn/pipeline_parallel_2/test_pipeline_context.py index 75d829c..6348dc5 100644 --- a/tests/nn/pipeline_parallel_2/test_pipeline_context.py +++ b/tests/nn/pipeline_parallel_2/test_pipeline_context.py @@ -108,3 +108,37 @@ def test_get_syncronous_schedule(): pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, data_parallel_size=DATA_PARALLEL_SIZE, ) + + +# def run_pipeline_context_init_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size): +# N_PARTITIONS = 4 +# N_MICROBATCHES = 5 + +# parallel_context = init_parallel_context( +# rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size +# ) +# scheduler = get_scheduler(SchedulerType.GPIPE)(N_MICROBATCHES, N_PARTITIONS) +# TOTAL_SCHEDULES = scheduler.total_clock_cycles + +# def increase_clock_every_second(pipeline_context): +# for _ in range(TOTAL_SCHEDULES): +# pipeline_context.increase_a_clock_cycle() + +# pipeline_context = PipelineContext(scheduler, parallel_context) + +# assert 1 == 1 + + +# def test_pipeline_context_init_progress_tracker(): +# TENSOR_PARALLEL_SIZE = 1 +# PIPELINE_PARALLEL_SIZE = 2 +# DATA_PARALLEL_SIZE = 1 +# WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE + +# spawn( +# run_pipeline_context_init_progress_tracker, +# world_size=WORLD_SIZE, +# tensor_parallel_size=TENSOR_PARALLEL_SIZE, +# pipeline_parallel_size=PIPELINE_PARALLEL_SIZE, +# data_parallel_size=DATA_PARALLEL_SIZE, +# )