-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
134 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,43 @@ | ||
from typing import List | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from pipegoose.constants import PIPELINE_MAX_WORKERS, PIPELINE_MIN_WORKERS | ||
from pipegoose.distributed.parallel_context import ParallelContext | ||
from pipegoose.nn.pipeline_parallel2._utils import get_partition_idx | ||
from pipegoose.nn.pipeline_parallel2._worker import WorkerManager | ||
from pipegoose.nn.pipeline_parallel2.partitioner import PartitionPolicy | ||
from pipegoose.nn.pipeline_parallel2.pipeline_engine import PipelineEngine | ||
from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType, get_scheduler | ||
from pipegoose.nn.pipeline_parallel2.scheduler import GPipeScheduler | ||
|
||
|
||
class PipelineParallel: | ||
"""Automatically parallelize a module using pipeline parallelism.""" | ||
|
||
def __init__( | ||
self, | ||
module: nn.Module, | ||
modules: List[nn.Module], | ||
num_microbatches: int, | ||
scheduler_type: SchedulerType, | ||
partition_policy: PartitionPolicy, | ||
parallel_context: ParallelContext, | ||
): | ||
self.module = module | ||
self.modules = modules | ||
self.num_microbatches = num_microbatches | ||
self.scheduler_type = scheduler_type | ||
self.partition_policy = partition_policy | ||
self.parallel_context = parallel_context | ||
|
||
@torch.no_grad() | ||
def parallelize(self) -> nn.Module: | ||
module = self.module | ||
partition_idx = get_partition_idx(self.parallel_context) | ||
module = self.modules[partition_idx] | ||
|
||
n_partitions = self.parallel_context.pipeline_parallel_size | ||
scheduler = GPipeScheduler(self.num_microbatches, n_partitions) | ||
worker_manager = WorkerManager() | ||
|
||
# TODO: lazy init | ||
scheduler = get_scheduler( | ||
scheduler_type=self.scheduler_type, | ||
num_microbatches=self.num_microbatches, | ||
parallel_context=self.parallel_context, | ||
) | ||
worker_manager = WorkerManager( | ||
min_workers=PIPELINE_MIN_WORKERS, | ||
max_workers=PIPELINE_MAX_WORKERS, | ||
parallel_context=self.parallel_context, | ||
) | ||
pipeline_engine = PipelineEngine( | ||
module=module, | ||
scheduler=scheduler, | ||
worker_manager=worker_manager, | ||
parallel_context=self.parallel_context, | ||
) | ||
|
||
pipeline_engine.parallelize(module) | ||
|
||
return pipeline_engine | ||
self.modules.forward = pipeline_engine.run | ||
return self.modules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,73 @@ | ||
from copy import deepcopy | ||
from functools import reduce | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
from pipegoose.nn.pipeline_parallel2.partitioner import PartitionPolicy | ||
from pipegoose.nn.pipeline_parallel2._utils import is_last_stage | ||
from pipegoose.nn.pipeline_parallel2.pipeline_parallel import PipelineParallel | ||
from pipegoose.nn.pipeline_parallel2.scheduler import SchedulerType | ||
from pipegoose.testing.utils import init_parallel_context, spawn | ||
|
||
|
||
def run_pipeline_engine( | ||
rank, | ||
world_size, | ||
port, | ||
tensor_parallel_size, | ||
pipeline_parallel_size, | ||
data_parallel_size, | ||
n_microbatches, | ||
model, | ||
inputs, | ||
ref_outputs, | ||
): | ||
|
||
class FakeParallelContext: | ||
pass | ||
parallel_context = init_parallel_context( | ||
rank, world_size, port, tensor_parallel_size, pipeline_parallel_size, data_parallel_size | ||
) | ||
model = PipelineParallel(model, num_microbatches=n_microbatches, parallel_context=parallel_context).parallelize() | ||
|
||
outputs = model(inputs) | ||
|
||
@pytest.mark.skip | ||
def test_pipeline_parallel(model): | ||
parallel_context = FakeParallelContext() | ||
if is_last_stage(parallel_context): | ||
assert torch.allclose(torch.cat(outputs, dim=0), ref_outputs) | ||
for output in outputs: | ||
output.sum().backward(retain_graph=True) | ||
else: | ||
outputs.sum().backward() | ||
|
||
NUM_MICROBATCHES = 5 | ||
|
||
input = torch.randn(NUM_MICROBATCHES, 4) | ||
@pytest.mark.parametrize( | ||
"tensor_parallel_size, pipeline_parallel_size, data_parallel_size", | ||
[ | ||
(1, 4, 1), | ||
# TODO: not works with 3d parallelism yet | ||
# (2, 4, 2) | ||
], | ||
) | ||
def test_pipeline_engine(tensor_parallel_size, pipeline_parallel_size, data_parallel_size): | ||
BATCH_SIZE = 32 | ||
N_MICROBATCHES = 6 | ||
SEQ_LEN = 10 | ||
HIDDEN_DIM = 5 | ||
WORLD_SIZE = tensor_parallel_size * pipeline_parallel_size * data_parallel_size | ||
|
||
parallelized_model = PipelineParallel( | ||
module=model, | ||
num_microbatches=NUM_MICROBATCHES, | ||
scheduler_type=SchedulerType.GPIPE, | ||
partition_policy=PartitionPolicy.UNIFORM, | ||
parallel_context=parallel_context, | ||
).parallelize() | ||
inputs = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_DIM, requires_grad=False) | ||
model = nn.ModuleList([nn.Sequential(nn.Linear(HIDDEN_DIM, HIDDEN_DIM), nn.ReLU()) for _ in range(pipeline_parallel_size)]) | ||
ORIG_MODEL = deepcopy(model) | ||
outputs = reduce(lambda inputs, layer: layer(inputs), model, inputs) | ||
|
||
output = parallelized_model(input) | ||
outputs.sum().backward() | ||
|
||
assert isinstance(output, torch.Tensor) | ||
spawn( | ||
run_pipeline_engine, | ||
world_size=WORLD_SIZE, | ||
tensor_parallel_size=tensor_parallel_size, | ||
pipeline_parallel_size=pipeline_parallel_size, | ||
data_parallel_size=data_parallel_size, | ||
n_microbatches=N_MICROBATCHES, | ||
model=ORIG_MODEL, | ||
inputs=inputs.detach(), | ||
ref_outputs=outputs.detach(), | ||
) |