Skip to content

Commit

Permalink
WIP refactor forward job, passed all tests except send output to the …
Browse files Browse the repository at this point in the history
…next pipeline stage
  • Loading branch information
xrsrke committed Oct 13, 2023
1 parent 3d8b167 commit 1b177cd
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 90 deletions.
13 changes: 8 additions & 5 deletions pipegoose/nn/pipeline_parallel2/_job/backward.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.pipeline_parallel2._job.callback import Callback
from pipegoose.nn.pipeline_parallel2._job.job import Job
from pipegoose.nn.pipeline_parallel2._package import Package
Expand Down Expand Up @@ -28,14 +29,16 @@ class SendBackwardPackageCallback(Callback):

order = 5

def __init__(self, parallel_context: ParallelContext):
self.parallel_context = parallel_context

def after_compute(self):
from pipegoose.nn.pipeline_parallel2._comm import send_package

parallel_context = self.job.pipeline_context.parallel_context
if parallel_context.pipeline_parallel_size > 1:
if self.parallel_context.pipeline_parallel_size > 1:
output = self.job.output
assert isinstance(output, Package), f"output must be an instance of Package, got {type(output)}"
send_package(output, parallel_context)
send_package(output, self.parallel_context)


class BackwardJob(Job):
Expand All @@ -58,8 +61,8 @@ def run_compute(self) -> torch.Tensor:
# # and we do gradient accumulation, we don't need return grads or send to other stages
# assert isinstance(input.grad, torch.Tensor)

rank = self.pipeline_context.parallel_context.get_global_rank()
print(f"executing backward job, rank={rank}, microbatch_idx={microbatch_idx}, partition_idx={partition_idx}")
# rank = self.pipeline_context.parallel_context.get_global_rank()
# print(f"executing backward job, rank={rank}, microbatch_idx={microbatch_idx}, partition_idx={partition_idx}")
print(f"yay! gradients: {input.grad.shape}")

