Skip to content

Commit

Permalink
refactor scheduling backward job
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 13, 2023
1 parent 1ae7cd1 commit 73f8169
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 102 deletions.
7 changes: 4 additions & 3 deletions pipegoose/nn/pipeline_parallel2/_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.pipeline_parallel2._package import Package
from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext

# from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext

RECV_QUEUE = Queue()
# RECV_QUEUE = dict()
Expand All @@ -15,12 +16,12 @@
_PIPELINE_CONTEXT = None


def set_pipeline_context(pipeline_context: PipelineContext):
def set_pipeline_context(pipeline_context):
global _PIPELINE_CONTEXT
_PIPELINE_CONTEXT = pipeline_context


def get_pipeline_context() -> PipelineContext:
def get_pipeline_context():
return _PIPELINE_CONTEXT


Expand Down
10 changes: 6 additions & 4 deletions pipegoose/nn/pipeline_parallel2/_job/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,17 @@ def run_compute(self) -> torch.Tensor:
partition_idx = self.input.metadata.partition_idx
prev_grad = self.input.data

output = get_output_activations(microbatch_idx, partition_idx)
input = get_input_activations(microbatch_idx, partition_idx)
output = get_output_activations(microbatch_idx, partition_idx, is_pipeline=True)

torch.autograd.backward(output, grad_tensors=prev_grad)

if input.requires_grad is False:
raise PipelineGradientFlowError(
"Please set .requires_grad = True to input activations. Gradients can't flow back to the input of the pipeline stage"
)

if input.grad is None:
raise PipelineGradientFlowError("Gradients can't flow back to the input of the pipeline stage")

# # TODO: remove this, since the grads is stored in module's weights
Expand All @@ -65,7 +70,4 @@ def run_compute(self) -> torch.Tensor:
# print(f"executing backward job, rank={rank}, microbatch_idx={microbatch_idx}, partition_idx={partition_idx}")
print(f"yay! gradients: {input.grad.shape}")

if input.grad is None:
raise PipelineGradientFlowError("Gradients can't flow back to the input of the pipeline stage")

