Skip to content

Commit

Permalink
Support layer parallelism in transformer application
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Jan 29, 2024
1 parent d7c5780 commit f959ed7
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 7 deletions.
2 changes: 2 additions & 0 deletions applications/nlp/transformer/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def create_encoder_decoder_transformer(dataset, args: argparse.Namespace):
)

parallelism.apply_fsdp_allweights(result, args)
parallelism.apply_layer_parallelism(transformer, result, args)
return result


Expand Down Expand Up @@ -227,6 +228,7 @@ def create_causal_lm_decoder_transformer(dataset, embed_dim: int,
)

parallelism.apply_fsdp_allweights(result, args)
parallelism.apply_layer_parallelism(transformer, result, args)
return result


Expand Down
73 changes: 73 additions & 0 deletions applications/nlp/transformer/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import itertools
import lbann
import lbann.models.subgraph.transformer
import math
import re
from typing import Any, Dict, Optional, List, Tuple, Union

#############################################################################
Expand Down Expand Up @@ -186,6 +188,64 @@ def apply_subgraph_parallelism(
return sgmodule, extra_model_kwargs


#############################################################################
# Layer parallelism

lp_grids = None
def apply_layer_parallelism(module: lbann.models.Transformer,
model: lbann.Model, args: argparse.Namespace):
"""
Applies a model-parallel strategy on sequences of contiguous transformer
blocks, sometimes referred to as pipeline parallelism or layer parallelism.
:param module: Transformer module to take as reference for block counts.
:param model: The model to modify.
:param args: Command-line arguments.
:param layers: If not None, a list of integers representing which blocks
to apply model parallelism to.
"""
if not args.layer_parallel:
return

lp_count = args.lp_count
if args.lp_count == 0:
lp_count = args.nodes * args.procs_per_node

blocks = len(module.encoder) + len(module.decoder)

# Assign blocks to increasing grid tags
blocks_per_grid_tag = math.ceil(blocks / lp_count)
cur_grid_tag = 0

# Go over all layers in traversal order, applying grid tags in increasing order
last_block_id = -1
block_id = -1
total_block_id = 0
for layer in model.layers:
if layer.name.startswith('transformer_decoder'):
block_id = int(
re.search(r'transformer_decoder(\d+)_',
layer.name).groups(1)[0])
elif layer.name.startswith('transformer_encoder'):
block_id = int(
re.search(r'transformer_encoder(\d+)_',
layer.name).groups(1)[0])
if last_block_id != block_id:
if total_block_id % blocks_per_grid_tag == 0:
cur_grid_tag += 1
last_block_id = block_id
total_block_id += 1

# Apply layer parallelism
layer.grid_tag = { 'value': cur_grid_tag }

global lp_grids
lp_grids = cur_grid_tag

def get_layer_parallel_args() -> List[str]:
if lp_grids is not None:
return ['--num-subgrids', str(lp_grids)]

def add_transformer_parallelism_arguments(parser: argparse.Namespace,
subgraph: bool = True):

Expand Down Expand Up @@ -259,3 +319,16 @@ def add_transformer_parallelism_arguments(parser: argparse.Namespace,
action='store_true',
help='Apply Fully-Sharded Data-Parallelism (FSDP) and shard MLP weights'
)

#######################################
# Layer parallelism
parser.add_argument(
'--layer-parallel',
action='store_true',
help='Apply layer parallelism (also referred to as pipelining)')
parser.add_argument(
'--lp-count',
default=0,
type=int,
help='In layer parallelism, the number of portions to divide network to'
' (Default: divide evenly between all ranks)')
4 changes: 3 additions & 1 deletion applications/nlp/transformer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lbann.launcher.batch_script import BatchScript

import utils.paths
import parallelism


def construct_training_task(model: lbann.Model,
Expand Down Expand Up @@ -234,7 +235,8 @@ def make_batch_script(model: lbann.Model,
script.add_parallel_command([
lbann.lbann_exe(),
f'--prototext={protobuf_file}',
] + lbann.contrib.args.get_profile_args(args))
] + (lbann.contrib.args.get_profile_args(args) +
parallelism.get_layer_parallel_args()))
script.add_command('status=$?')
script.add_command('echo "Finished training at $(date)"')
script.add_command('exit ${status}')
Expand Down
12 changes: 6 additions & 6 deletions python/lbann/contrib/args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Helper functions to add common command-line arguments."""

from typing import Any
from typing import Any, List

import argparse
import shlex
Expand Down Expand Up @@ -250,10 +250,10 @@ def add_profiling_arguments(parser: argparse.ArgumentParser) -> None:
action='store_true',
default=False,
help='enable itemized memory usage analysis')
parser.add_argument('--profile-init',
parser.add_argument('--profile-noinit',
action='store_true',
default=False,
help='enable profiling initialization')
help='disable profiling initialization')
parser.add_argument('--caliper',
action='store_true',
default=False,
Expand Down Expand Up @@ -285,7 +285,7 @@ def create_profile_callbacks(args: argparse.Namespace) -> Any:
"""
try:
profile = args.profile
profile_init = not args.profile_init
profile_noinit = args.profile_noinit
memprofile = args.memory_profile
memprof_verbose = args.memory_profile_verbose
except AttributeError:
Expand All @@ -294,15 +294,15 @@ def create_profile_callbacks(args: argparse.Namespace) -> Any:

result = []
if profile:
result.append(lbann.CallbackProfiler(skip_init=profile_init))
result.append(lbann.CallbackProfiler(skip_init=profile_noinit))
if memprofile:
result.append(lbann.CallbackMemoryProfiler(
detailed_first_step=memprof_verbose))

return result


def get_profile_args(args: argparse.Namespace) -> list[str]:
def get_profile_args(args: argparse.Namespace) -> List[str]:
"""Get LBANN command-line arguments for profiling.
The parsed arguments must be generated by an
Expand Down

0 comments on commit f959ed7

Please sign in to comment.