Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not backpropagate through layers without gradient sources #2423

Merged
merged 2 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/lbann/layers/layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ class Layer
template <hydrogen::Device Device>
friend class kfac_block_gru;

friend class model;

public:
/** @name Lifecycle */
///@{
Expand Down
18 changes: 14 additions & 4 deletions include/lbann/models/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ class model
*/
void setup_weights();

/** @brief Tests whether a layer would be needed to compute through during
* backpropagation
*/
bool is_layer_needed_for_backprop(const Layer* l) const;

///@}
/** @name Subgraph parallelism implementation */
///@{
Expand Down Expand Up @@ -561,6 +566,14 @@ class model
/** @brief Current callbacks to process. */
std::vector<std::shared_ptr<callback_base>> m_callbacks;

/** @brief A set of layers needed for backpropagation.
* @details This set is populated by model::forward_prop and controls
* which layers will be computed during backpropagation. If the
* `NO_BACKPROP_DISABLE` option is enabled, this set will not change the
* behavior of backpropagation.
*/
std::unordered_set<const Layer*> m_needed_for_backprop;

/** @brief Is the model setup
* @details Flag to indicate if the setup function has been called
*/
Expand Down Expand Up @@ -793,10 +806,7 @@ model::set_current_mini_batch_size(uint64_t mini_batch_size) noexcept
return;
}

inline bool model::is_amp_enabled() const noexcept
{
return m_amp_enabled;
}
inline bool model::is_amp_enabled() const noexcept { return m_amp_enabled; }

inline EvalType model::get_amp_scale_factor() const noexcept
{
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks/check_gradients.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ void check_gradients::do_check_gradients(model& m) const
m.get_objective_function()->differentiate();
m.get_objective_function()->compute_weight_regularization();

// Compute analytical gradients through model
m.backward_prop(false, /*skip_callbacks=*/true);
// Compute all analytical gradients through model
m.backward_prop(/*compute_weight_grads_only=*/false, /*skip_callbacks=*/true);

// Choose finite difference step
// Note: Consider a central difference scheme:
Expand Down
123 changes: 87 additions & 36 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,34 @@ void model::serialize_to_onnx(onnx::ModelProto& mp)
}
#endif // LBANN_HAS_ONNX

bool model::is_layer_needed_for_backprop(const Layer* l) const
{
// First, check parents. If one of the parent layers require gradients,
// this layer does too.
for (int i = 0; i < l->get_num_parents(); ++i) {
if (m_needed_for_backprop.find(&l->get_parent_layer(i)) !=
m_needed_for_backprop.end()) {
return true;
}
}

// Second, check the layer itself. If frozen, backprop is not necessary.
if (l->is_frozen()) {
return false;
}

// Otherwise, check weight optimizers. If one of the associated optimizers
// is not nullptr, then backprop will be necessary.
for (size_t i = 0; i < l->num_weights(); ++i) {
if (l->get_weights(i).get_optimizer() != nullptr) {
return true;
}
}

// Not needed for backprop
return false;
}

