Skip to content

Commit

Permalink
change the return output of the middle pipeline stages to have the sa…
Browse files Browse the repository at this point in the history
…me shape as the final stage
  • Loading branch information
xrsrke committed Oct 18, 2023
1 parent f923e00 commit 142a144
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 22 deletions.
11 changes: 4 additions & 7 deletions pipegoose/nn/pipeline_parallel2/pipeline_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pipegoose.nn.pipeline_parallel2._package import Metadata, Package, TrainingMetadata
from pipegoose.nn.pipeline_parallel2._worker import BaseWorkerManager
from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext
from pipegoose.nn.pipeline_parallel2.queue import JobQueue
from pipegoose.nn.pipeline_parallel2.queue import _SAVED_SCHEDULED_ACTIVATIONS, JobQueue
from pipegoose.nn.pipeline_parallel2.scheduler import BaseScheduler
from pipegoose.nn.pipeline_parallel2.sync.callback import Callback
from pipegoose.nn.pipeline_parallel2.sync.handshake import (
Expand Down Expand Up @@ -118,19 +118,16 @@ def after_new_clock_cycle(self, progress, clock_idx):

dist.barrier()

from pipegoose.nn.pipeline_parallel2.queue import _SAVED_SCHEDULED_ACTIVATIONS

if self.pipeline_context.is_last_stage:
outputs = []

for microbatch_idx in range(n_microbatches):
output = _SAVED_SCHEDULED_ACTIVATIONS[(microbatch_idx, self.pipeline_context.partition_idx)]
outputs.append(output)

return outputs
else:
output = _SAVED_SCHEDULED_ACTIVATIONS[(microbatch_idx, self.pipeline_context.partition_idx)]
return output
outputs = [torch.zeros(1, requires_grad=True).float() for _ in range(n_microbatches - 1)] + [output]

return outputs

def _construct_first_package(self, microbatch_idx: int, input: torch.Tensor) -> Package:
"""Construct the first forward package of a microbatch."""
Expand Down
13 changes: 5 additions & 8 deletions tests/nn/pipeline_parallel_2/test_pipeline_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,15 @@ def forward(self, input):
[(microbatch_idx, partition_idx) for microbatch_idx in range(n_microbatches)]
EXPECTED_FORWARD_TIMELINE = [(microbatch_idx, partition_idx) for microbatch_idx in range(n_microbatches)]
# EXPECTED_BACKWARD_TIMELINE = [(microbatch_idx, partition_idx) for microbatch_idx in range(n_microbatches, -1, -1)]
p_outputs = pipeline_engine.run(inputs)
outputs = pipeline_engine.run(inputs)

assert forward_timeline == EXPECTED_FORWARD_TIMELINE

if is_last_stage(parallel_context):
assert torch.allclose(torch.cat(p_outputs, dim=0), ref_outputs)
for output in p_outputs:
output.sum().backward(retain_graph=True)
else:
# NOTE: earlier stages should not return the final output
# assert p_outputs is None
p_outputs.sum().backward()
assert torch.allclose(torch.cat(outputs, dim=0), ref_outputs)

for output in outputs:
output.sum().backward(retain_graph=True)

for param in partition.parameters():
assert param.grad is not None
Expand Down
13 changes: 6 additions & 7 deletions tests/nn/pipeline_parallel_2/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pipegoose.testing.utils import init_parallel_context, spawn


def run_pipeline_engine(
def run_pipeline_parallel(
rank,
world_size,
port,
Expand All @@ -32,10 +32,9 @@ def run_pipeline_engine(

if is_last_stage(parallel_context):
assert torch.allclose(torch.cat(outputs, dim=0), ref_outputs)
for output in outputs:
output.sum().backward(retain_graph=True)
else:
outputs.sum().backward()

for output in outputs:
output.sum().backward(retain_graph=True)


@pytest.mark.parametrize(
Expand All @@ -46,7 +45,7 @@ def run_pipeline_engine(
# (2, 4, 2)
],
)
def test_pipeline_engine(tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
def test_pipeline_parallel(tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
BATCH_SIZE = 32
N_MICROBATCHES = 6
SEQ_LEN = 10
Expand All @@ -61,7 +60,7 @@ def test_pipeline_engine(tensor_parallel_size, pipeline_parallel_size, data_para
outputs.sum().backward()

spawn(
run_pipeline_engine,
run_pipeline_parallel,
world_size=WORLD_SIZE,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
Expand Down

0 comments on commit 142a144

Please sign in to comment.