Skip to content

Commit

Permalink
refactor pipeline engine
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Oct 18, 2023
1 parent 3402dbb commit 4593863
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 145 deletions.
5 changes: 0 additions & 5 deletions pipegoose/nn/pipeline_parallel2/_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from pipegoose.distributed.parallel_context import ParallelContext
from pipegoose.nn.pipeline_parallel2._package import Package

# from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext

RECV_QUEUE = Queue()
# RECV_QUEUE = dict()

# TODO: refactor to a singleton class
# NOTE: save parallel context for backward job
Expand Down Expand Up @@ -54,6 +51,4 @@ def _recv_package(package: Package, src: int, dst: int):
package.metadata.microbatch_idx
package.metadata.partition_idx

# TODO: refactor, user should not recall how to construct the key
# RECV_QUEUE[(microbatch_idx, partition_idx)] = package
RECV_QUEUE.put(package)
46 changes: 17 additions & 29 deletions pipegoose/nn/pipeline_parallel2/_job/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,26 @@
class _SaveGradLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, key, metadata, tensor: torch.Tensor):
print("forward of saving grad loss", key)
ctx.key = key
ctx.package_metadata = metadata
new_tensor = tensor.detach().clone()
return new_tensor

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor):
with torch.autograd.set_grad_enabled(False):
from pipegoose.nn.pipeline_parallel2.queue import (
_SAVED_GRAD_LOSS,
_SAVED_METADATA_of_GRAD_LOSS,
)
from pipegoose.nn.pipeline_parallel2.queue import (
_SAVED_GRAD_LOSS,
_SAVED_METADATA_of_GRAD_LOSS,
)

with torch.autograd.set_grad_enabled(False):
key = ctx.key
print("backward of saving grad loss", ctx.key)

_SAVED_GRAD_LOSS[key] = grad_output
_SAVED_METADATA_of_GRAD_LOSS[key] = ctx.package_metadata
# NOTE: prevent the grad from flowing back to the input of the pipeline stage
# since we only do one backward pass at a time

# NOTE: prevent the grad from flowing back to the input of the pipeline stage
# because we only do one backward pass at a time
# and it's orchestrated by the pipeline engine
return (None, None, None)


Expand Down Expand Up @@ -119,48 +118,37 @@ def __init__(self, *args, is_scheduled: bool = False, **kwargs):
self.is_scheduled: bool = is_scheduled

def run_compute(self) -> torch.Tensor:
from pipegoose.nn.pipeline_parallel2._comm import get_pipeline_context

microbatch_idx = self.input.metadata.microbatch_idx
partition_idx = self.input.metadata.partition_idx
prev_grad = self.input.data

from pipegoose.nn.pipeline_parallel2._comm import get_pipeline_context

pipeline_context = get_pipeline_context()
rank = pipeline_context.parallel_context.get_global_rank()

input = get_input_activations(microbatch_idx, partition_idx)
output = get_output_activations(microbatch_idx, partition_idx, self.is_scheduled)

if input.requires_grad is False and partition_idx != 0:
if pipeline_context.is_first_stage is False and input.requires_grad is False:
# NOTE: the input of the first pipeline stage is the input of the model
# which we don't need to compute gradients for
raise PipelineGradientFlowError(
f"Please set .requires_grad = True to input activations. Gradients can't flow back to the input of the pipeline stage, rank={rank}, microbatch_idx={microbatch_idx}, partition_idx={partition_idx}"
)

def is_last_microbatch(microbatch_idx):
return microbatch_idx == pipeline_context.num_microbatches - 1

# if not is_last_microbatch(microbatch_idx):
# output = output.detach().requires_grad_(True)

if rank == 3 and microbatch_idx == 4:
assert 1 == 1

# new_output = output.detach().requires_grad_(True)
torch.autograd.backward(output, grad_tensors=prev_grad)
torch.autograd.backward(output, grad_tensors=prev_grad, retain_graph=True)