return input.grad
59 changes: 1 addition & 58 deletions pipegoose/nn/pipeline_parallel2/_job/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
import torch

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.pipeline_parallel2._comm import (
get_pipeline_context,
set_pipeline_context,
)
from pipegoose.nn.pipeline_parallel2._comm import get_pipeline_context
from pipegoose.nn.pipeline_parallel2._job.backward import BackwardJob
from pipegoose.nn.pipeline_parallel2._job.callback import Callback
from pipegoose.nn.pipeline_parallel2._job.forward import (
Expand Down Expand Up @@ -138,54 +135,6 @@ def backward_function(self):


def schedule_backward_job(package: Package, pipeline_context: PipelineContext) -> Package:
# assert isinstance(package, Package), f"package must be an instance of Package, got {type(package)}"

# class _ScheduleBackwardJob(torch.autograd.Function):
# @staticmethod
# def forward(ctx, metadata: Metadata, pipeline_context: PipelineContext, input: torch.Tensor):
# # NOTE: can't assign metadata attribute to ctx
# # "AttributeError: attribute 'metadata' of 'torch._C._FunctionBase'
# # objects is not writable"
# rank = pipeline_context.parallel_context.get_global_rank()
# print(f"scheduled a backward job, rank={rank}, microbatch_idx={metadata.microbatch_idx}")
# ctx.package_meta = metadata
# ctx.pipeline_context = pipeline_context
# ctx.input = input
# return input

# @staticmethod
# def backward(ctx: Any, grad_input: torch.Tensor):
# metadata = ctx.package_meta
# # pipeline_context = ctx.pipeline_context
# # parallel_context = pipeline_context.parallel_context

# # rank = parallel_context.get_global_rank()
# # microbatch_idx = metadata.microbatch_idx

# # dst_worker_name = parallel_context.get_worker_name(metadata.dst)
# # print(grad_input)
# # print(f"creating a backward job, rank={rank}, microbatch_idx={microbatch_idx}, dst_worker_name={dst_worker_name}")

# _create_backward_job_and_put_to_pending_queue(grad_input, metadata)
# # TODO: because forward job and backward job are in the same node
# # rpc isn't necessary
# # rpc.rpc_sync(
# # # NOTE: the backward job create in the same node
# # # as the forward job
# # to=dst_worker_name,
# # func=_create_backward_job_and_put_to_pending_queue,
# # args=(grad_input, metadata),
# # )

# return (None, None, None)

set_pipeline_context(pipeline_context)

# metadata = package.metadata
# new_data = _ScheduleBackwardJob.apply(metadata, pipeline_context, package.data)
# package.data = new_data
# return package

class Function(torch.autograd.Function):
@staticmethod
def forward(ctx, metadata: Metadata, input):
Expand All @@ -196,12 +145,6 @@ def forward(ctx, metadata: Metadata, input):
def backward(ctx, grad_input):
metadata = ctx.package_meta
_create_backward_job_and_put_to_pending_queue(grad_input, metadata)

# from pipegoose.nn.pipeline_parallel2.queue import SavedActivation
# output = SavedActivation.get_saved_activations((0, 0))
# detached_output = output.detach().requires_grad_()
# torch.autograd.backward(detached_output, grad_input)
# return detached_output.grad
return (None, grad_input)

data = package.data
Expand Down
3 changes: 3 additions & 0 deletions pipegoose/nn/pipeline_parallel2/pipeline_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.distributed.parallel_mode import ParallelMode
from pipegoose.nn.pipeline_parallel2._comm import set_pipeline_context
from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx, is_last_stage
from pipegoose.nn.pipeline_parallel2.scheduler import BaseScheduler

Expand All @@ -19,6 +20,8 @@ def __init__(self, scheduler: BaseScheduler, parallel_context: ParallelContext):
# NOTE: block CPU thread until the next clock cycle
self._wait_new_clock_cycle = threading.Condition()

set_pipeline_context(self)

@property
def partition_idx(self) -> int:
parallel_context = self.parallel_context
Expand Down
13 changes: 10 additions & 3 deletions tests/nn/pipeline_parallel_2/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,14 @@ def backward_function(*args, **kwargs):
return job


@pytest.fixture
def forward_function():
return nn.Linear(*LINEAR_SHAPE)


@pytest.fixture(scope="function")
def forward_job(forward_package, parallel_context, pipeline_context):
function = nn.Linear(*LINEAR_SHAPE)
return create_job(function, forward_package, parallel_context, pipeline_context)
def forward_job(forward_package, forward_function):
from pipegoose.nn.pipeline_parallel2._job.forward import ForwardJob

# return create_job(function, forward_package, parallel_context, pipeline_context)
return ForwardJob(forward_function, forward_package)
36 changes: 2 additions & 34 deletions tests/nn/pipeline_parallel_2/test_pipeline_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from pipegoose.nn.pipeline_parallel2._comm import get_pipeline_context
from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext
from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType, get_scheduler
from pipegoose.nn.pipeline_parallel2.task import Task
Expand Down Expand Up @@ -37,6 +38,7 @@ def run_pipeline_context(rank, world_size, port, tensor_parallel_size, pipeline_
assert isinstance(pipeline_context.schedules, list)
assert isinstance(pipeline_context.get_schedule_from_partition(clock_idx=3, partition_idx=2), list)
assert isinstance(pipeline_context.get_schedule_from_microbatch(clock_idx=3, microbatch_idx=0), list)
assert get_pipeline_context() == pipeline_context

next_schedules = pipeline_context.get_next_schedule_from_microbatch(microbatch_idx=0)
assert isinstance(next_schedules, list)
Expand Down Expand Up @@ -108,37 +110,3 @@ def test_get_syncronous_schedule():
pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
data_parallel_size=DATA_PARALLEL_SIZE,
)


# def run_pipeline_context_init_progress_tracker(rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
# N_PARTITIONS = 4
# N_MICROBATCHES = 5

# parallel_context = init_parallel_context(
# rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
# )
# scheduler = get_scheduler(SchedulerType.GPIPE)(N_MICROBATCHES, N_PARTITIONS)
# TOTAL_SCHEDULES = scheduler.total_clock_cycles

# def increase_clock_every_second(pipeline_context):
# for _ in range(TOTAL_SCHEDULES):
# pipeline_context.increase_a_clock_cycle()

# pipeline_context = PipelineContext(scheduler, parallel_context)

# assert 1 == 1


# def test_pipeline_context_init_progress_tracker():
# TENSOR_PARALLEL_SIZE = 1
# PIPELINE_PARALLEL_SIZE = 2
# DATA_PARALLEL_SIZE = 1
# WORLD_SIZE = TENSOR_PARALLEL_SIZE * PIPELINE_PARALLEL_SIZE * DATA_PARALLEL_SIZE

# spawn(
# run_pipeline_context_init_progress_tracker,
# world_size=WORLD_SIZE,
# tensor_parallel_size=TENSOR_PARALLEL_SIZE,
# pipeline_parallel_size=PIPELINE_PARALLEL_SIZE,
# data_parallel_size=DATA_PARALLEL_SIZE,
# )

0 comments on commit 73f8169

Please sign in to comment.