From 32a761bd11552478a1cdf6db3372d0aea4f0703b Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 12 Feb 2024 10:07:24 -0800 Subject: [PATCH] Do not backpropagate through layers without gradient sources (#2423) --- include/lbann/layers/layer.hpp | 2 + include/lbann/models/model.hpp | 18 ++++- src/callbacks/check_gradients.cpp | 4 +- src/models/model.cpp | 123 +++++++++++++++++++++--------- 4 files changed, 105 insertions(+), 42 deletions(-) diff --git a/include/lbann/layers/layer.hpp b/include/lbann/layers/layer.hpp index 04ff80416f1..8a80aeaae57 100644 --- a/include/lbann/layers/layer.hpp +++ b/include/lbann/layers/layer.hpp @@ -297,6 +297,8 @@ class Layer template friend class kfac_block_gru; + friend class model; + public: /** @name Lifecycle */ ///@{ diff --git a/include/lbann/models/model.hpp b/include/lbann/models/model.hpp index 5fc7bdd2c80..91335abe35b 100644 --- a/include/lbann/models/model.hpp +++ b/include/lbann/models/model.hpp @@ -343,6 +343,11 @@ class model */ void setup_weights(); + /** @brief Tests whether a layer would be needed to compute through during + * backpropagation + */ + bool is_layer_needed_for_backprop(const Layer* l) const; + ///@} /** @name Subgraph parallelism implementation */ ///@{ @@ -561,6 +566,14 @@ class model /** @brief Current callbacks to process. */ std::vector> m_callbacks; + /** @brief A set of layers needed for backpropagation. + * @details This set is populated by model::forward_prop and controls + * which layers will be computed during backpropagation. If the + * `NO_BACKPROP_DISABLE` option is enabled, this set will not change the + * behavior of backpropagation. + */ + std::unordered_set m_needed_for_backprop; + /** @brief Is the model setup * @details Flag to indicate if the setup function has been called */ @@ -793,10 +806,7 @@ model::set_current_mini_batch_size(uint64_t mini_batch_size) noexcept return; } -inline bool model::is_amp_enabled() const noexcept -{ - return m_amp_enabled; -} +inline bool model::is_amp_enabled() const noexcept { return m_amp_enabled; } inline EvalType model::get_amp_scale_factor() const noexcept { diff --git a/src/callbacks/check_gradients.cpp b/src/callbacks/check_gradients.cpp index 42f1627534d..1e9e0f2cb84 100644 --- a/src/callbacks/check_gradients.cpp +++ b/src/callbacks/check_gradients.cpp @@ -273,8 +273,8 @@ void check_gradients::do_check_gradients(model& m) const m.get_objective_function()->differentiate(); m.get_objective_function()->compute_weight_regularization(); - // Compute analytical gradients through model - m.backward_prop(false, /*skip_callbacks=*/true); + // Compute all analytical gradients through model + m.backward_prop(/*compute_weight_grads_only=*/false, /*skip_callbacks=*/true); // Choose finite difference step // Note: Consider a central difference scheme: diff --git a/src/models/model.cpp b/src/models/model.cpp index e4c38f8b257..6ed719817dd 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -467,6 +467,34 @@ void model::serialize_to_onnx(onnx::ModelProto& mp) } #endif // LBANN_HAS_ONNX +bool model::is_layer_needed_for_backprop(const Layer* l) const +{ + // First, check parents. If one of the parent layers require gradients, + // this layer does too. + for (int i = 0; i < l->get_num_parents(); ++i) { + if (m_needed_for_backprop.find(&l->get_parent_layer(i)) != + m_needed_for_backprop.end()) { + return true; + } + } + + // Second, check the layer itself. If frozen, backprop is not necessary. + if (l->is_frozen()) { + return false; + } + + // Otherwise, check weight optimizers. If one of the associated optimizers + // is not nullptr, then backprop will be necessary. + for (size_t i = 0; i < l->num_weights(); ++i) { + if (l->get_weights(i).get_optimizer() != nullptr) { + return true; + } + } + + // Not needed for backprop + return false; +} + // ============================================= // Model specification // ============================================= @@ -929,15 +957,12 @@ void model::setup_subcommunicators(const std::vector& fngrids) for (El::Int node = 0; node < num_layers; ++node) { Layer* const layer = layers[node]; std::string const& layer_type = layer->get_type(); - if ((layer_type == "slice" || - layer_type == "split" || - layer_type == "concatenate" || - layer_type == "sum") && + if ((layer_type == "slice" || layer_type == "split" || + layer_type == "concatenate" || layer_type == "sum") && layer->subgraph_parallelism_execution()) { if (subCommunicatorsSubgrids.find(one_index) != subCommunicatorsSubgrids.end()) { - layer->reset_inter_subgrid_vc_comm( - subCommunicatorsSubgrids[one_index]); + layer->reset_inter_subgrid_vc_comm(subCommunicatorsSubgrids[one_index]); } else { subCommunicatorsSubgrids[one_index] = std::make_shared(); @@ -945,8 +970,7 @@ void model::setup_subcommunicators(const std::vector& fngrids) int indexSubgrid = -1; for (int child = 0; child < layer->get_num_children(); ++child) { - if (fngrids.at(childs[child]->get_grid_tag())->InGrid()) - { + if (fngrids.at(childs[child]->get_grid_tag())->InGrid()) { indexSubgrid = child; } } @@ -955,9 +979,17 @@ void model::setup_subcommunicators(const std::vector& fngrids) const int layer_tag = layer->get_grid_tag(); if (child_tag < 0) - LBANN_ERROR("child_tag=", child_tag, " (child=", childs[indexSubgrid]->get_name(), ")"); + LBANN_ERROR("child_tag=", + child_tag, + " (child=", + childs[indexSubgrid]->get_name(), + ")"); if (layer_tag < 0) - LBANN_ERROR("layer_tag=", layer_tag, " (layer=", layer->get_name(), ")"); + LBANN_ERROR("layer_tag=", + layer_tag, + " (layer=", + layer->get_name(), + ")"); const int posInSubGrid = fngrids[child_tag]->VCRank(); const int posInGrid = fngrids[layer_tag]->ViewingRank(); @@ -966,15 +998,13 @@ void model::setup_subcommunicators(const std::vector& fngrids) posInGrid, *subCommunicatorsSubgrids[one_index]); - layer->reset_inter_subgrid_vc_comm( - subCommunicatorsSubgrids[one_index]); + layer->reset_inter_subgrid_vc_comm(subCommunicatorsSubgrids[one_index]); } } if (layer_type == "cross_grid_sum" || layer_type == "cross_grid_sum_slice") { - layer->reset_inter_subgrid_vc_comm( - subCommunicatorsSubgrids[one_index]); + layer->reset_inter_subgrid_vc_comm(subCommunicatorsSubgrids[one_index]); } } } @@ -1321,14 +1351,15 @@ void model::add_split_layers(std::unordered_set& layer_names) l.get_data_layout(), l.get_device_allocation()); -#define PROTO_DEVICE_LAYOUT(T_datatype, T_layout, T_device) \ - if (args == args_tuple(std::type_index(typeid(T_datatype)), T_layout, T_device)) { \ - split.reset(new split_layer(m_comm)); \ - } +#define PROTO_DEVICE_LAYOUT(T_datatype, T_layout, T_device) \ + if (args == \ + args_tuple(std::type_index(typeid(T_datatype)), T_layout, T_device)) { \ + split.reset(new split_layer(m_comm)); \ + } -#define PROTO_DEVICE(T_datatype, T_device) \ - PROTO_DEVICE_LAYOUT(T_datatype, data_layout::DATA_PARALLEL, T_device); \ - PROTO_DEVICE_LAYOUT(T_datatype, data_layout::MODEL_PARALLEL, T_device); +#define PROTO_DEVICE(T_datatype, T_device) \ + PROTO_DEVICE_LAYOUT(T_datatype, data_layout::DATA_PARALLEL, T_device); \ + PROTO_DEVICE_LAYOUT(T_datatype, data_layout::MODEL_PARALLEL, T_device); #include "lbann/macros/instantiate_device.hpp" #undef PROTO_DEVICE_LAYOUT @@ -1449,7 +1480,7 @@ void model::remove_layer(std::string const& removable_layer_name) auto& parent = const_cast(l.get_parent_layer(0)); // assuming only one parent auto& child = - const_cast(l.get_child_layer(0)); // assuming only one child + const_cast(l.get_child_layer(0)); // assuming only one child // Setup relationship between parent layer and child layer child.replace_parent_layer(l.get_parent_layer_pointer(0), @@ -1501,7 +1532,7 @@ void model::replace_layer(OwningLayerPtr&& new_layer, auto& parent = const_cast(l.get_parent_layer(0)); // assuming only one parent auto& child = - const_cast(l.get_child_layer(0)); // assuming only one child + const_cast(l.get_child_layer(0)); // assuming only one child // Setup relationship between the new layer and child of old layer (which // becomes child of new layer) @@ -1582,6 +1613,9 @@ void model::forward_prop(execution_mode mode, bool skip_callbacks) // Clear activations in reference counter m_activation_refcnt.clear(); + // Clear layers that will be required in backpropagation + m_needed_for_backprop.clear(); + for (El::Int i = 0; i < get_num_layers(); ++i) { auto& l = get_layer(i); @@ -1605,6 +1639,9 @@ void model::forward_prop(execution_mode mode, bool skip_callbacks) if (!skip_callbacks) do_layer_forward_prop_end_cbs(mode, &l); } + + if (is_layer_needed_for_backprop(&l)) + m_needed_for_backprop.insert(&l); } if (!skip_callbacks) do_model_forward_prop_end_cbs(mode); @@ -1627,8 +1664,19 @@ void model::backward_prop(bool compute_weight_grads_only, bool skip_callbacks) // Perform backward prop step on current layer auto& l = get_layer(i); - bool enable_layer = (!envvar_disable_layers || - disabled_layers.find(&l) == disabled_layers.end()); + + // Check if layer should be skipped + bool enable_layer = true; + if (envvar_disable_layers) { + // Based on backpropagation requirements + if (disabled_layers.find(&l) != disabled_layers.end()) + enable_layer = false; + + // Based on gradient/optimizer requirements + if (compute_weight_grads_only && m_needed_for_backprop.size() > 0 && + m_needed_for_backprop.find(&l) == m_needed_for_backprop.end()) + enable_layer = false; + } // Check if all children skip gradient backpropagation if (enable_layer && envvar_disable_layers) { @@ -1746,25 +1794,27 @@ void model::update_weights() ++m_amp_cur_skipped_steps; // Keep scale factor to the smallest positive normalized value for // floats. Even when EvalType is double, we may cast to float. - m_amp_scale_factor = std::max( - static_cast(std::numeric_limits::min()), - m_amp_scale_factor * m_amp_backoff_factor); + m_amp_scale_factor = + std::max(static_cast(std::numeric_limits::min()), + m_amp_scale_factor * m_amp_backoff_factor); // Warn if we've been skipping too many steps. // Check exact number to avoid printing repeatedly. if (m_amp_cur_skipped_steps == 10) { - LBANN_WARNING( - "AMP skipped ten steps in a row, your model may have issues with AMP"); + LBANN_WARNING("AMP skipped ten steps in a row, your model may have " + "issues with AMP"); } - } else { + } + else { if (m_amp_cur_steps + 1 == m_amp_growth_interval) { m_amp_cur_steps = 0; m_amp_cur_skipped_steps = 0; // Prevent scale factor from overflowing to inf when cast to // float. - m_amp_scale_factor = std::min( - static_cast(std::numeric_limits::max()), - m_amp_scale_factor * m_amp_growth_factor); - } else { + m_amp_scale_factor = + std::min(static_cast(std::numeric_limits::max()), + m_amp_scale_factor * m_amp_growth_factor); + } + else { ++m_amp_cur_steps; } } @@ -1809,7 +1859,8 @@ void model::reconcile_weight_values() void model::enable_amp(EvalType init_scale_factor, EvalType growth_factor, EvalType backoff_factor, - size_t growth_interval) { + size_t growth_interval) +{ m_amp_enabled = true; m_amp_scale_factor = init_scale_factor; m_amp_growth_factor = growth_factor;