// =============================================
// Model specification
// =============================================
Expand Down Expand Up @@ -929,24 +957,20 @@ void model::setup_subcommunicators(const std::vector<El::Grid*>& fngrids)
for (El::Int node = 0; node < num_layers; ++node) {
Layer* const layer = layers[node];
std::string const& layer_type = layer->get_type();
if ((layer_type == "slice" ||
layer_type == "split" ||
layer_type == "concatenate" ||
layer_type == "sum") &&
if ((layer_type == "slice" || layer_type == "split" ||
layer_type == "concatenate" || layer_type == "sum") &&
layer->subgraph_parallelism_execution()) {
if (subCommunicatorsSubgrids.find(one_index) !=
subCommunicatorsSubgrids.end()) {
layer->reset_inter_subgrid_vc_comm(
subCommunicatorsSubgrids[one_index]);
layer->reset_inter_subgrid_vc_comm(subCommunicatorsSubgrids[one_index]);
}
else {
subCommunicatorsSubgrids[one_index] = std::make_shared<El::mpi::Comm>();
const auto& childs = layer->get_child_layers();

int indexSubgrid = -1;
for (int child = 0; child < layer->get_num_children(); ++child) {
if (fngrids.at(childs[child]->get_grid_tag())->InGrid())
{
if (fngrids.at(childs[child]->get_grid_tag())->InGrid()) {
indexSubgrid = child;
}
}
Expand All @@ -955,9 +979,17 @@ void model::setup_subcommunicators(const std::vector<El::Grid*>& fngrids)
const int layer_tag = layer->get_grid_tag();

if (child_tag < 0)
LBANN_ERROR("child_tag=", child_tag, " (child=", childs[indexSubgrid]->get_name(), ")");
LBANN_ERROR("child_tag=",
child_tag,
" (child=",
childs[indexSubgrid]->get_name(),
")");
if (layer_tag < 0)
LBANN_ERROR("layer_tag=", layer_tag, " (layer=", layer->get_name(), ")");
LBANN_ERROR("layer_tag=",
layer_tag,
" (layer=",
layer->get_name(),
")");

const int posInSubGrid = fngrids[child_tag]->VCRank();
const int posInGrid = fngrids[layer_tag]->ViewingRank();
Expand All @@ -966,15 +998,13 @@ void model::setup_subcommunicators(const std::vector<El::Grid*>& fngrids)
posInGrid,
*subCommunicatorsSubgrids[one_index]);

layer->reset_inter_subgrid_vc_comm(
subCommunicatorsSubgrids[one_index]);
layer->reset_inter_subgrid_vc_comm(subCommunicatorsSubgrids[one_index]);
}
}

if (layer_type == "cross_grid_sum" ||
layer_type == "cross_grid_sum_slice") {
layer->reset_inter_subgrid_vc_comm(
subCommunicatorsSubgrids[one_index]);
layer->reset_inter_subgrid_vc_comm(subCommunicatorsSubgrids[one_index]);
}
}
}
Expand Down Expand Up @@ -1321,14 +1351,15 @@ void model::add_split_layers(std::unordered_set<std::string>& layer_names)
l.get_data_layout(),
l.get_device_allocation());

#define PROTO_DEVICE_LAYOUT(T_datatype, T_layout, T_device) \
if (args == args_tuple(std::type_index(typeid(T_datatype)), T_layout, T_device)) { \
split.reset(new split_layer<T_datatype, T_layout, T_device>(m_comm)); \
}
#define PROTO_DEVICE_LAYOUT(T_datatype, T_layout, T_device) \
if (args == \
args_tuple(std::type_index(typeid(T_datatype)), T_layout, T_device)) { \
split.reset(new split_layer<T_datatype, T_layout, T_device>(m_comm)); \
}

#define PROTO_DEVICE(T_datatype, T_device) \
PROTO_DEVICE_LAYOUT(T_datatype, data_layout::DATA_PARALLEL, T_device); \
PROTO_DEVICE_LAYOUT(T_datatype, data_layout::MODEL_PARALLEL, T_device);
#define PROTO_DEVICE(T_datatype, T_device) \
PROTO_DEVICE_LAYOUT(T_datatype, data_layout::DATA_PARALLEL, T_device); \
PROTO_DEVICE_LAYOUT(T_datatype, data_layout::MODEL_PARALLEL, T_device);

#include "lbann/macros/instantiate_device.hpp"
#undef PROTO_DEVICE_LAYOUT
Expand Down Expand Up @@ -1449,7 +1480,7 @@ void model::remove_layer(std::string const& removable_layer_name)
auto& parent =
const_cast<Layer&>(l.get_parent_layer(0)); // assuming only one parent
auto& child =
const_cast<Layer&>(l.get_child_layer(0)); // assuming only one child
const_cast<Layer&>(l.get_child_layer(0)); // assuming only one child

