Skip to content

Commit

Permalink
Updating layer norm impl
Browse files Browse the repository at this point in the history
  • Loading branch information
szaman19 committed Jun 25, 2024
1 parent 9e44581 commit 021d2a7
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions src/layers/regularizers/layer_norm.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
// Produced at the Lawrence Livermore National Laboratory.
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
// the CONTRIBUTORS file. <[email protected]>
Expand Down Expand Up @@ -28,8 +28,8 @@
#include "layer_norm_kernels.cuh"
#include "lbann/comm_impl.hpp"
#include "lbann/layers/regularizers/layer_norm.hpp"
#include "lbann/optimizers/optimizer.hpp"
#include "lbann/layers/regularizers/layer_norm_impl.hpp"
#include "lbann/optimizers/optimizer.hpp"
#include "lbann/utils/gpu/helpers.hpp"

#ifdef LBANN_HAS_DISTCONV
Expand Down Expand Up @@ -556,12 +556,12 @@ void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::bp_compute()
template <typename TensorDataType, data_layout Layout, El::Device Device>
void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
{
#ifdef LBANN_HAS_DISTCONV
#ifdef LBANN_HAS_DISTCONV
if (this->distconv_enabled()) {
this->get_distconv_adapter().fp_compute();
return;
}
#endif // LBANN_HAS_DISTCONV
#endif // LBANN_HAS_DISTCONV

int weight_idx = 0;
const TensorDataType* scale_weights = nullptr;
Expand All @@ -575,6 +575,7 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
El::Int norm_size, global_norm_size, num_norm, norm_stride;
this->get_normdims(norm_size, global_norm_size, num_norm, norm_stride);

<<<<<<< HEAD
<<<<<<< HEAD
#ifdef LBANN_HAS_DISTCONV
if (this->distconv_enabled()) {
Expand All @@ -585,6 +586,8 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
=======
>>>>>>> f02146109 (Updated implementation with updating statistics tensors)
=======
>>>>>>> ecac28c9f (Updating layer norm impl)
fp_impl(*this->get_comm(),
this->m_epsilon,
norm_size,
Expand All @@ -601,13 +604,13 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
template <typename TensorDataType, data_layout Layout, El::Device Device>
void layer_norm_layer<TensorDataType, Layout, Device>::bp_compute()
{
#ifdef LBANN_HAS_DISTCONV
#ifdef LBANN_HAS_DISTCONV
if (this->distconv_enabled()) {
this->get_distconv_adapter().bp_compute();
return;
}
#endif // LBANN_HAS_DISTCONV
#endif // LBANN_HAS_DISTCONV
// Obtain optional buffers
const TensorDataType* scale_weights = nullptr;
TensorDataType* scale_grad = nullptr;
Expand Down

0 comments on commit 021d2a7

Please sign in to comment.