if partition_idx == 0:
# NOTE: the first pipeline stage is the end of the backward pass
# no need to send the gradients to any other pipeline stage
return

if input.grad is None:
raise PipelineGradientFlowError(
"Gradients can't flow back to the input of the pipeline stage, rank={rank}, microbatch_idx={microbatch_idx}, partition_idx={partition_idx}"
)

# # TODO: remove this, since the grads is stored in module's weights
# # and we do gradient accumulation, we don't need return grads or send to other stages
# assert isinstance(input.grad, torch.Tensor)

# print(f"executing backward job, rank={rank}, microbatch_idx={microbatch_idx}, partition_idx={partition_idx}")
print(
f"rank={rank}, microbatch_idx={microbatch_idx}, partition_idx={partition_idx}, yay! gradients: {input.grad.shape}"
)
Expand Down
31 changes: 8 additions & 23 deletions pipegoose/nn/pipeline_parallel2/_job/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ConfirmCompleteATaskToProgressTracker,
CreateForwardOutputPackageCallback,
ForwardJob,
SaveBufferForBackwardCallback,
SendForwardPackageCallback,
)
from pipegoose.nn.pipeline_parallel2._job.job import Job
Expand Down Expand Up @@ -73,6 +74,7 @@ def create(
) -> ForwardJob:
callbacks = [
CreateForwardOutputPackageCallback(parallel_context, pipeline_context),
SaveBufferForBackwardCallback(),
ScheduleBackwardJobCallback(pipeline_context),
# SaveActivationIfTrainingCallback(),
# SaveInputActivationsCallback(),
Expand Down Expand Up @@ -205,9 +207,6 @@ def backward(ctx, grad_input: torch.Tensor) -> (None, torch.Tensor):
def _run_backward_execution(grad_input, metadata):
import torch.distributed as dist

def is_last_microbatch(microbatch_idx):
return microbatch_idx == pipeline_context.num_microbatches - 1

def backward_function(self):
pass

Expand All @@ -228,10 +227,8 @@ def backward_function(self):
if parallel_context.get_global_rank() == 0:
progress = get_progresses_from_pipeline_context(pipeline_context)
progress_tracker.initiate(progress)
print("new backward progress: ", progress)

rank = parallel_context.get_global_rank()
print(f"rank={rank}, running main thread backward pass")

dist.barrier()

Expand All @@ -240,9 +237,6 @@ def backward_function(self):

print(f"rank={rank}, entered clock_idx: {pipeline_context.clock_idx}")

if pipeline_context.clock_idx == 9 and pipeline_context.partition_idx == 0:
assert 1 == 1

if len(tasks) > 0:
for task in tasks:
microbatch_idx = task.microbatch_idx
Expand All @@ -253,10 +247,7 @@ def backward_function(self):
)