// Setup relationship between parent layer and child layer
child.replace_parent_layer(l.get_parent_layer_pointer(0),
Expand Down Expand Up @@ -1501,7 +1532,7 @@ void model::replace_layer(OwningLayerPtr&& new_layer,
auto& parent =
const_cast<Layer&>(l.get_parent_layer(0)); // assuming only one parent
auto& child =
const_cast<Layer&>(l.get_child_layer(0)); // assuming only one child
const_cast<Layer&>(l.get_child_layer(0)); // assuming only one child

// Setup relationship between the new layer and child of old layer (which
// becomes child of new layer)
Expand Down Expand Up @@ -1582,6 +1613,9 @@ void model::forward_prop(execution_mode mode, bool skip_callbacks)
// Clear activations in reference counter
m_activation_refcnt.clear();

// Clear layers that will be required in backpropagation
m_needed_for_backprop.clear();

for (El::Int i = 0; i < get_num_layers(); ++i) {
auto& l = get_layer(i);

Expand All @@ -1605,6 +1639,9 @@ void model::forward_prop(execution_mode mode, bool skip_callbacks)
if (!skip_callbacks)
do_layer_forward_prop_end_cbs(mode, &l);
}

if (is_layer_needed_for_backprop(&l))
m_needed_for_backprop.insert(&l);
}
if (!skip_callbacks)
do_model_forward_prop_end_cbs(mode);
Expand All @@ -1627,8 +1664,19 @@ void model::backward_prop(bool compute_weight_grads_only, bool skip_callbacks)

// Perform backward prop step on current layer
auto& l = get_layer(i);
bool enable_layer = (!envvar_disable_layers ||
disabled_layers.find(&l) == disabled_layers.end());

// Check if layer should be skipped
bool enable_layer = true;
if (envvar_disable_layers) {
// Based on backpropagation requirements
if (disabled_layers.find(&l) != disabled_layers.end())
enable_layer = false;

// Based on gradient/optimizer requirements
if (compute_weight_grads_only && m_needed_for_backprop.size() > 0 &&
m_needed_for_backprop.find(&l) == m_needed_for_backprop.end())
enable_layer = false;
}

// Check if all children skip gradient backpropagation
if (enable_layer && envvar_disable_layers) {
Expand Down Expand Up @@ -1746,25 +1794,27 @@ void model::update_weights()
++m_amp_cur_skipped_steps;
// Keep scale factor to the smallest positive normalized value for
// floats. Even when EvalType is double, we may cast to float.
m_amp_scale_factor = std::max(
static_cast<EvalType>(std::numeric_limits<float>::min()),
m_amp_scale_factor * m_amp_backoff_factor);
m_amp_scale_factor =
std::max(static_cast<EvalType>(std::numeric_limits<float>::min()),
m_amp_scale_factor * m_amp_backoff_factor);
// Warn if we've been skipping too many steps.
// Check exact number to avoid printing repeatedly.
if (m_amp_cur_skipped_steps == 10) {
LBANN_WARNING(
"AMP skipped ten steps in a row, your model may have issues with AMP");
LBANN_WARNING("AMP skipped ten steps in a row, your model may have "
"issues with AMP");
}
} else {
}
else {
if (m_amp_cur_steps + 1 == m_amp_growth_interval) {
m_amp_cur_steps = 0;
m_amp_cur_skipped_steps = 0;
// Prevent scale factor from overflowing to inf when cast to
// float.
m_amp_scale_factor = std::min(
static_cast<EvalType>(std::numeric_limits<float>::max()),
m_amp_scale_factor * m_amp_growth_factor);
} else {
m_amp_scale_factor =
std::min(static_cast<EvalType>(std::numeric_limits<float>::max()),
m_amp_scale_factor * m_amp_growth_factor);
}
else {
++m_amp_cur_steps;
}
}
Expand Down Expand Up @@ -1809,7 +1859,8 @@ void model::reconcile_weight_values()
void model::enable_amp(EvalType init_scale_factor,
EvalType growth_factor,
EvalType backoff_factor,
size_t growth_interval) {
size_t growth_interval)
{
m_amp_enabled = true;
m_amp_scale_factor = init_scale_factor;
m_amp_growth_factor = growth_factor;
Expand Down
Loading