if input.grad is None:
Expand Down
41 changes: 25 additions & 16 deletions pipegoose/nn/pipeline_parallel2/_job/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.pipeline_parallel2._comm import (
get_pipeline_context,
set_pipeline_context,
Expand Down Expand Up @@ -35,35 +36,41 @@ def create(self) -> Job:
class ScheduleBackwardJobCallback(Callback):
order = 3

def __init__(self, pipeline_context: PipelineContext):
self.pipeline_context = pipeline_context

def after_compute(self):
package = self.job.output
new_package = schedule_backward_job(package, self.job.pipeline_context)
new_package = schedule_backward_job(package, self.pipeline_context)
self.job.output = new_package


class _ForwardJobCreator(JobCreator):
"""Put a forward job into job queue for a worker to execute."""

CBS = [
CreateForwardOutputPackageCallback,
SaveInputActivationsCallback,
SaveActivationIfTrainingCallback,
ScheduleBackwardJobCallback,
SendForwardPackageCallback,
ConfirmCompleteATaskToProgressTracker,
]

@classmethod
def create(cls, function: Callable, package: Package, pipeline_context: PipelineContext) -> ForwardJob:
job = ForwardJob(function, package, cbs=cls.CBS, pipeline_context=pipeline_context)
def create(
cls, function: Callable, package: Package, parallel_context: ParallelContext, pipeline_context: PipelineContext
) -> ForwardJob:
callbacks = [
CreateForwardOutputPackageCallback(parallel_context, pipeline_context),
SaveInputActivationsCallback,
SaveActivationIfTrainingCallback,
ScheduleBackwardJobCallback(pipeline_context),
SendForwardPackageCallback(parallel_context),
ConfirmCompleteATaskToProgressTracker(parallel_context),
]
job = ForwardJob(function, package, callbacks)
return job


class _BackwardJobCreator(JobCreator):
# CBS = [CreateBackwardOutputPackageCallback, SendBackwardPackageCallback]

@classmethod
def create(cls, function: Callable, package: Package, pipeline_context: PipelineContext) -> BackwardJob:
def create(
cls, function: Callable, package: Package, parallel_context: ParallelContext, pipeline_context: PipelineContext
) -> BackwardJob:
from pipegoose.nn.pipeline_parallel2.queue import (
InputActivations,
SavedActivation,
Expand All @@ -81,11 +88,13 @@ def create(cls, function: Callable, package: Package, pipeline_context: Pipeline
), f"No saved input activations for \
microbatch_idx={microbatch_idx}, partition_idx={partition_idx}"

job = BackwardJob(function, package, pipeline_context=pipeline_context)
job = BackwardJob(function, package)
return job


def create_job(function: Callable, package: Package, pipeline_context: PipelineContext) -> Union[ForwardJob, BackwardJob]:
def create_job(
function: Callable, package: Package, parallel_context: ParallelContext, pipeline_context: PipelineContext
) -> Union[ForwardJob, BackwardJob]:
"""Create a job based on the package."""
assert isinstance(package, Package), f"package must be an instance of Package, got {type(package)}"
assert isinstance(
Expand All @@ -98,7 +107,7 @@ def create_job(function: Callable, package: Package, pipeline_context: PipelineC
}

job_type = package.metadata.job_type
job = JOB_TYPE_TO_CREATOR[job_type].create(function, package, pipeline_context)
job = JOB_TYPE_TO_CREATOR[job_type].create(function, package, parallel_context, pipeline_context)

return job

Expand Down
12 changes: 9 additions & 3 deletions pipegoose/nn/pipeline_parallel2/_job/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pipegoose.nn.pipeline_parallel2._job.job import Job
from pipegoose.nn.pipeline_parallel2._package import Package
from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext
from pipegoose.nn.pipeline_parallel2.sync.handshake import get_progress_tracker


class ForwardJob(Job):
Expand Down Expand Up @@ -120,11 +121,16 @@ class ConfirmCompleteATaskToProgressTracker(Callback):

order = 6

def after_compute(self):
from pipegoose.nn.pipeline_parallel2.sync.handshake import get_progress_tracker
def __init__(self, parallel_context: ParallelContext):
assert get_progress_tracker() is not None, "Progress tracker must be initialized before using this callback"

progress_tracker = get_progress_tracker()
world_size = parallel_context.get_world_size(ParallelMode.GLOBAL)
assert world_size > 1, "Progress tracker is only used in distributed training"

def after_compute(self):
microbatch_idx = self.job.input.metadata.microbatch_idx
partition_idx = self.job.input.metadata.partition_idx
key = (microbatch_idx, partition_idx)

progress_tracker = get_progress_tracker()
progress_tracker.confirm(key)
22 changes: 14 additions & 8 deletions pipegoose/nn/pipeline_parallel2/pipeline_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
ProgressTracker,
set_progress_tracker,
)
from pipegoose.nn.pipeline_parallel2.sync.progress_tracker import (
get_progresses_from_pipeline_context,
)


@dataclass
Expand Down Expand Up @@ -102,11 +105,12 @@ def after_new_clock_cycle(self, progress, clock_idx):

# if self.parallel_context.is_first_rank(ParallelMode.PIPELINE):
if self.parallel_context.get_global_rank() == 0:
schedules = self.pipeline_context.schedules
progress = {
i: {(item.microbatch_idx, item.partition_idx): False for item in sublist}
for i, sublist in enumerate(schedules)
}
# schedules = self.pipeline_context.schedules
# progress = {
# i: {(item.microbatch_idx, item.partition_idx): False for item in sublist}
# for i, sublist in enumerate(schedules)
# }
progress = get_progresses_from_pipeline_context(self.pipeline_context)
progress_tracker.initiate(progress)
print(progress)

Expand Down Expand Up @@ -157,9 +161,11 @@ def after_new_clock_cycle(self, progress, clock_idx):
# outputs = [SavedActivation.get_saved_activations((microbatch_idx, partition_idx)) for microbatch_idx in range(n_microbatches)]
# outputs = torch.cat(outputs, dim=0)
return outputs
# else:
# # NOTE: not terminate the worker, make it wait for processing further backward jobs
# time.sleep(100)
else:
import time

# NOTE: not terminate the worker, make it wait for processing further backward jobs
time.sleep(100)

# dist.barrier()

Expand Down
3 changes: 1 addition & 2 deletions pipegoose/nn/pipeline_parallel2/sync/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,14 @@ def is_initiated(self) -> bool:

def initiate(self, progress: Progress):
"""Initiate the progress tracker."""
set_progress_tracker(self)
if self.parallel_context.get_local_rank(self.parallel_mode) == self.master_rank:
INITIAL_CLOCK_IDX = 0

ProgressTracker._broadcast_tasks(progress, clock_idx=INITIAL_CLOCK_IDX, is_init=True)
ProgressTracker.progress = progress
ProgressTracker.clock_idx = INITIAL_CLOCK_IDX

set_progress_tracker(self)

@staticmethod
def _broadcast_tasks(progress, clock_idx, is_init=False):
parallel_context = ProgressTracker.parallel_context
Expand Down
8 changes: 4 additions & 4 deletions tests/nn/pipeline_parallel_2/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,18 @@ def backward_package(base_package):


@pytest.fixture(scope="function")
def backward_job(backward_package, pipeline_context):
def backward_job(backward_package, parallel_context, pipeline_context):
def function():
def backward_function(*args, **kwargs):
return torch.randn(1)

return backward_function

job = create_job(function, backward_package, pipeline_context)
job = create_job(function, backward_package, parallel_context, pipeline_context)
return job


@pytest.fixture(scope="function")
def forward_job(forward_package, pipeline_context):
def forward_job(forward_package, parallel_context, pipeline_context):
function = nn.Linear(*LINEAR_SHAPE)
return create_job(function, forward_package, pipeline_context)
return create_job(function, forward_package, parallel_context, pipeline_context)
61 changes: 20 additions & 41 deletions tests/nn/pipeline_parallel_2/job/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,16 @@
CreateBackwardOutputPackageCallback,
)
from pipegoose.nn.pipeline_parallel2._job.creator import schedule_backward_job
from pipegoose.nn.pipeline_parallel2._job.job_type import JobType
from pipegoose.nn.pipeline_parallel2._package import Package
from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext
from pipegoose.nn.pipeline_parallel2.queue import (
JobQueue,
SavedActivation,
get_input_activations,
)
from pipegoose.nn.pipeline_parallel2.queue import JobQueue, SavedActivation
from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType, get_scheduler
from pipegoose.testing.utils import init_parallel_context, init_pipeline_context, spawn

# NOTE: use for creating a forward job
function = nn.Linear(2, 4)

# BATCH_SIZE = 2
# SEQ_LEN = 5
# HIDDEN_SIZE = 10
# linear = nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE)

# input = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, requires_grad=True)
# INPUT = deepcopy(input)
# LINEAR = deepcopy(linear)
# OUTPUT = LINEAR(INPUT)
# INITIAL_GRADS = torch.ones_like(OUTPUT)
# OUTPUT.sum().backward()


@pytest.fixture
def forward_job():
"""A forward job that set with callbacks that use in training. like save input activation and output activations for backward job"""
# @pytest.fixture
# def forward_job():
# """A forward job that set with callbacks that use in training. like save input activation and output activations for backward job"""


@pytest.fixture
Expand All @@ -61,15 +43,6 @@ def run_create_a_backward_job_if_a_tensor_do_backprop(
forward_package.metadata.microbatch_idx
forward_package.metadata.partition_idx

# N_PARTITIONS = 3
# N_MICROBATCHES = 5
# parallel_context = init_parallel_context(
# rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
# )

# scheduler = get_scheduler(SchedulerType.GPIPE)(N_MICROBATCHES, N_PARTITIONS)
# pipeline_context = PipelineContext(scheduler, parallel_context)
# set_pipeline_context(pipeline_context)
pipeline_context, parallel_context = init_pipeline_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
Expand Down Expand Up @@ -186,22 +159,28 @@ def test_execute_scheduled_backward_job(request, package, pipeline_parallel_size
)


def test_execute_a_backward_job(backward_package, pipeline_context):
def test_execute_a_backward_job(forward_job, backward_package, pipeline_context):
def function(*args, **kwargs):
pass

MICROBATCH_IDX = backward_package.metadata.microbatch_idx
PARTITION_IDX = backward_package.metadata.partition_idx
# # NOTE: the backward job should do the backward pass
# # with respect to the input activations
INPUT_ACTS = get_input_activations(MICROBATCH_IDX, PARTITION_IDX)
# MICROBATCH_IDX = backward_package.metadata.microbatch_idx
# PARTITION_IDX = backward_package.metadata.partition_idx
# # # NOTE: the backward job should do the backward pass
# # # with respect to the input activations
# INPUT_ACTS = get_input_activations(MICROBATCH_IDX, PARTITION_IDX)

# backward_job = BackwardJob(function, backward_package)

backward_job = BackwardJob(function, backward_package, cbs=[], pipeline_context=pipeline_context)
output = forward_job.compute()
INITIAL_GRADS = torch.ones_like(output.data)
grad_package = Package(INITIAL_GRADS, forward_job.input.metadata)
grad_package.metadata.job_type = JobType.BACKWARD

backward_job = BackwardJob(function, grad_package)
grads = backward_job.compute()

assert isinstance(grads, torch.Tensor)
assert torch.equal(grads, INPUT_ACTS.grad)
# assert torch.equal(grads, INPUT_ACTS.grad)


def run_execute_a_backward_job_and_send_the_output(
Expand Down
Loading

0 comments on commit 1b177cd

Please sign in to comment.