Skip to content

Commit

Permalink
add PipelineParallel
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 18, 2023
1 parent 4593863 commit f923e00
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 109 deletions.
5 changes: 1 addition & 4 deletions pipegoose/nn/pipeline_parallel2/pipeline_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def __init__(
scheduler: BaseScheduler,
worker_manager: BaseWorkerManager,
parallel_context: ParallelContext,
partition_func,
):
assert isinstance(module, nn.Module), f"module must be an instance of nn.Module, got {type(module)}"
assert isinstance(
Expand All @@ -53,9 +52,7 @@ def __init__(
self.scheduler = scheduler
self.worker_manager = worker_manager
self.parallel_context = parallel_context

self.pipeline_context = PipelineContext(scheduler, parallel_context)
self.partition_func = partition_func

def run(self, inputs: torch.Tensor) -> torch.Tensor:
self.worker_manager.spawn()
Expand Down Expand Up @@ -114,7 +111,7 @@ def after_new_clock_cycle(self, progress, clock_idx):
else:
package = RECV_QUEUE.get()

job = create_job(self.partition_func, package, self.parallel_context, self.pipeline_context)
job = create_job(self.module, package, self.parallel_context, self.pipeline_context)
JobQueue.PENDING_JOBS.put(job)

dist.barrier()
Expand Down
38 changes: 14 additions & 24 deletions pipegoose/nn/pipeline_parallel2/pipeline_parallel.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,43 @@
from typing import List

import torch
from torch import nn

from pipegoose.constants import PIPELINE_MAX_WORKERS, PIPELINE_MIN_WORKERS
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx
from pipegoose.nn.pipeline_parallel2._worker import WorkerManager
from pipegoose.nn.pipeline_parallel2.partitioner import PartitionPolicy
from pipegoose.nn.pipeline_parallel2.pipeline_engine import PipelineEngine
from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType, get_scheduler
from pipegoose.nn.pipeline_parallel2.scheduler import GPipeScheduler


class PipelineParallel:
"""Automatically parallelize a module using pipeline parallelism."""

def __init__(
self,
module: nn.Module,
modules: List[nn.Module],
num_microbatches: int,
scheduler_type: SchedulerType,
partition_policy: PartitionPolicy,
parallel_context: ParallelContext,
):
self.module = module
self.modules = modules
self.num_microbatches = num_microbatches
self.scheduler_type = scheduler_type
self.partition_policy = partition_policy
self.parallel_context = parallel_context

@torch.no_grad()
def parallelize(self) -> nn.Module:
module = self.module
partition_idx = get_partition_idx(self.parallel_context)
module = self.modules[partition_idx]

n_partitions = self.parallel_context.pipeline_parallel_size
scheduler = GPipeScheduler(self.num_microbatches, n_partitions)
worker_manager = WorkerManager()

# TODO: lazy init
scheduler = get_scheduler(
scheduler_type=self.scheduler_type,
num_microbatches=self.num_microbatches,
parallel_context=self.parallel_context,
)
worker_manager = WorkerManager(
min_workers=PIPELINE_MIN_WORKERS,
max_workers=PIPELINE_MAX_WORKERS,
parallel_context=self.parallel_context,
)
pipeline_engine = PipelineEngine(
module=module,
scheduler=scheduler,
worker_manager=worker_manager,
parallel_context=self.parallel_context,
)

pipeline_engine.parallelize(module)

return pipeline_engine
self.modules.forward = pipeline_engine.run
return self.modules
64 changes: 31 additions & 33 deletions tests/nn/pipeline_parallel_2/job/test_backward.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import time
from copy import deepcopy

import pytest
Expand All @@ -11,16 +10,15 @@
save_grad_loss,
)
from pipegoose.nn.pipeline_parallel2._job.callback import Callback
from pipegoose.nn.pipeline_parallel2._job.creator import schedule_backward_job
from pipegoose.nn.pipeline_parallel2._job.forward import (
CreateForwardOutputPackageCallback,
ForwardJob,
SaveActivationIfTrainingCallback,
SaveInputActivationsCallback,
)

# from pipegoose.nn.pipeline_parallel2._job.forward import (
# CreateForwardOutputPackageCallback,
# ForwardJob,
# # SaveActivationIfTrainingCallback,
# # SaveInputActivationsCallback,
# )
from pipegoose.nn.pipeline_parallel2._job.job_type import JobType
from pipegoose.nn.pipeline_parallel2._package import Package
from pipegoose.nn.pipeline_parallel2.queue import JobQueue
from pipegoose.testing.utils import init_pipeline_context, spawn


Expand Down Expand Up @@ -78,36 +76,36 @@ def backward_package_in_the_second_last_pipeline_stage(backward_package):
return backward_package


def test_create_a_backward_job_if_a_tensor_do_backprop(forward_package, forward_function, parallel_context, pipeline_context):
callbacks = [
CreateForwardOutputPackageCallback(parallel_context, pipeline_context),
SaveInputActivationsCallback,
SaveActivationIfTrainingCallback,
]
forward_job = ForwardJob(forward_function, forward_package, callbacks)
# def test_create_a_backward_job_if_a_tensor_do_backprop(forward_package, forward_function, parallel_context, pipeline_context):
# callbacks = [
# CreateForwardOutputPackageCallback(parallel_context, pipeline_context),
# SaveInputActivationsCallback,
# SaveActivationIfTrainingCallback,
# ]
# forward_job = ForwardJob(forward_function, forward_package, callbacks)

# NOTE: we enqueue the backward job in the destination rank
output = forward_job.compute()
DATA = output.data.clone()
METADATA = deepcopy(output.metadata)
# # NOTE: we enqueue the backward job in the destination rank
# output = forward_job.compute()
# DATA = output.data.clone()
# METADATA = deepcopy(output.metadata)

output = schedule_backward_job(output, pipeline_context)
# NOTE: make sure we aren't change the package
assert torch.equal(output.data, DATA)
assert output.metadata == METADATA
# output = schedule_backward_job(output, pipeline_context)
# # NOTE: make sure we aren't change the package
# assert torch.equal(output.data, DATA)
# assert output.metadata == METADATA

output.data.sum().backward(retain_graph=True)
# output.data.sum().backward(retain_graph=True)

# NOTE: since we don't launch any job selector workers in the background,
# after triggering the creation of a backward job,
# we expect the destination worker's job queue to have one job
time.sleep(0.1)
assert JobQueue.PENDING_JOBS.qsize() == 1
# # NOTE: since we don't launch any job selector workers in the background,
# # after triggering the creation of a backward job,
# # we expect the destination worker's job queue to have one job
# time.sleep(0.1)
# assert JobQueue.PENDING_JOBS.qsize() == 1

backward_job = JobQueue.PENDING_JOBS.get()
assert isinstance(backward_job, BackwardJob)
# backward_job = JobQueue.PENDING_JOBS.get()
# assert isinstance(backward_job, BackwardJob)

backward_job.compute()
# backward_job.compute()


def test_the_gradient_output_of_a_backward_job(backward_package):
Expand Down
36 changes: 17 additions & 19 deletions tests/nn/pipeline_parallel_2/job/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
import torch
from torch import nn

from pipegoose.nn.pipeline_parallel2._job.forward import (
from pipegoose.nn.pipeline_parallel2._job.forward import ( # SaveActivationIfTrainingCallback,; SaveInputActivationsCallback,
ConfirmCompleteATaskToProgressTracker,
CreateForwardOutputPackageCallback,
ForwardJob,
SaveActivationIfTrainingCallback,
SaveInputActivationsCallback,
SendForwardPackageCallback,
)
from pipegoose.nn.pipeline_parallel2._job.job_type import JobType
from pipegoose.nn.pipeline_parallel2._package import Package
from pipegoose.nn.pipeline_parallel2._utils import sleep
from pipegoose.nn.pipeline_parallel2.queue import SavedActivation, get_input_activations
from pipegoose.nn.pipeline_parallel2.queue import get_input_activations
from pipegoose.testing.utils import init_pipeline_context, spawn

# NOTE: use for creating a forward job
Expand Down Expand Up @@ -138,25 +136,25 @@ def test_forward_job_save_input_activations_for_backward_pass(forward_package, p
assert saved_activations.requires_grad is True


def test_forward_job_save_output_activations_for_backward_pass(forward_package, parallel_context, pipeline_context):
MICROBATCH_IDX = forward_package.metadata.microbatch_idx
PARTITION_IDX = forward_package.metadata.partition_idx
CALLBACKS = [CreateForwardOutputPackageCallback(parallel_context, pipeline_context), SaveActivationIfTrainingCallback()]
# def test_forward_job_save_output_activations_for_backward_pass(forward_package, parallel_context, pipeline_context):
# MICROBATCH_IDX = forward_package.metadata.microbatch_idx
# PARTITION_IDX = forward_package.metadata.partition_idx
# CALLBACKS = [CreateForwardOutputPackageCallback(parallel_context, pipeline_context), SaveActivationIfTrainingCallback()]

key = SavedActivation.get_key(MICROBATCH_IDX, PARTITION_IDX)
forward_job = ForwardJob(function, forward_package, CALLBACKS)
# key = SavedActivation.get_key(MICROBATCH_IDX, PARTITION_IDX)
# forward_job = ForwardJob(function, forward_package, CALLBACKS)

output = forward_job.compute()
saved_activations = SavedActivation.get_saved_activations(key)
# output = forward_job.compute()
# saved_activations = SavedActivation.get_saved_activations(key)

assert isinstance(saved_activations, torch.Tensor)
assert torch.equal(saved_activations, output.data)
assert saved_activations.requires_grad is True
# assert isinstance(saved_activations, torch.Tensor)
# assert torch.equal(saved_activations, output.data)
# assert saved_activations.requires_grad is True

with pytest.raises(KeyError):
# NOTE: we expect the saved activations to be removed
# after retrieving them
SavedActivation.get_saved_activations(key)
# with pytest.raises(KeyError):
# # NOTE: we expect the saved activations to be removed
# # after retrieving them
# SavedActivation.get_saved_activations(key)


def run_forward_job_send_output_to_the_next_pipeline_stage(
Expand Down
22 changes: 11 additions & 11 deletions tests/nn/pipeline_parallel_2/test_pipeline_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def run_pipeline_engine(
n_microbatches,
model,
inputs,
outputs,
grads,
ref_outputs,
ref_grads,
):
forward_timeline = []
backward_timeline = []
Expand Down Expand Up @@ -51,13 +51,13 @@ def forward(self, input):
scheduler = GPipeScheduler(n_microbatches, pipeline_parallel_size)
worker_manager = WorkerManager()
partition_idx = get_partition_idx(parallel_context)
partition_func = Function(partition_idx)

partition = Function(partition_idx)
pipeline_engine = PipelineEngine(
module=model,
module=partition,
scheduler=scheduler,
worker_manager=worker_manager,
parallel_context=parallel_context,
partition_func=partition_func,
)
[(microbatch_idx, partition_idx) for microbatch_idx in range(n_microbatches)]
EXPECTED_FORWARD_TIMELINE = [(microbatch_idx, partition_idx) for microbatch_idx in range(n_microbatches)]
Expand All @@ -67,19 +67,19 @@ def forward(self, input):
assert forward_timeline == EXPECTED_FORWARD_TIMELINE

if is_last_stage(parallel_context):
assert torch.allclose(torch.cat(p_outputs, dim=0), outputs)
assert torch.allclose(torch.cat(p_outputs, dim=0), ref_outputs)
for output in p_outputs:
output.sum().backward(retain_graph=True)
else:
# NOTE: earlier stages should not return the final output
# assert p_outputs is None
p_outputs.sum().backward()

for param in partition_func.parameters():
for param in partition.parameters():
assert param.grad is not None

for p, ground_grad in zip(partition_func.parameters(), grads[partition_idx]):
assert torch.allclose(p.grad, ground_grad)
for p, ref_grad in zip(partition.parameters(), ref_grads[partition_idx]):
assert torch.allclose(p.grad, ref_grad)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -115,6 +115,6 @@ def test_pipeline_engine(tensor_parallel_size, pipeline_parallel_size, data_para
n_microbatches=N_MICROBATCHES,
model=ORIG_MODEL,
inputs=inputs.detach(),
outputs=outputs.detach(),
grads=grads,
ref_outputs=outputs.detach(),
ref_grads=grads,
)
78 changes: 60 additions & 18 deletions tests/nn/pipeline_parallel_2/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,73 @@
from copy import deepcopy
from functools import reduce

import pytest
import torch
from torch import nn

from pipegoose.nn.pipeline_parallel2.partitioner import PartitionPolicy
from pipegoose.nn.pipeline_parallel2._utils import is_last_stage
from pipegoose.nn.pipeline_parallel2.pipeline_parallel import PipelineParallel
from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType
from pipegoose.testing.utils import init_parallel_context, spawn


def run_pipeline_engine(
rank,
world_size,
port,
tensor_parallel_size,
pipeline_parallel_size,
data_parallel_size,
n_microbatches,
model,
inputs,
ref_outputs,
):

class FakeParallelContext:
pass
parallel_context = init_parallel_context(
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size
)
model = PipelineParallel(model, num_microbatches=n_microbatches, parallel_context=parallel_context).parallelize()

outputs = model(inputs)

@pytest.mark.skip
def test_pipeline_parallel(model):
parallel_context = FakeParallelContext()
if is_last_stage(parallel_context):
assert torch.allclose(torch.cat(outputs, dim=0), ref_outputs)
for output in outputs:
output.sum().backward(retain_graph=True)
else:
outputs.sum().backward()

NUM_MICROBATCHES = 5

input = torch.randn(NUM_MICROBATCHES, 4)
@pytest.mark.parametrize(
"tensor_parallel_size, pipeline_parallel_size, data_parallel_size",
[
(1, 4, 1),
# TODO: not works with 3d parallelism yet
# (2, 4, 2)
],
)
def test_pipeline_engine(tensor_parallel_size, pipeline_parallel_size, data_parallel_size):
BATCH_SIZE = 32
N_MICROBATCHES = 6
SEQ_LEN = 10
HIDDEN_DIM = 5
WORLD_SIZE = tensor_parallel_size * pipeline_parallel_size * data_parallel_size

parallelized_model = PipelineParallel(
module=model,
num_microbatches=NUM_MICROBATCHES,
scheduler_type=SchedulerType.GPIPE,
partition_policy=PartitionPolicy.UNIFORM,
parallel_context=parallel_context,
).parallelize()
inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM, requires_grad=False)
model = nn.ModuleList([nn.Sequential(nn.Linear(HIDDEN_DIM, HIDDEN_DIM), nn.ReLU()) for _ in range(pipeline_parallel_size)])
ORIG_MODEL = deepcopy(model)
outputs = reduce(lambda inputs, layer: layer(inputs), model, inputs)

output = parallelized_model(input)
outputs.sum().backward()

assert isinstance(output, torch.Tensor)
spawn(
run_pipeline_engine,
world_size=WORLD_SIZE,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
data_parallel_size=data_parallel_size,
n_microbatches=N_MICROBATCHES,
model=ORIG_MODEL,
inputs=inputs.detach(),
ref_outputs=outputs.detach(),
)

0 comments on commit f923e00

Please sign in to comment.