diff --git a/applications/nlp/transformer/parallelism.py b/applications/nlp/transformer/parallelism.py index 580e5d4caa3..0693935f617 100644 --- a/applications/nlp/transformer/parallelism.py +++ b/applications/nlp/transformer/parallelism.py @@ -263,7 +263,7 @@ def apply_layer_parallelism_postamble(model: lbann.Model, # Inject interim layers for each grid and reconnect for dst_grid, children in unique_grids.items(): interim = lbann.Identity(layer, grid_tag=dst_grid) - layers_to_insert.append((i, interim)) + layers_to_insert.append((i+1, interim)) # Reconnect parents for child in children: @@ -272,9 +272,9 @@ def apply_layer_parallelism_postamble(model: lbann.Model, cind = layer.children.index(child) new_children[cind] = interim - # Reconnect children + # Reconnect and condense children if unique_grids: - layer.children = new_children + layer.children = list(set(new_children)) # Add identity layers to the traversed graph right after the source layer # was computed