if pipeline_context.is_last_stage:
if is_last_microbatch(microbatch_idx):
package = Package(grad_input, metadata)
package.metadata.job_type = JobType.BACKWARD
else:
if pipeline_context.is_last_microbatch(microbatch_idx) is False:
from pipegoose.nn.pipeline_parallel2.queue import (
_SAVED_GRAD_LOSS,
_SAVED_METADATA_of_GRAD_LOSS,
Expand All @@ -265,26 +256,20 @@ def backward_function(self):
grad_input = _SAVED_GRAD_LOSS[(microbatch_idx, partition_idx)]
metadata = _SAVED_METADATA_of_GRAD_LOSS[(microbatch_idx, partition_idx)]

package = Package(grad_input, metadata)
package.metadata.job_type = JobType.BACKWARD
package = Package(grad_input, metadata)
package.metadata.job_type = JobType.BACKWARD
else:
from pipegoose.nn.pipeline_parallel2._comm import RECV_QUEUE

package = RECV_QUEUE.get()

rank = parallel_context.get_global_rank()
microbatch_idx = metadata.microbatch_idx

backward_job = create_job(backward_function, package, parallel_context, pipeline_context)
# NOTE: this is a bug, not consistent with the test cases
# backward_job.is_scheduled = False
print(f"rank={rank}, created backward job: microbatch_idx={microbatch_idx}, partition_idx={partition_idx}")

# NOTE : put the backward job to pending queue
JobQueue.PENDING_JOBS.put(backward_job)

microbatch_idx = metadata.microbatch_idx
print(f"rank={rank}, created backward job: microbatch_idx={microbatch_idx}, partition_idx={partition_idx}")

dist.barrier()

dist.barrier()

print("done")
80 changes: 22 additions & 58 deletions pipegoose/nn/pipeline_parallel2/_job/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,10 @@

class ForwardJob(Job):
def run_compute(self) -> torch.Tensor:
with torch.set_grad_enabled(True):
# pipeline_context = self.pipeline_context
# if pipeline_context.partition_idx > 0:
# assert 1 == 1
with torch.enable_grad():
output = self.function(self.input.data)

# from pipegoose.nn.pipeline_parallel2.queue import SavedActivation

# # TODO: refactor
# microbatch_idx = self.input.metadata.microbatch_idx
# partition_idx = self.input.metadata.partition_idx

# key = SavedActivation.get_key(microbatch_idx, partition_idx)
# # print("saving activation, data.shape=", self.job.output.data.shape)
# SavedActivation.save_activations(key, output)

return output
is_training = self.input.metadata.training.is_training
with torch.set_grad_enabled(is_training):
output = self.function(self.input.data)
return output


class CreateForwardOutputPackageCallback(Callback):
Expand All @@ -43,27 +29,28 @@ def __init__(self, parallel_context: ParallelContext, pipeline_context: Pipeline
self.pipeline_context = pipeline_context

def after_compute(self):
from pipegoose.nn.pipeline_parallel2.queue import save_input_activations

microbatch_idx = self.job.input.metadata.microbatch_idx
partition_idx = self.job.input.metadata.partition_idx
save_input_activations(self.job.input.data, microbatch_idx, partition_idx)

data = self.job.output
from pipegoose.nn.pipeline_parallel2.queue import save_output_activations

save_output_activations(data, microbatch_idx, partition_idx)
self._save_buffer_for_backward()

input_metadata = deepcopy(self.job.input.metadata)

package = Package(data, input_metadata)
package = Package(self.job.output, input_metadata)

if not self.pipeline_context.is_last_stage:
package = self._update_next_pipeline_stage(package)
package = self._update_src_and_dst_rank(package)

self.job.output = package

def _save_buffer_for_backward(self):
from pipegoose.nn.pipeline_parallel2.queue import (
save_input_activations,
save_output_activations,
)

microbatch_idx = self.job.input.metadata.microbatch_idx
partition_idx = self.job.input.metadata.partition_idx
save_input_activations(self.job.input.data, microbatch_idx, partition_idx)
save_output_activations(self.job.output, microbatch_idx, partition_idx)

def _update_next_pipeline_stage(self, package: Package) -> Package:
microbatch_idx = package.metadata.microbatch_idx

Expand All @@ -86,39 +73,16 @@ def _update_src_and_dst_rank(self, package: Package) -> Package:
return package


class SaveInputActivationsCallback(Callback):
"""Save the input activations for backward pass."""

class SaveBufferForBackwardCallback(Callback):
order = 1

def after_compute(self):
# from pipegoose.nn.pipeline_parallel2.queue import save_input_activations
def ater_compute(self):
# from pipegoose.nn.pipeline_parallel2.queue import save_input_activations, save_output_activations

# data = self.job.input.data
# microbatch_idx = self.job.input.metadata.microbatch_idx
# partition_idx = self.job.input.metadata.partition_idx
# save_input_activations(data, microbatch_idx, partition_idx)
pass


class SaveActivationIfTrainingCallback(Callback):
"""Save the activation of a forward job for backward pass if training."""

order = 2

def after_compute(self):
# is_training = self.job.input.metadata.training.is_training

# if is_training is True:
# from pipegoose.nn.pipeline_parallel2.queue import SavedActivation

# # TODO: refactor
# microbatch_idx = self.job.input.metadata.microbatch_idx
# partition_idx = self.job.input.metadata.partition_idx

# key = SavedActivation.get_key(microbatch_idx, partition_idx)
# print("saving activation, data.shape=", self.job.output.data.shape)
# SavedActivation.save_activations(key, self.job.output.data)
# save_input_activations(self.job.input.data, microbatch_idx, partition_idx)
# save_output_activations(self.job.output.data, microbatch_idx, partition_idx)
pass


Expand Down
3 changes: 3 additions & 0 deletions pipegoose/nn/pipeline_parallel2/pipeline_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,6 @@ def is_first_stage(self) -> bool:
@property
def is_last_stage(self) -> bool:
return is_last_stage(self.parallel_context)

def is_last_microbatch(self, microbatch_idx: int) -> bool:
return microbatch_idx == self.num_microbatches - 1
14 changes: 1 addition & 13 deletions pipegoose/nn/pipeline_parallel2/pipeline_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pipegoose.nn.pipeline_parallel2._package import Metadata, Package, TrainingMetadata
from pipegoose.nn.pipeline_parallel2._worker import BaseWorkerManager
from pipegoose.nn.pipeline_parallel2.pipeline_context import PipelineContext
from pipegoose.nn.pipeline_parallel2.queue import JobQueue, save_input_activations
from pipegoose.nn.pipeline_parallel2.queue import JobQueue
from pipegoose.nn.pipeline_parallel2.scheduler import BaseScheduler
from pipegoose.nn.pipeline_parallel2.sync.callback import Callback
from pipegoose.nn.pipeline_parallel2.sync.handshake import (
Expand All @@ -39,7 +39,6 @@ class PipelineEngine:
def __init__(
self,
module: nn.Module,
# partitioner: BasePartitioner,
scheduler: BaseScheduler,
worker_manager: BaseWorkerManager,
parallel_context: ParallelContext,
Expand All @@ -51,7 +50,6 @@ def __init__(
), f"parallel_context must be an instance of ParallelContext, got {type(parallel_context)}"

self.module = module
# self.partitioner = partitioner
self.scheduler = scheduler
self.worker_manager = worker_manager
self.parallel_context = parallel_context
Expand Down Expand Up @@ -116,8 +114,6 @@ def after_new_clock_cycle(self, progress, clock_idx):
else:
package = RECV_QUEUE.get()

save_input_activations(package.data, microbatch_idx=microbatch_idx, partition_idx=partition_idx)

job = create_job(self.partition_func, package, self.parallel_context, self.pipeline_context)
JobQueue.PENDING_JOBS.put(job)

Expand All @@ -132,17 +128,9 @@ def after_new_clock_cycle(self, progress, clock_idx):

for microbatch_idx in range(n_microbatches):
output = _SAVED_SCHEDULED_ACTIVATIONS[(microbatch_idx, self.pipeline_context.partition_idx)]
# outputs.append(get_output_activations(microbatch_idx, self.pipeline_context.partition_idx))
outputs.append(output)

# outputs = torch.cat(outputs, dim=0)
return outputs

# from pipegoose.nn.pipeline_parallel2.queue import _SAVED_ACTIVATIONS
# _SAVED_ACTIVATIONS[(0, 3)].sum().backward()

# import time
# time.sleep(10)
else:
output = _SAVED_SCHEDULED_ACTIVATIONS[(microbatch_idx, self.pipeline_context.partition_idx)]
return output
Expand Down
Loading

0 comments on commit 4593863

Please sign in to comment.