diff --git a/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize-validation.py b/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize-validation.py
new file mode 100644
index 00000000000..dc3dad824a0
--- /dev/null
+++ b/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize-validation.py
@@ -0,0 +1,34 @@
+from tqdm import trange
+from multiprocessing import Pool
+import numpy as np
+import pickle
+
+
+class Processor:
+
+ def __init__(self, total_threads: int):
+ self.threads = total_threads
+
+ def __call__(self, tid: int):
+ import thepile as dataset
+ num_samples = dataset.num_val_samples()
+ filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/val.bin'
+ len_filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/val-seqlen.bin'
+
+ with open(filename, 'ab') as fp:
+ with open(len_filename, 'ab') as slfp:
+ for i in trange(num_samples):
+ text = dataset.dataset_val[i]['text']
+ tokenized = dataset.tokenize(text)
+ sample = np.array(tokenized, dtype=np.uint16)
+ sample_len = np.array([len(sample)], dtype=np.uint32)
+ sample.tofile(fp)
+ sample_len.tofile(slfp)
+
+ print('Done')
+
+
+if __name__ == '__main__':
+ threads = 1
+ with Pool(threads) as pool:
+ pool.map(Processor(threads), range(threads))
diff --git a/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize.py b/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize.py
new file mode 100644
index 00000000000..90a811a8abb
--- /dev/null
+++ b/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize.py
@@ -0,0 +1,78 @@
+from tqdm import trange
+from multiprocessing import Pool
+import numpy as np
+import os
+import argparse
+from pathlib import Path
+
+
+class Processor:
+
+ def __init__(self, total_threads: int):
+ self.threads = total_threads
+
+ def __call__(self, tid: int):
+ import thepile as dataset
+ num_samples = dataset.num_train_samples()
+ np.random.seed(20231023)
+ indices = np.random.permutation(num_samples)
+ local_samples = num_samples // self.threads
+ offset = tid * local_samples
+ # Add remainder
+ if tid == self.threads - 1:
+ local_samples += num_samples % self.threads
+ section = indices[offset:offset + local_samples]
+ filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/train-pretokenized-{tid:02d}-of-{self.threads}.bin'
+ len_filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/train-seqlen-{tid:02d}-of-{self.threads}.bin'
+
+ # Create file
+ if not os.path.isfile(filename):
+ Path(filename).touch()
+ if not os.path.isfile(len_filename):
+ Path(len_filename).touch()
+
+ sz = os.path.getsize(len_filename)
+ assert sz % 4 == 0
+ sequences_processed = sz // 4
+ print(tid, ': Size in bytes:', sz, '. Sequences processed:',
+ sequences_processed)
+
+ with open(filename, 'ab') as fp:
+ with open(len_filename, 'ab') as slfp:
+ for i in trange(sequences_processed,
+ section.shape[0],
+ desc=f'Thread {tid}'):
+ text = dataset.dataset_train[int(section[i])]['text']
+ sample = dataset.tokenize(text)
+ sample = np.array(sample, dtype=np.uint16)
+ sample.tofile(fp)
+ sample_len = np.array([len(sample)], dtype=np.uint32)
+ sample_len.tofile(slfp)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('-j',
+ action='store',
+ default=0,
+ type=int,
+ help='Threads (default 0 = number of cores)')
+ parser.add_argument('-t',
+ action='store',
+ default=0,
+ type=int,
+ help='Total Chunks (default 0 = number of threads)')
+ parser.add_argument('-o',
+ action='store',
+ default=0,
+ type=int,
+ help='Chunk offset (default 0)')
+ args = parser.parse_args()
+
+ threads = args.j or os.cpu_count()
+ total_chunks = args.t or threads
+ offset = args.o
+ assert offset + threads <= total_chunks
+ with Pool(threads) as pool:
+ pool.map(Processor(total_chunks), range(offset, offset + threads))
diff --git a/applications/nlp/transformer/datasets/thepile.py b/applications/nlp/transformer/datasets/thepile.py
index ee82f01a910..8678e061afd 100644
--- a/applications/nlp/transformer/datasets/thepile.py
+++ b/applications/nlp/transformer/datasets/thepile.py
@@ -91,7 +91,7 @@ def get_train_sample(index):
def get_val_sample(index):
"""Token indices for a data sample from the validation set."""
- text = dataset_train[index]['text']
+ text = dataset_val[index]['text']
tokenized = tokenize(text)
# Trim long sequences, left-pad short sequences
@@ -120,3 +120,12 @@ def sample_dims():
def vocab_size():
return tokenizer.get_vocab_size()
+
+
+if __name__ == '__main__':
+ print('Training samples:', num_train_samples())
+ print('Validation samples:', num_val_samples())
+ print('Training sample 101:')
+ print(tokenizer.decode(get_train_sample(101)))
+ print('Validation sample 233:')
+ print(tokenizer.decode(get_val_sample(233)))
diff --git a/applications/nlp/transformer/datasets/thepile_pretokenized.py b/applications/nlp/transformer/datasets/thepile_pretokenized.py
index 89c490e2db0..65b7c7e80ad 100644
--- a/applications/nlp/transformer/datasets/thepile_pretokenized.py
+++ b/applications/nlp/transformer/datasets/thepile_pretokenized.py
@@ -1,5 +1,5 @@
"""
-The Pile dataset, stored as pre-tokenized binary files for optimized processing.
+The Pile dataset, stored as pre-tokenized, pre-packed binary files for optimized processing.
"""
import os
import os.path
@@ -10,7 +10,9 @@
# Options
# ----------------------------------------------
-sequence_length = int(os.getenv('THE_PILE_SEQUENCE_LENGTH', default='512'))
+# Sequence length is hardcoded to 512 in the pre-packed binary dataset.
+# To use other sequence lengths, see ``thepile_pretokenized_varlen.py``
+sequence_length = 512
# ----------------------------------------------
# Setup
diff --git a/applications/nlp/transformer/datasets/thepile_pretokenized_varlen.py b/applications/nlp/transformer/datasets/thepile_pretokenized_varlen.py
new file mode 100644
index 00000000000..71e26e0c117
--- /dev/null
+++ b/applications/nlp/transformer/datasets/thepile_pretokenized_varlen.py
@@ -0,0 +1,105 @@
+"""
+The Pile dataset, stored as pre-tokenized binary files for optimized processing.
+"""
+import os
+import os.path
+
+import numpy as np
+# ----------------------------------------------
+# Options
+# ----------------------------------------------
+
+sequence_length = int(os.getenv('THE_PILE_SEQUENCE_LENGTH', default='512'))
+
+# ----------------------------------------------
+# Setup
+# ----------------------------------------------
+
+# Load the datasets
+data_dir = os.getenv('THE_PILE_DATA_DIR',
+ '/p/vast1/data/datasets/the-pile-pretokenized')
+dataset_train = np.memmap(os.path.join(data_dir, 'train.bin'),
+ dtype=np.uint16,
+ mode='r')
+sample_lengths_train = np.fromfile(os.path.join(data_dir, 'train-seqlen.bin'),
+ dtype=np.uint32).astype(np.uint64)
+sample_offsets_train = np.zeros_like(sample_lengths_train)
+sample_offsets_train[1:] = np.cumsum(sample_lengths_train)[:-1]
+dataset_val = np.memmap(os.path.join(data_dir, 'val.bin'),
+ dtype=np.uint16,
+ mode='r')
+sample_lengths_val = np.fromfile(os.path.join(data_dir, 'val-seqlen.bin'),
+ dtype=np.uint32).astype(np.uint64)
+sample_offsets_val = np.zeros_like(sample_lengths_val)
+sample_offsets_val[1:] = np.cumsum(sample_lengths_val)[:-1]
+
+# Uses the definition from the GPT-NeoX-20B tokenizer
+pad_index = 1 # '<|padding|>'
+_vocab_size = 50277
+
+# ----------------------------------------------
+# Sample access functions
+# ----------------------------------------------
+
+
+def trim_and_pad(sample, random: bool):
+ # Trim long sequences
+ if len(sample) > sequence_length:
+ if random:
+ pos = np.random.rand()
+ offset = (len(sample) - sequence_length + 1) * pos
+ offset = int(np.floor(offset))
+ sample = sample[offset:offset + sequence_length]
+ else:
+ sample = sample[0:sequence_length]
+
+ # Left-pad short sequences
+ if len(sample) < sequence_length:
+ sample_pad = np.full(sequence_length, pad_index, dtype=np.int32)
+ if len(sample) > 0:
+ sample_pad[-len(sample):] = sample
+ return sample_pad
+
+ return sample
+
+
+def get_train_sample(index: int):
+ sample = np.copy(
+ dataset_train[sample_offsets_train[index]:sample_offsets_train[index] +
+ sample_lengths_train[index]]).astype(np.int32)
+ return trim_and_pad(sample, True)
+
+
+def get_val_sample(index):
+ sample = np.copy(
+ dataset_val[sample_offsets_val[index]:sample_offsets_val[index] +
+ sample_lengths_val[index]]).astype(np.int32)
+ return trim_and_pad(sample, False)
+
+
+def num_train_samples():
+ return sample_lengths_train.shape[0]
+
+
+def num_val_samples():
+ return sample_lengths_val.shape[0]
+
+
+def sample_dims():
+ return (sequence_length, )
+
+
+def vocab_size():
+ return _vocab_size
+
+
+if __name__ == '__main__':
+ print('Training samples:', num_train_samples())
+ print('Validation samples:', num_val_samples())
+ from tokenizers import Tokenizer
+ tokenizer = Tokenizer.from_file(
+ os.path.join(data_dir, '20B_tokenizer.json'))
+ print('Training sample 101:')
+ print(tokenizer.decode(get_train_sample(101)))
+ print('Validation sample 233:')
+ print(tokenizer.decode(get_val_sample(233)))
diff --git a/applications/nlp/transformer/modeling.py b/applications/nlp/transformer/modeling.py
index fc6ddf48bc8..013a8e180b9 100644
--- a/applications/nlp/transformer/modeling.py
+++ b/applications/nlp/transformer/modeling.py
@@ -87,6 +87,7 @@ def create_encoder_decoder_transformer(dataset, args: argparse.Namespace):
transformer, args)
parallelism.apply_ffn_model_parallelism(transformer, args)
parallelism.apply_fsdp_mlp(transformer, [embedding_weights], args)
+ parallelism.apply_layer_parallelism(transformer, args)
# Run through transformer
result = transformer(encoder_input, decoder_input, sequence_length - 1)
@@ -124,6 +125,7 @@ def create_encoder_decoder_transformer(dataset, args: argparse.Namespace):
)
parallelism.apply_fsdp_allweights(result, args)
+ parallelism.apply_layer_parallelism_postamble(result, args)
return result
@@ -186,6 +188,7 @@ def create_causal_lm_decoder_transformer(dataset, embed_dim: int,
transformer, args)
parallelism.apply_ffn_model_parallelism(transformer, args)
parallelism.apply_fsdp_mlp(transformer, [embedding_weights], args)
+ parallelism.apply_layer_parallelism(transformer, args)
# Run through transformer with the same sequence
result = transformer(decoder_input, decoder_input, sequence_length)
@@ -227,6 +230,7 @@ def create_causal_lm_decoder_transformer(dataset, embed_dim: int,
)
parallelism.apply_fsdp_allweights(result, args)
+ parallelism.apply_layer_parallelism_postamble(result, args)
return result
diff --git a/applications/nlp/transformer/parallelism.py b/applications/nlp/transformer/parallelism.py
index 7c56faf51eb..9c0c84ea50d 100644
--- a/applications/nlp/transformer/parallelism.py
+++ b/applications/nlp/transformer/parallelism.py
@@ -4,9 +4,12 @@
strategies found in this file.
"""
import argparse
+import collections
import itertools
import lbann
import lbann.models.subgraph.transformer
+import math
+import re
from typing import Any, Dict, Optional, List, Tuple, Union
#############################################################################
@@ -195,6 +198,108 @@ def apply_subgraph_parallelism(
return sgmodule, extra_model_kwargs
+#############################################################################
+# Layer parallelism
+
+lp_grids = None
+
+
+def apply_layer_parallelism(module: lbann.models.Transformer,
+ args: argparse.Namespace):
+ """
+ Applies a model-parallel strategy on sequences of contiguous transformer
+ blocks, sometimes referred to as pipeline parallelism or layer parallelism.
+
+ :param module: Transformer module to modify.
+ :param args: Command-line arguments.
+ """
+ if not args.layer_parallel:
+ return
+
+ lp_count = args.lp_count
+ if args.lp_count == 0:
+ lp_count = args.nodes * args.procs_per_node
+
+ blocks = len(module.encoder) + len(module.decoder)
+
+ # Assign blocks to increasing grid tags
+ blocks_per_grid_tag = math.ceil(blocks / lp_count)
+ cur_grid_tag = 0
+
+ # Go over all blocks, applying grid tags in increasing order
+ for i, block in enumerate(itertools.chain(module.encoder, module.decoder)):
+ cur_grid_tag = max(cur_grid_tag, (i // blocks_per_grid_tag) + 1)
+ block.extra_layer_args['grid_tag'] = cur_grid_tag
+
+ global lp_grids
+ lp_grids = cur_grid_tag
+
+
+def _get_grid_tag(tag: Union[int, Dict[str, int]]):
+ if isinstance(tag, dict):
+ return tag.get('value', 0)
+ return tag
+
+
+def apply_layer_parallelism_postamble(model: lbann.Model,
+ args: argparse.Namespace):
+ """
+ Applies post-model creation optimizations of the layer-parallel strategy
+ (see ``apply_layer_parallelism``).
+
+ :param model: LBANN Model to modify.
+ :param args: Command-line arguments.
+ """
+ if not args.layer_parallel:
+ return
+
+ # Loop over all layers that have multiple outgoing cross-grid edges
+ layers_to_insert = []
+ for i, layer in enumerate(model.layers):
+ if len(layer.children) == 1:
+ continue
+ tag = _get_grid_tag(layer.grid_tag)
+ unique_grids = collections.defaultdict(list)
+ new_children = []
+ for child in layer.children:
+ ctag = _get_grid_tag(child.grid_tag)
+ if ctag != tag:
+ unique_grids[ctag].append(child)
+ new_children.append(None)
+ else:
+ new_children.append(child)
+
+ # Inject interim layers for each grid and reconnect
+ for dst_grid, children in unique_grids.items():
+ interim = lbann.Identity(layer, grid_tag=dst_grid)
+ layers_to_insert.append((i+1, interim))
+
+ # Reconnect parents
+ for child in children:
+ pind = child.parents.index(layer)
+ child.parents[pind] = interim
+ cind = layer.children.index(child)
+ new_children[cind] = interim
+
+ # Reconnect and condense children
+ if unique_grids:
+ layer.children = list(set(new_children))
+
+ # Add identity layers to the traversed graph right after the source layer
+ # was computed
+ for i, l in reversed(layers_to_insert):
+ model.layers.insert(i, l)
+
+
+def get_layer_parallel_args() -> List[str]:
+ if lp_grids is not None:
+ return ['--num-subgrids', str(lp_grids)]
+ return []
+
+
+#############################################################################
+
+
def add_transformer_parallelism_arguments(parser: argparse.Namespace,
subgraph: bool = True):
@@ -277,3 +382,16 @@ def add_transformer_parallelism_arguments(parser: argparse.Namespace,
action='store_true',
help='Apply Fully-Sharded Data-Parallelism (FSDP) and shard MLP weights'
)
+
+ #######################################
+ # Layer parallelism
+ parser.add_argument(
+ '--layer-parallel',
+ action='store_true',
+ help='Apply layer parallelism (also referred to as pipelining)')
+ parser.add_argument(
+ '--lp-count',
+ default=0,
+ type=int,
+ help='In layer parallelism, the number of portions to divide network to'
+ ' (Default: divide evenly between all ranks)')
diff --git a/applications/nlp/transformer/trainer.py b/applications/nlp/transformer/trainer.py
index d34a894ecd1..395512eee91 100644
--- a/applications/nlp/transformer/trainer.py
+++ b/applications/nlp/transformer/trainer.py
@@ -12,6 +12,7 @@
from lbann.launcher.batch_script import BatchScript
import utils.paths
+import parallelism
def construct_training_task(model: lbann.Model,
@@ -238,7 +239,8 @@ def make_batch_script(model: lbann.Model,
script.add_parallel_command([
lbann.lbann_exe(),
f'--prototext={protobuf_file}',
- ] + lbann.contrib.args.get_profile_args(args))
+ ] + (lbann.contrib.args.get_profile_args(args) +
+ parallelism.get_layer_parallel_args()))
script.add_command('status=$?')
script.add_command('echo "Finished training at $(date)"')
script.add_command('exit ${status}')
diff --git a/python/lbann/contrib/args.py b/python/lbann/contrib/args.py
index d63acb0e8c6..a43de8a6f62 100644
--- a/python/lbann/contrib/args.py
+++ b/python/lbann/contrib/args.py
@@ -1,6 +1,6 @@
"""Helper functions to add common command-line arguments."""
-from typing import Any
+from typing import Any, List
import argparse
import shlex
@@ -250,10 +250,10 @@ def add_profiling_arguments(parser: argparse.ArgumentParser) -> None:
action='store_true',
default=False,
help='enable itemized memory usage analysis')
- parser.add_argument('--profile-init',
+ parser.add_argument('--profile-noinit',
action='store_true',
default=False,
- help='enable profiling initialization')
+ help='disable profiling initialization')
parser.add_argument('--caliper',
action='store_true',
default=False,
@@ -285,7 +285,7 @@ def create_profile_callbacks(args: argparse.Namespace) -> Any:
"""
try:
profile = args.profile
- profile_init = not args.profile_init
+ profile_noinit = args.profile_noinit
memprofile = args.memory_profile
memprof_verbose = args.memory_profile_verbose
except AttributeError:
@@ -294,7 +294,7 @@ def create_profile_callbacks(args: argparse.Namespace) -> Any:
result = []
if profile:
- result.append(lbann.CallbackProfiler(skip_init=profile_init))
+ result.append(lbann.CallbackProfiler(skip_init=profile_noinit))
if memprofile:
result.append(lbann.CallbackMemoryProfiler(
detailed_first_step=memprof_verbose))
@@ -302,7 +302,7 @@ def create_profile_callbacks(args: argparse.Namespace) -> Any:
return result
-def get_profile_args(args: argparse.Namespace) -> list[str]:
+def get_profile_args(args: argparse.Namespace) -> List[str]:
"""Get LBANN command-line arguments for profiling.
The parsed arguments must be generated by an
diff --git a/python/lbann/models/subgraph/transformer.py b/python/lbann/models/subgraph/transformer.py
index b0674f8ee87..e7291dc474b 100644
--- a/python/lbann/models/subgraph/transformer.py
+++ b/python/lbann/models/subgraph/transformer.py
@@ -1285,7 +1285,7 @@ def _subsequent_mask(self, size):
vals = np.triu(np.full((size, size), -1e9), k=1)
weights = lbann.Weights(
initializer=lbann.ValueInitializer(values=vals.flat),
- optimizer=None,
+ optimizer=lbann.NoOptimizer(),
name=f"{self.name}_mask{size}_weights",
)
self._subsequent_mask_cache[size] = lbann.WeightsLayer(
diff --git a/python/lbann/models/transformer.py b/python/lbann/models/transformer.py
index 95a10ffde35..63774869970 100644
--- a/python/lbann/models/transformer.py
+++ b/python/lbann/models/transformer.py
@@ -40,30 +40,33 @@ def __init__(self, normalized_shape, name=None, builtin=True):
name=f'{self.name}_bias',
)
- def forward(self, x):
+ def forward(self, x, **extra_kwargs):
if self.builtin:
return lbann.LayerNorm(x,
scale=True,
bias=True,
start_dim=-1,
name=self.name,
- weights=[self.weight, self.bias])
+ weights=[self.weight, self.bias],
+ **extra_kwargs)
# Normalization
- x = lbann.InstanceNorm(x)
+ x = lbann.InstanceNorm(x, **extra_kwargs)
# Affine transform
s = lbann.WeightsLayer(
weights=self.weight,
dims=[1] + list(make_iterable(self.normalized_shape)),
+ **extra_kwargs,
)
- s = lbann.Tessellate(s, hint_layer=x)
+ s = lbann.Tessellate(s, hint_layer=x, **extra_kwargs)
b = lbann.WeightsLayer(
weights=self.bias,
dims=[1] + list(make_iterable(self.normalized_shape)),
+ **extra_kwargs,
)
- b = lbann.Tessellate(b, hint_layer=x)
- x = lbann.Add(lbann.Multiply(s, x), b)
+ b = lbann.Tessellate(b, hint_layer=x, **extra_kwargs)
+ x = lbann.Add(lbann.Multiply(s, x, **extra_kwargs), b, **extra_kwargs)
return x
@@ -124,6 +127,7 @@ def __init__(
self.pre_layernorm = pre_layernorm
self.activation = activation
self.extra_ffn_args = {}
+ self.extra_layer_args = {}
# Module name
self.name = name
@@ -172,26 +176,27 @@ def forward(self, x, mask=None):
name = f'{self.name}_instance{self.instance}'
if self.pre_layernorm:
- y = self.norm1(x)
+ y = self.norm1(x, **self.extra_layer_args)
else:
y = x
# Self-attention with residual connection
- y = self.attention(y, y, y, mask=mask)
+ y = self.attention(y, y, y, mask=mask, **self.extra_layer_args)
if self.dropout_prob > 0:
y = lbann.Dropout(
y,
keep_prob=1 - self.dropout_prob,
name=f'{name}_drop1',
+ **self.extra_layer_args,
)
- z = lbann.Sum(x, y, name=f'{name}_sum1')
+ z = lbann.Sum(x, y, name=f'{name}_sum1', **self.extra_layer_args)
if not self.pre_layernorm:
- z = self.norm1(z)
+ z = self.norm1(z, **self.extra_layer_args)
x = z
# Feedforward network with residual connection
if self.pre_layernorm:
- y = self.norm2(z)
+ y = self.norm2(z, **self.extra_layer_args)
else:
y = x
@@ -200,14 +205,19 @@ def forward(self, x, mask=None):
weights=self.fc1_weights,
output_channel_dims=[self.feedforward_dim],
name=f'{name}_fc1',
+ **self.extra_layer_args,
**self.extra_ffn_args,
)
- y = self.activation(y, name=f'{name}_ffn_act', **self.extra_ffn_args)
+ y = self.activation(y,
+ name=f'{name}_ffn_act',
+ **self.extra_layer_args,
+ **self.extra_ffn_args)
if self.dropout_prob > 0:
y = lbann.Dropout(
y,
keep_prob=1 - self.dropout_prob,
name=f'{name}_drop2',
+ **self.extra_layer_args,
**self.extra_ffn_args,
)
y = lbann.ChannelwiseFullyConnected(
@@ -215,6 +225,7 @@ def forward(self, x, mask=None):
weights=self.fc2_weights,
output_channel_dims=[self.embed_dim],
name=f'{name}_fc2',
+ **self.extra_layer_args,
**self.extra_ffn_args,
)
if self.dropout_prob > 0:
@@ -222,11 +233,12 @@ def forward(self, x, mask=None):
y,
keep_prob=1 - self.dropout_prob,
name=f'{name}_drop3',
+ **self.extra_layer_args,
**self.extra_ffn_args,
)
- z = lbann.Sum(x, y, name=f'{name}_sum2')
+ z = lbann.Sum(x, y, name=f'{name}_sum2', **self.extra_layer_args)
if not self.pre_layernorm:
- z = self.norm2(z)
+ z = self.norm2(z, **self.extra_layer_args)
return z
@@ -288,6 +300,7 @@ def __init__(
self.pre_layernorm = pre_layernorm
self.activation = activation
self.extra_ffn_args = {}
+ self.extra_layer_args = {}
# Module name
self.name = name
@@ -350,22 +363,23 @@ def forward(self, x, memory, src_mask=None, tgt_mask=None):
name = f'{self.name}_instance{self.instance}'
if self.pre_layernorm:
- y = self.norm1(x)
+ y = self.norm1(x, **self.extra_layer_args)
else:
y = x
# Self-attention with residual connection
- y = self.attention1(y, y, y, mask=tgt_mask)
+ y = self.attention1(y, y, y, mask=tgt_mask, **self.extra_layer_args)
if self.dropout_prob > 0:
y = lbann.Dropout(
y,
keep_prob=1 - self.dropout_prob,
name=f'{name}_drop1',
+ **self.extra_layer_args,
)
- z = lbann.Sum(x, y, name=f'{name}_sum1')
+ z = lbann.Sum(x, y, name=f'{name}_sum1', **self.extra_layer_args)
if not self.pre_layernorm:
- z = self.norm1(z)
+ z = self.norm1(z, **self.extra_layer_args)
x = z
@@ -373,27 +387,30 @@ def forward(self, x, memory, src_mask=None, tgt_mask=None):
if memory is not None:
# Attention on encoder output with residual connection
if self.pre_layernorm:
- y = self.norm2(x)
+ y = self.norm2(x, **self.extra_layer_args)
else:
y = x
- y = self.attention2(y, memory, memory, mask=src_mask)
+ y = self.attention2(y,
+ memory,
+ memory,
+ mask=src_mask,
+ **self.extra_layer_args)
if self.dropout_prob > 0:
- y = lbann.Dropout(
- y,
- keep_prob=1 - self.dropout_prob,
- name=f'{name}_drop2',
- )
- z = lbann.Sum(x, y, name=f'{name}_sum2')
+ y = lbann.Dropout(y,
+ keep_prob=1 - self.dropout_prob,
+ name=f'{name}_drop2',
+ **self.extra_layer_args)
+ z = lbann.Sum(x, y, name=f'{name}_sum2', **self.extra_layer_args)
if not self.pre_layernorm:
- z = self.norm2(z)
+ z = self.norm2(z, **self.extra_layer_args)
x = z
# Feedforward network with residual connection
if self.pre_layernorm:
- y = self.norm3(x)
+ y = self.norm3(x, **self.extra_layer_args)
else:
y = x
@@ -402,14 +419,19 @@ def forward(self, x, memory, src_mask=None, tgt_mask=None):
weights=self.fc1_weights,
output_channel_dims=[self.feedforward_dim],
name=f'{name}_fc1',
+ **self.extra_layer_args,
**self.extra_ffn_args,
)
- y = self.activation(y, name=f'{name}_ffn_act', **self.extra_ffn_args)
+ y = self.activation(y,
+ name=f'{name}_ffn_act',
+ **self.extra_layer_args,
+ **self.extra_ffn_args)
if self.dropout_prob > 0:
y = lbann.Dropout(
y,
keep_prob=1 - self.dropout_prob,
name=f'{name}_drop3',
+ **self.extra_layer_args,
**self.extra_ffn_args,
)
y = lbann.ChannelwiseFullyConnected(
@@ -417,6 +439,7 @@ def forward(self, x, memory, src_mask=None, tgt_mask=None):
weights=self.fc2_weights,
output_channel_dims=[self.embed_dim],
name=f'{name}_fc2',
+ **self.extra_layer_args,
**self.extra_ffn_args,
)
if self.dropout_prob > 0:
@@ -424,12 +447,13 @@ def forward(self, x, memory, src_mask=None, tgt_mask=None):
y,
keep_prob=1 - self.dropout_prob,
name=f'{name}_drop4',
+ **self.extra_layer_args,
**self.extra_ffn_args,
)
- z = lbann.Sum(x, y, name=f'{name}_sum3')
+ z = lbann.Sum(x, y, name=f'{name}_sum3', **self.extra_layer_args)
if not self.pre_layernorm:
- z = self.norm3(z)
+ z = self.norm3(z, **self.extra_layer_args)
return z
@@ -570,7 +594,7 @@ def _subsequent_mask(self, size):
weights = lbann.Weights(
initializer=lbann.ValueInitializer(values=vals.flat),
- optimizer=None,
+ optimizer=lbann.NoOptimizer(),
name=f'{self.name}_mask{size}_weights',
)
self._subsequent_mask_cache[size] = lbann.WeightsLayer(
diff --git a/python/lbann/modules/transformer/attention.py b/python/lbann/modules/transformer/attention.py
index cac70833325..b721d9a2e80 100644
--- a/python/lbann/modules/transformer/attention.py
+++ b/python/lbann/modules/transformer/attention.py
@@ -113,7 +113,13 @@ def __init__(self,
name=f'{self.name}_output_bias'),
]
- def forward(self, queries, keys, values, mask=None, seqlen=None):
+ def forward(self,
+ queries,
+ keys,
+ values,
+ mask=None,
+ seqlen=None,
+ **extra_kwargs):
"""Apply multi-head attention.
The input and output tensors are interpreted as sequences of
@@ -147,7 +153,8 @@ def forward(self, queries, keys, values, mask=None, seqlen=None):
output_channel_dims=[self.embed_dim * 3],
name=f'{name}_qkv_fc',
bias=True,
- transpose=False)
+ transpose=False,
+ **extra_kwargs)
# Unstack
qkv_slice = lbann.Slice(qkv_fc,
@@ -155,10 +162,11 @@ def forward(self, queries, keys, values, mask=None, seqlen=None):
slice_points=[
0, self.embed_dim, 2 * self.embed_dim,
3 * self.embed_dim
- ])
- queries_fc = lbann.Identity(qkv_slice)
- keys_fc = lbann.Identity(qkv_slice)
- values_fc = lbann.Identity(qkv_slice)
+ ],
+ **extra_kwargs)
+ queries_fc = lbann.Identity(qkv_slice, **extra_kwargs)
+ keys_fc = lbann.Identity(qkv_slice, **extra_kwargs)
+ values_fc = lbann.Identity(qkv_slice, **extra_kwargs)
else:
# Otherwise, apply fully-connected layers to input sequences separately
queries_fc = lbann.ChannelwiseFullyConnected(
@@ -166,47 +174,52 @@ def forward(self, queries, keys, values, mask=None, seqlen=None):
weights=self.query_weights,
output_channel_dims=[self.embed_dim],
name=f'{name}_queries_fc',
+ **extra_kwargs,
)
keys_fc = lbann.ChannelwiseFullyConnected(
keys,
weights=self.key_weights,
output_channel_dims=[self.embed_dim],
name=f'{name}_keys_fc',
+ **extra_kwargs,
)
values_fc = lbann.ChannelwiseFullyConnected(
values,
weights=self.value_weights,
output_channel_dims=[self.embed_dim],
name=f'{name}_values_fc',
+ **extra_kwargs,
)
if self.positional_encoding is not None:
queries_fc, keys_fc, values_fc = self.positional_encoding.apply_layer(
- queries_fc, keys_fc, values_fc, seqlen)
+ queries_fc, keys_fc, values_fc, seqlen, **extra_kwargs)
if self.separate_heads:
attentions = self.dot_product_attn_separate_heads(
- name, queries_fc, keys_fc, values_fc, mask)
+ name, queries_fc, keys_fc, values_fc, mask, **extra_kwargs)
else:
attentions = self.dot_product_attn_batched(name, queries_fc,
keys_fc, values_fc,
- mask)
+ mask, **extra_kwargs)
outputs_fc = lbann.ChannelwiseFullyConnected(
attentions,
weights=self.output_weights,
output_channel_dims=[self.embed_dim],
name=f'{name}',
+ **extra_kwargs,
)
return outputs_fc
def dot_product_attn_batched(self, name, queries_fc, keys_fc, values_fc,
- mask):
+ mask, **extra_kwargs):
head_name = f'{name}_all_heads'
queries_fc = lbann.Scale(
queries_fc,
constant=1 / math.sqrt(self.head_dim),
name=f'{head_name}_scale',
+ **extra_kwargs,
)
# Dimension key:
@@ -216,15 +229,24 @@ def dot_product_attn_batched(self, name, queries_fc, keys_fc, values_fc,
# * P = Head size
# SxE -> HxPxS
- q_headsfirst = lbann.TensorPermute(queries_fc, axes=(1, 0))
+ q_headsfirst = lbann.TensorPermute(queries_fc,
+ axes=(1, 0),
+ **extra_kwargs)
q_headsfirst = lbann.Reshape(q_headsfirst,
- dims=(self.num_heads, self.head_dim, -1))
- k_headsfirst = lbann.TensorPermute(keys_fc, axes=(1, 0))
+ dims=(self.num_heads, self.head_dim, -1),
+ **extra_kwargs)
+ k_headsfirst = lbann.TensorPermute(keys_fc,
+ axes=(1, 0),
+ **extra_kwargs)
k_headsfirst = lbann.Reshape(k_headsfirst,
- dims=(self.num_heads, self.head_dim, -1))
- v_headsfirst = lbann.TensorPermute(values_fc, axes=(1, 0))
+ dims=(self.num_heads, self.head_dim, -1),
+ **extra_kwargs)
+ v_headsfirst = lbann.TensorPermute(values_fc,
+ axes=(1, 0),
+ **extra_kwargs)
v_headsfirst = lbann.Reshape(v_headsfirst,
- dims=(self.num_heads, self.head_dim, -1))
+ dims=(self.num_heads, self.head_dim, -1),
+ **extra_kwargs)
# HxPxS -> HxSxS
y = lbann.MatMul(
@@ -233,24 +255,30 @@ def dot_product_attn_batched(self, name, queries_fc, keys_fc, values_fc,
transpose_a=True,
transpose_b=False,
name=f'{head_name}_matmul',
+ **extra_kwargs,
)
if mask:
- y = lbann.Add(y, mask, name=f'{head_name}_mask')
+ y = lbann.Add(y, mask, name=f'{head_name}_mask', **extra_kwargs)
if self.bias:
- y = lbann.Add(y, self.bias, name=f'{head_name}_attnbias')
+ y = lbann.Add(y,
+ self.bias,
+ name=f'{head_name}_attnbias',
+ **extra_kwargs)
y = lbann.ChannelwiseSoftmax(y,
dim=-1,
single_dim_mode=True,
- name=f'{head_name}_softmax')
+ name=f'{head_name}_softmax',
+ **extra_kwargs)
if self.dropout > 0:
y = lbann.Dropout(
y,
keep_prob=1 - self.dropout,
name=f'{head_name}_drop',
+ **extra_kwargs,
)
# Attention output as batched matrix multiplication
@@ -258,11 +286,16 @@ def dot_product_attn_batched(self, name, queries_fc, keys_fc, values_fc,
attentions = lbann.MatMul(y,
v_headsfirst,
transpose_b=True,
- name=head_name)
+ name=head_name,
+ **extra_kwargs)
# HxSxP -> SxE
- attentions = lbann.TensorPermute(attentions, axes=(1, 0, 2))
- attentions = lbann.Reshape(attentions, dims=(-1, self.embed_dim))
+ attentions = lbann.TensorPermute(attentions,
+ axes=(1, 0, 2),
+ **extra_kwargs)
+ attentions = lbann.Reshape(attentions,
+ dims=(-1, self.embed_dim),
+ **extra_kwargs)
return attentions
def _get_subgraph(self, tag_id: int = 0) -> Dict[str, int]:
@@ -279,7 +312,7 @@ def _get_subgraph(self, tag_id: int = 0) -> Dict[str, int]:
return dict(grid_tag=tag_id)
def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc,
- values_fc, mask):
+ values_fc, mask, **extra_kwargs):
# Slice embedding vectors for each head
slice_points = [self.head_dim * i for i in range(self.num_heads + 1)]
queries_slice = lbann.Slice(
@@ -288,6 +321,7 @@ def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc,
slice_points=slice_points,
name=f'{name}_queries_slice',
parallel_strategy=self._get_subgraph(),
+ **extra_kwargs,
)
keys_slice = lbann.Slice(
keys_fc,
@@ -295,6 +329,7 @@ def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc,
slice_points=slice_points,
name=f'{name}_keys_slice',
parallel_strategy=self._get_subgraph(),
+ **extra_kwargs,
)
values_slice = lbann.Slice(
values_fc,
@@ -302,6 +337,7 @@ def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc,
slice_points=slice_points,
name=f'{name}_values_slice',
parallel_strategy=self._get_subgraph(),
+ **extra_kwargs,
)
if self.subgraph_branches > 0 and mask is not None:
@@ -321,11 +357,14 @@ def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc,
# Attention inputs
q = lbann.Identity(queries_slice,
- parallel_strategy=self._get_subgraph(tag))
+ parallel_strategy=self._get_subgraph(tag),
+ **extra_kwargs)
k = lbann.Identity(keys_slice,
- parallel_strategy=self._get_subgraph(tag))
+ parallel_strategy=self._get_subgraph(tag),
+ **extra_kwargs)
v = lbann.Identity(values_slice,
- parallel_strategy=self._get_subgraph(tag))
+ parallel_strategy=self._get_subgraph(tag),
+ **extra_kwargs)
# Multiply queries and keys
# Note: num_queries x num_keys
@@ -334,36 +373,52 @@ def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc,
k,
transpose_b=True,
name=f'{head_name}_matmul',
+ **extra_kwargs,
)
y = lbann.Scale(y,
constant=1 / math.sqrt(self.head_dim),
- name=f'{head_name}_scale')
+ name=f'{head_name}_scale',
+ **extra_kwargs)
if mask:
if self.subgraph_branches > 0:
- y = lbann.Add(y, mask[tag - 1], name=f'{head_name}_mask')
+ y = lbann.Add(y,
+ mask[tag - 1],
+ name=f'{head_name}_mask',
+ **extra_kwargs)
else:
- y = lbann.Add(y, mask, name=f'{head_name}_mask')
+ y = lbann.Add(y,
+ mask,
+ name=f'{head_name}_mask',
+ **extra_kwargs)
if self.bias:
- y = lbann.Add(y, self.bias, name=f'{head_name}_attnbias')
+ y = lbann.Add(y,
+ self.bias,
+ name=f'{head_name}_attnbias',
+ **extra_kwargs)
- y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')
+ y = lbann.ChannelwiseSoftmax(y,
+ name=f'{head_name}_softmax',
+ **extra_kwargs)
if self.dropout > 0:
y = lbann.Dropout(
y,
keep_prob=1 - self.dropout,
name=f'{head_name}_drop',
+ **extra_kwargs,
)
# Attention output
# Note: num_queries x head_dim
- attentions.append(lbann.MatMul(y, v, name=head_name))
+ attentions.append(
+ lbann.MatMul(y, v, name=head_name, **extra_kwargs))
# Concatenate heads and apply fully-connected layer
attentions = lbann.Concatenation(
attentions,
axis=1,
name=f'{name}_heads_concat',
- parallel_strategy=self._get_subgraph())
+ parallel_strategy=self._get_subgraph(),
+ **extra_kwargs)
return attentions
diff --git a/python/lbann/modules/transformer/encoding.py b/python/lbann/modules/transformer/encoding.py
index cd56915002f..653ce6d6f9d 100644
--- a/python/lbann/modules/transformer/encoding.py
+++ b/python/lbann/modules/transformer/encoding.py
@@ -15,20 +15,22 @@ class SequenceEncoding:
the layer type and index.
"""
- def apply_input(self, x: lbann.Layer, length: int) -> lbann.Layer:
+ def apply_input(self, x: lbann.Layer, length: int,
+ **extra_kwargs) -> lbann.Layer:
"""
Applies sequence encoding on the input of a transformer, immediately
after token embedding.
:param x: The output of the embedded sequence minibatch.
:param length: Sequence length.
+ :param extra_kwargs: Additional arguments to pass to each internal Layer.
:return: Encoded input.
"""
return x # Do nothing
def apply_layer(
- self, q: lbann.Layer, k: lbann.Layer, v: lbann.Layer,
- length: int) -> Tuple[lbann.Layer, lbann.Layer, lbann.Layer]:
+ self, q: lbann.Layer, k: lbann.Layer, v: lbann.Layer, length: int,
+ **extra_kwargs) -> Tuple[lbann.Layer, lbann.Layer, lbann.Layer]:
"""
Applies sequence encoding within a transformer encoder/decoder layer.
Encoding is performed on each transformer layer's multi-head attention
@@ -38,6 +40,7 @@ def apply_layer(
:param k: The input keys of the transformer layer.
:param v: The input values of the transformer layer.
:param length: Sequence length.
+ :param extra_kwargs: Additional arguments to pass to each internal Layer.
:return: Encoded tuple of (q, k, v).
"""
return q, k, v # Do nothing
@@ -121,13 +124,14 @@ def _positional_encoding(self, sequence_length):
# Return cached positional encoding
return self._positional_encoding_cache[sequence_length]
- def apply_input(self, inputs, input_length):
+ def apply_input(self, inputs, input_length, **extra_kwargs):
self.instance += 1
result = lbann.Add(
inputs,
self._positional_encoding(input_length),
name=f'{self.name}_instance{self.instance}_peadd',
+ **extra_kwargs,
)
# Input dropout
@@ -136,6 +140,7 @@ def apply_input(self, inputs, input_length):
result,
keep_prob=1 - self.dropout_prob,
name=f'{self.name}_pedrop',
+ **extra_kwargs,
)
return result
@@ -182,7 +187,11 @@ def compute_embeddings(self):
embedding_dim=self.embed_dim,
)
- def apply_input(self, inputs, input_length, learned_encoding=None):
+ def apply_input(self,
+ inputs,
+ input_length,
+ learned_encoding=None,
+ **extra_kwargs):
self.instance += 1
if learned_encoding is None:
@@ -193,12 +202,14 @@ def apply_input(self, inputs, input_length, learned_encoding=None):
learned_encoding = lbann.Identity(
lbann.Slice(learned_encoding,
axis=0,
- slice_points=[0, input_length]))
+ slice_points=[0, input_length],
+ **extra_kwargs), **extra_kwargs)
result = lbann.Add(
inputs,
learned_encoding,
name=f'{self.name}_instance{self.instance}_peadd',
+ **extra_kwargs,
)
# Input dropout
@@ -207,6 +218,7 @@ def apply_input(self, inputs, input_length, learned_encoding=None):
result,
keep_prob=1 - self.dropout_prob,
name=f'{self.name}_pedrop',
+ **extra_kwargs,
)
return result
@@ -278,35 +290,43 @@ def _precompute_frequencies(self, sequence_length: int):
_make_constant_from_array(sin, f'rope_sin_{sequence_length}'),
)
- def _rotate_half(self, x: lbann.Layer, length: int):
+ def _rotate_half(self, x: lbann.Layer, length: int, **extra_kwargs):
"""
Helper method that rotates half of a tensor x.
"""
# SxE -> SxHxP
- r = lbann.Reshape(x, dims=(length, self.num_heads, self.dim))
- s = lbann.Slice(r, slice_points=[0, self.dim // 2, self.dim], axis=2)
- x1 = lbann.Identity(s)
- x2 = lbann.Identity(s)
- nx2 = lbann.Scale(x2, constant=-1)
- cat = lbann.Concatenation([nx2, x1], axis=2)
+ r = lbann.Reshape(x,
+ dims=(length, self.num_heads, self.dim),
+ **extra_kwargs)
+ s = lbann.Slice(r,
+ slice_points=[0, self.dim // 2, self.dim],
+ axis=2,
+ **extra_kwargs)
+ x1 = lbann.Identity(s, **extra_kwargs)
+ x2 = lbann.Identity(s, **extra_kwargs)
+ nx2 = lbann.Scale(x2, constant=-1, **extra_kwargs)
+ cat = lbann.Concatenation([nx2, x1], axis=2, **extra_kwargs)
# Reshape back to SxE
- return lbann.Reshape(cat, dims=(length, self.num_heads * self.dim))
+ return lbann.Reshape(cat,
+ dims=(length, self.num_heads * self.dim),
+ **extra_kwargs)
def _embed(self, x: lbann.Layer, length: int, sliced_cos: lbann.Layer,
- sliced_sin: lbann.Layer):
+ sliced_sin: lbann.Layer, **extra_kwargs):
"""
Helper method that applies rotary embeddings on a tensor x.
"""
- rot = self._rotate_half(x, length)
+ rot = self._rotate_half(x, length, **extra_kwargs)
return lbann.Add(
- lbann.Multiply(x, sliced_cos),
- lbann.Multiply(rot, sliced_sin),
+ lbann.Multiply(x, sliced_cos, **extra_kwargs),
+ lbann.Multiply(rot, sliced_sin, **extra_kwargs),
+ **extra_kwargs,
)
def apply_layer(
- self, q: lbann.Layer, k: lbann.Layer, v: lbann.Layer,
- length: int) -> Tuple[lbann.Layer, lbann.Layer, lbann.Layer]:
+ self, q: lbann.Layer, k: lbann.Layer, v: lbann.Layer, length: int,
+ **extra_kwargs) -> Tuple[lbann.Layer, lbann.Layer, lbann.Layer]:
# If length is not given, maximum sequence length is assumed
if length is None:
length = self.max_sequence_length
@@ -316,15 +336,21 @@ def apply_layer(
sliced_sin = self.sin
else:
sliced_cos = lbann.Identity(
- lbann.Slice(self.cos, slice_points=[0, length], axis=0))
+ lbann.Slice(self.cos,
+ slice_points=[0, length],
+ axis=0,
+ **extra_kwargs), **extra_kwargs)
sliced_sin = lbann.Identity(
- lbann.Slice(self.sin, slice_points=[0, length], axis=0))
+ lbann.Slice(self.sin,
+ slice_points=[0, length],
+ axis=0,
+ **extra_kwargs), **extra_kwargs)
- eq = self._embed(q, length, sliced_cos, sliced_sin)
- ek = self._embed(k, length, sliced_cos, sliced_sin)
+ eq = self._embed(q, length, sliced_cos, sliced_sin, **extra_kwargs)
+ ek = self._embed(k, length, sliced_cos, sliced_sin, **extra_kwargs)
if self.embed_values:
- ev = self._embed(v, length, sliced_cos, sliced_sin)
+ ev = self._embed(v, length, sliced_cos, sliced_sin, **extra_kwargs)
else:
ev = v
diff --git a/python/lbann/proto/serialize.py b/python/lbann/proto/serialize.py
index 63b3f7f60f5..37d9eca0c7f 100644
--- a/python/lbann/proto/serialize.py
+++ b/python/lbann/proto/serialize.py
@@ -80,3 +80,21 @@ def bin2text(infile: str, outfile: str):
f.write(
google.protobuf.text_format.MessageToString(
message, use_index_order=True).encode())
+
+
+def generic_load(filename: str):
+ """
+ Loads a .protobin or .prototext file.
+ """
+ try: # Try binary first
+ message = lbann_pb2.LbannPB()
+
+ # Read file
+ with open(filename, 'rb') as f:
+ message.ParseFromString(f.read())
+ except: # Try text
+ with open(filename, 'rb') as f:
+ message = google.protobuf.text_format.Parse(
+ f.read(), lbann_pb2.LbannPB())
+
+ return message
diff --git a/scripts/viz.py b/scripts/viz.py
index 09d7d86dd33..09217d0c405 100755
--- a/scripts/viz.py
+++ b/scripts/viz.py
@@ -2,10 +2,17 @@
"""Visualize an LBANN model's layer graph and save to file."""
import argparse
+import random
import re
import graphviz
-import google.protobuf.text_format
from lbann import lbann_pb2, layers_pb2
+from lbann.proto import serialize
+
+# Pastel rainbow (slightly shuffled) from colorkit.co
+palette = [
+ '#ffffff', '#a0c4ff', '#ffadad', '#fdffb6', '#caffbf', '#9bf6ff',
+ '#bdb2ff', '#ffc6ff', '#ffd6a5'
+]
# Parse command-line arguments
parser = argparse.ArgumentParser(
@@ -17,14 +24,14 @@
parser.add_argument('output',
action='store',
nargs='?',
- default='graph.pdf',
+ default='graph.dot',
type=str,
- help='output file (default: graph.pdf)')
+ help='output file (default: graph.dot)')
parser.add_argument('--file-format',
action='store',
- default='pdf',
+ default='dot',
type=str,
- help='output file format (default: pdf)',
+ help='output file format (default: dot)',
metavar='FORMAT')
parser.add_argument('--label-format',
action='store',
@@ -39,6 +46,10 @@
type=str,
help='Graphviz visualization scheme (default: dot)',
metavar='ENGINE')
+parser.add_argument('--color-cross-grid',
+ action='store_true',
+ default=False,
+ help='Highlight cross-grid edges')
args = parser.parse_args()
# Strip extension from filename
@@ -51,9 +62,7 @@
label_format = re.sub(r' |-|_', '', args.label_format.lower())
# Read prototext file
-proto = lbann_pb2.LbannPB()
-with open(args.input, 'r') as f:
- google.protobuf.text_format.Merge(f.read(), proto)
+proto = serialize.generic_load(args.input)
model = proto.model
# Construct graphviz graph
@@ -62,29 +71,36 @@
engine=args.graphviz_engine)
graph.attr('node', shape='rect')
+layer_to_grid_tag = {}
+
# Construct nodes in layer graph
layer_types = (set(layers_pb2.Layer.DESCRIPTOR.fields_by_name.keys()) - set([
'name', 'parents', 'children', 'datatype', 'data_layout',
'device_allocation', 'weights', 'freeze', 'hint_layer', 'top', 'bottom',
- 'type', 'motif_layer'
+ 'type', 'motif_layer', 'parallel_strategy', 'grid_tag'
]))
for l in model.layer:
# Determine layer type
- type = ''
+ ltype = ''
for _type in layer_types:
if l.HasField(_type):
- type = getattr(l, _type).DESCRIPTOR.name
+ ltype = getattr(l, _type).DESCRIPTOR.name
break
+ # If operator layer, use operator type
+ if ltype == 'OperatorLayer':
+ url = l.operator_layer.ops[0].parameters.type_url
+ ltype = url[url.rfind('.') + 1:]
+
# Construct node label
label = ''
if label_format == 'nameonly':
label = l.name
elif label_format == 'typeonly':
- label = type
+ label = ltype
elif label_format == 'typeandname':
- label = '<{0}
{1}>'.format(type, l.name)
+ label = '<{0}
{1}>'.format(ltype, l.name)
elif label_format == 'full':
label = '<'
for (index, line) in enumerate(str(l).strip().split('\n')):
@@ -94,14 +110,36 @@
label += '>'
# Add layer as layer graph node
- graph.node(l.name, label=label)
+ tag = l.grid_tag.value
+ layer_to_grid_tag[l.name] = tag
+ attrs = {}
+ if tag != 0:
+ attrs = dict(style='filled', fillcolor=palette[tag % len(palette)])
+ graph.node(l.name, label=label, **attrs)
# Add parent/child relationships as layer graph edges
edges = set()
+cross_grid_edges = set()
for l in model.layer:
- edges.update([(p, l.name) for p in l.parents.split()])
- edges.update([(l.name, c) for c in l.children.split()])
+ tag = layer_to_grid_tag[l.name]
+ for p in l.parents:
+ if tag != layer_to_grid_tag[p]:
+ cross_grid_edges.add((p, l.name))
+ else:
+ edges.add((p, l.name))
+
+ for c in l.children:
+ if tag != layer_to_grid_tag[c]:
+ cross_grid_edges.add((l.name, c))
+ else:
+ edges.add((l.name, c))
+
graph.edges(edges)
+if args.color_cross_grid:
+ for u, v in cross_grid_edges:
+ graph.edge(u, v, color='red')
+else:
+ graph.edges(cross_grid_edges)
# Save to file
graph.render(filename=filename, cleanup=True, format=file_format)
diff --git a/src/layers/transform/weights.cpp b/src/layers/transform/weights.cpp
index bac97570a71..e9d40e46504 100644
--- a/src/layers/transform/weights.cpp
+++ b/src/layers/transform/weights.cpp
@@ -157,15 +157,29 @@ void weights_layer::fp_compute()
// Duplicate weights across columns of output matrix
const auto& local_weights = this->weights_values(0).LockedMatrix();
- MatType ones;
- El::Ones(ones, local_output.Width(), 1);
- El::Gemm(El::NORMAL,
- El::TRANSPOSE,
- El::TypeTraits::One(),
- local_weights,
- ones,
- El::TypeTraits::Zero(),
- local_output);
+ if (local_output.Width() <= 32) { // The number 32 is a heuristic
+ // Use copies for broadcast
+ for (int i = 0; i < local_output.Width(); ++i) {
+ MatType v;
+ El::View(v,
+ local_output,
+ El::IR(0, local_weights.Height()),
+ El::IR(i, i + 1));
+ El::Copy(local_weights, v);
+ }
+ }
+ else {
+ // Use GEMM with ones for broadcast
+ MatType ones;
+ El::Ones(ones, local_output.Width(), 1);
+ El::Gemm(El::NORMAL,
+ El::TRANSPOSE,
+ El::TypeTraits::One(),
+ local_weights,
+ ones,
+ El::TypeTraits::Zero(),
+ local_output);
+ }
}
template
diff --git a/src/models/model.cpp b/src/models/model.cpp
index 6ed719817dd..719471e6f0f 100644
--- a/src/models/model.cpp
+++ b/src/models/model.cpp
@@ -1382,10 +1382,13 @@ void model::add_split_layers(std::unordered_set& layer_names)
split->set_name(name);
layer_names.insert(name);
- // Copy parallel strategy from parent.
+ // Copy parallel strategy and grid tag from parent.
ParallelStrategy& ps = split->get_parallel_strategy();
ParallelStrategy& orig_ps = l.get_parallel_strategy();
ps = orig_ps;
+ if (l.grid_tag() >= 0) {
+ split->grid_tag(l.grid_tag());
+ }
// Setup relationships between split layer and child layers
for (int j = 0; j < l.get_num_children(); ++j) {
@@ -1674,8 +1677,9 @@ void model::backward_prop(bool compute_weight_grads_only, bool skip_callbacks)
// Based on gradient/optimizer requirements
if (compute_weight_grads_only && m_needed_for_backprop.size() > 0 &&
- m_needed_for_backprop.find(&l) == m_needed_for_backprop.end())
+ m_needed_for_backprop.find(&l) == m_needed_for_backprop.end()) {
enable_layer = false;
+ }
}
// Check if all children skip gradient backpropagation
diff --git a/src/utils/options.cpp b/src/utils/options.cpp
index 4466d1aca20..b83509f22af 100644
--- a/src/utils/options.cpp
+++ b/src/utils/options.cpp
@@ -53,6 +53,7 @@ void construct_std_options()
arg_parser.add_flag(
LBANN_OPTION_DISABLE_SIGNAL_HANDLER,
{"--disable_signal_handler"},
+ utils::ENV("LBANN_DISABLE_SIGNAL_HANDLER"),
"[STD] Disables signal handling (signal handling on by default)");
arg_parser.add_flag(LBANN_OPTION_EXIT_AFTER_SETUP,
{"--exit_after_setup"},