Skip to content

Commit

Permalink
Initial layer parallelism (#2342)
Browse files Browse the repository at this point in the history
* Quick implementation of basic layer parallelism

* Layer-parallel lenet! (probably remove before merge)

* Address comments from review -- BROKEN

* This just crashes, which is maybe better than a hang?

* Revert lenet changes; add layer-parallel lenet driver

* Fixes so layer-parallel lenet actually runs

* Apply suggestions from code review

Co-authored-by: Tal Ben-Nun <[email protected]>

* Address review concerns

* Add comment to lenet_lp.py

* Remove some questionable code

* Remove obsolete method and fix incorrect callbacks

* Fix misuse of rank API for matrix participation and avoid calling unnecessary collectives

* clang-format

* Fix a logic error in subgrid setup

* Improve definition of is_participating

* Remove incorrect early returns

* Fix the grid tag logic

* Fix is_participating for cross_grid_sum_slice

---------

Co-authored-by: Tal Ben-Nun <[email protected]>
Co-authored-by: Tal Ben-Nun <[email protected]>
  • Loading branch information
3 people authored Jan 24, 2024
1 parent 2756920 commit cebacc2
Show file tree
Hide file tree
Showing 25 changed files with 516 additions and 261 deletions.
100 changes: 100 additions & 0 deletions applications/vision/lenet_lp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Trains Lenet using 2 "grids" in a layer-parallel configuration.
"""

import argparse
import lbann
import data.mnist
import lbann.contrib.args
import lbann.contrib.launcher

# ----------------------------------
# Command-line arguments
# ----------------------------------

desc = ('Train LeNet on MNIST data using LBANN.')
parser = argparse.ArgumentParser(description=desc)
lbann.contrib.args.add_scheduler_arguments(parser, 'lbann_lenet')
args = parser.parse_args()

# ----------------------------------
# Construct layer graph
# ----------------------------------

# Input data
images = lbann.Input(data_field='samples', grid_tag=0)
labels = lbann.Input(data_field='labels', grid_tag=0)

# LeNet
x = lbann.Convolution(images,
num_dims = 2,
out_channels = 6,
groups = 1,
kernel_size = 5,
stride = 1,
dilation = 1,
has_bias = True, grid_tag=1)
x = lbann.Relu(x, grid_tag=1)
x = lbann.Pooling(x,
num_dims = 2,
pool_dims_i = 2,
pool_strides_i = 2,
pool_mode = "max", grid_tag=1)
x = lbann.Convolution(x,
num_dims = 2,
out_channels = 16,
groups = 1,
kernel_size = 5,
stride = 1,
dilation = 1,
has_bias = True, grid_tag=1)
x = lbann.Relu(x, grid_tag=1)
x = lbann.Pooling(x,
num_dims = 2,
pool_dims_i = 2,
pool_strides_i = 2,
pool_mode = "max", grid_tag=1)

x = lbann.FullyConnected(x, num_neurons = 120, has_bias = True, grid_tag=2)
x = lbann.Relu(x, grid_tag=2)
x = lbann.FullyConnected(x, num_neurons = 84, has_bias = True, grid_tag=2)
x = lbann.Relu(x, grid_tag=2)
x = lbann.FullyConnected(x, num_neurons = 10, has_bias = True, grid_tag=2)
probs = lbann.Softmax(x, grid_tag=2)

# Loss function and accuracy
loss = lbann.CrossEntropy(probs, labels, grid_tag=2)
acc = lbann.CategoricalAccuracy(probs, labels, grid_tag=2)

# ----------------------------------
# Setup experiment
# ----------------------------------

# Setup model
mini_batch_size = 64
num_epochs = 20
model = lbann.Model(num_epochs,
layers=lbann.traverse_layer_graph([images, labels]),
objective_function=loss,
metrics=[lbann.Metric(acc, name='accuracy', unit='%')],
callbacks=[lbann.CallbackPrintModelDescription(),
lbann.CallbackPrint(),
lbann.CallbackTimer()])

# Setup optimizer
opt = lbann.SGD(learn_rate=0.01, momentum=0.9)

# Setup data reader
data_reader = data.mnist.make_data_reader()

# Setup trainer
trainer = lbann.Trainer(mini_batch_size=mini_batch_size)

# ----------------------------------
# Run experiment
# ----------------------------------
kwargs = lbann.contrib.args.get_scheduler_kwargs(args)
lbann.contrib.launcher.run(trainer, model, data_reader, opt,
job_name=args.job_name,
lbann_args=['--num-subgrids', '2'],
**kwargs)
9 changes: 9 additions & 0 deletions include/lbann/layers/data_type_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class data_type_layer : public Layer
return TypeName<OutputTensorDataType>();
};

/** @brief Determine if we're participating in the compute on this process */
bool is_participating() const override;

/** Forward propagation step.
* Apply a mathematical operation to input tensors to obtain output
* tensors.
Expand Down Expand Up @@ -248,6 +251,12 @@ class data_type_layer : public Layer
*/
void setup_matrices(const std::vector<El::Grid*>& grids) override;

/** @brief Setup distributed matrices with "subgraph parallelism" logic */
virtual void do_setup_matrices_subgraph(std::vector<El::Grid*> const& grids);

/** @brief Setup all distributed matrices with given grid */
virtual void do_setup_matrices_simple(El::Grid const& grid);

/** Setup layer data.
* Called by the 'setup' function. Memory is allocated for
* distributed matrices.
Expand Down
46 changes: 44 additions & 2 deletions include/lbann/layers/layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,13 @@ class Layer
*/
virtual El::Device get_device_allocation() const = 0;

/** @brief Get whether this layer participates on this process.
*
* @note This is technically possible to implement here, but easier
* in data_type_layer.
*/
virtual bool is_participating() const = 0;

/** @brief Get expected number of parent layers.
* A negative value indicates no limit.
*/
Expand Down Expand Up @@ -666,9 +673,18 @@ class Layer
/** Get reference to LBANN communicator. */
lbann_comm* get_comm() const;

/** @brief Identifying tag for process grid */
/** @name Layer parallelism interface */
///@{
/** @brief Get the "layer parallelism" grid tag. */
int grid_tag() const noexcept;
/** @brief Set the "layer parallelism" grid tag. */
void grid_tag(int tag);
///@}

/// @todo Unify Layer-Parallel and Subgraph-Parallel implementations
/** @brief Identifying tag for process grid (subgraph parallelism) */
int get_grid_tag() const noexcept;
/** @brief Set process grid */
/** @brief Set process grid (subgraph parallelism) */
void set_grid_tag(int tag);

/** @name Hint layer access functions */
Expand Down Expand Up @@ -885,6 +901,32 @@ class Layer
*/
bool m_runs_inplace = false;

/** @name Layer parallelism */
///@{

/** @brief The tag used to choose the grid.
*
* During model setup, this will be checked. If it has not been set
* (i.e., it is "-1"), then it will be chosen to match its parents
* (which must all be on the same grid -- "transitional" layers
* must be explicitly marked).
*
* Temporary: While the legacy "subgraph parallelism"
* infrastructure coexists, setup will also check that this
* and the subgraph-related "m_grid_tag" are not both set. If using
* "subgraph", every layer will leave this as "-1" and the grid
* setup will proceed according to the legacy subgraph setup.
*
* After setup, this is guaranteed to be >= 0, except when the
* legacy "subgraph" codepath is being used. The actual @c Grid
* object can be retrieved through the activation matrices. These
* are guaranteed to be assigned the "layer parallelism" grid, if
* any.
*/
int m_lp_grid_tag = -1;

///@}

// -------------------------------------------------------
// Objects for sub-grid parallelism
// -------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion include/lbann/layers/transform/cross_grid_sum_slice.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class cross_grid_sum_slice_layer : public data_type_layer<TensorDataType>
this->setup_reference_counter(output);
}
}

bool is_participating() const final { return true; }
void bp_setup_gradient_wrt_inputs() override
{
auto children = this->get_child_layers();
Expand Down
3 changes: 2 additions & 1 deletion include/lbann/layers/transform/split.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class split_layer : public data_type_layer<TensorDataType>

void fp_setup_outputs() override
{

const auto& input = this->get_prev_activations();
auto mini_batch_size =
this->infer_mini_batch_size_from_parents_or_default_to_current();
Expand All @@ -129,6 +128,8 @@ class split_layer : public data_type_layer<TensorDataType>
El::VC,
El::ELEMENT,
Dev> const*>(&input);
LBANN_ASSERT(ptr_input);

int tag = 0;
auto childs = this->get_child_layers();
if (this->get_communication_flag() == COLL_OPT) {
Expand Down
8 changes: 0 additions & 8 deletions include/lbann/optimizers/data_type_optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,6 @@ class data_type_optimizer
/** @name Gradient update management */
///@{

/** @brief Get the full objective function gradient w.r.t. the weights,
* synchronized across all ranks in the trainer.
*
* A collective operation (allreduce or allgather) may be launched and/or
* synchronized if needed.
*/
std::unique_ptr<AbsDistMatrixType> get_gradient();

/** @brief Get the raw objective function gradient w.r.t. the weights,
* synchronized across all ranks in the trainer. This may be a local sharded
* version and not contain the gradients of all weights, but it will
Expand Down
26 changes: 0 additions & 26 deletions include/lbann/optimizers/data_type_optimizer_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,32 +93,6 @@ auto data_type_optimizer<TensorDataType>::get_weights() const
return *m_weights;
}

template <typename TensorDataType>
auto data_type_optimizer<TensorDataType>::get_gradient()
-> std::unique_ptr<AbsDistMatrixType>
{
auto matrix_dist = std::get<2>(this->get_matrix_info());

// Create a new matrix with the correct value distribution (usually STAR_STAR)
// and copy the values from there.
std::unique_ptr<AbsDistMatrixType> result;
result.reset(AbsDistMatrixType::Instantiate(*matrix_dist.grid,
matrix_dist.root,
matrix_dist.colDist,
matrix_dist.rowDist,
El::ELEMENT,
matrix_dist.device));

// If the gradient is not sharded, return a view
if (!this->is_sharded()) {
El::LockedView(*result, *m_gradient);
}
else {
El::Copy(*m_gradient, *result);
}
return result;
}

template <typename TensorDataType>
auto data_type_optimizer<TensorDataType>::get_gradient_sharded()
-> AbsDistMatrixType&
Expand Down
12 changes: 12 additions & 0 deletions include/lbann/optimizers/optimizer_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ class GradientHelperImpl : public optimizer::GradientHelper

void start_sync(lbann_comm& comm) override
{
if (!global_gradient_->Participating()) {
return;
}

// Complete outstanding synchronization of the same data type
static GradientHelperImpl<TensorDataType>* lastsync = nullptr;
if (lastsync != nullptr) {
Expand Down Expand Up @@ -148,6 +152,10 @@ class GradientHelperImpl : public optimizer::GradientHelper

void complete_sync(lbann_comm& comm) override
{
if (!global_gradient_->Participating()) {
return;
}

switch (this->get_status()) {
case optimizer_gradient_status::sync_started:
comm.wait(sync_req_);
Expand Down Expand Up @@ -287,6 +295,10 @@ void optimizer::accumulate_all_gradient_contributions(
using AbsDistMatType = El::AbstractDistMatrix<TensorDataType>;
static const TensorDataType one = TensorDataType(1.f);

if (!gradient.Participating()) {
return;
}

// There are a few cases to note here:
// 1. One update of the same type.
// 2. One update of a different type.
Expand Down
10 changes: 5 additions & 5 deletions include/lbann/utils/summary_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
#ifndef LBANN_SUMMARY_IMPL_HPP_INCLUDED
#define LBANN_SUMMARY_IMPL_HPP_INCLUDED

#include "lbann/utils/summary.hpp"
#include "lbann/utils/profiling.hpp"
#include "lbann/utils/summary.hpp"

namespace lbann {

Expand All @@ -48,7 +48,7 @@ lbann_summary::reduce_mean(const std::string tag,
El::DistData mat_format(mat);
if (mat_format.colDist == El::STAR && mat_format.rowDist == El::STAR) {
// Compute local sum on master process if matrix is Star,Star
if (m_comm->am_trainer_master()) {
if (mat.RedundantRank() == 0) {
sum = local_sum(mat.LockedMatrix());
}
}
Expand Down Expand Up @@ -105,7 +105,7 @@ lbann_summary::reduce_stdev(const std::string tag,
El::DistData mat_format(mat);
if (mat_format.colDist == El::STAR && mat_format.rowDist == El::STAR) {
// Compute local sums on master process if matrix is Star,Star
if (m_comm->am_trainer_master()) {
if (mat.RedundantRank() == 0) {
local_sum_sqsum(mat.LockedMatrix(), sum, sqsum);
}
}
Expand All @@ -129,7 +129,7 @@ template <typename TensorDataType>
inline void
lbann_summary::reduce_scalar(const std::string tag, TensorDataType s, int step)
{
if (m_comm->am_trainer_master()) {
if (mat.RedundantRank() == 0) {
m_pending_scalars.emplace_back(tag, step, s);
}
}
Expand Down Expand Up @@ -166,7 +166,7 @@ inline void lbann_summary::reduce_histogram(
El::DistData mat_format(mat);
if (mat_format.colDist == El::STAR && mat_format.rowDist == El::STAR) {
// Compute local sums on master process if matrix is Star,Star
if (m_comm->am_trainer_master()) {
if (mat.RedundantRank() == 0) {
local_sum_sqsum(mat.LockedMatrix(), sum, sqsum);
}
}
Expand Down
17 changes: 8 additions & 9 deletions include/lbann/utils/tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ void do_tensor_copy(const BaseDistMat& src, El::AbstractDistMatrix<TDT>& tgt)
El::CopyAsync(src, tgt);
}
else {
if (src.DistData().grid == tgt.DistData().grid) {
El::Copy(src, tgt);
}
else {
utils::details::do_tensor_copy_between_grids(src, tgt);
}
El::Copy(src, tgt);
}
}

Expand Down Expand Up @@ -158,9 +153,13 @@ void view_or_copy_tensor(const BaseDistMat& src,

if (src.DistData() == tgt.DistData()) {
if (locked_view) {
El::LockedView(tgt, dynamic_cast<const El::AbstractDistMatrix<TDT>&>(src));
} else {
El::View(tgt, dynamic_cast<El::AbstractDistMatrix<TDT>&>(const_cast<BaseDistMat&>(src)));
El::LockedView(tgt,
dynamic_cast<const El::AbstractDistMatrix<TDT>&>(src));
}
else {
El::View(tgt,
dynamic_cast<El::AbstractDistMatrix<TDT>&>(
const_cast<BaseDistMat&>(src)));
}
}
else {
Expand Down
Loading

0 comments on commit cebacc2

Please sign in to comment.