Skip to content

Commit

Permalink
Updating layer norm impl
Browse files Browse the repository at this point in the history
  • Loading branch information
szaman19 committed Jun 25, 2024
1 parent 01a361a commit ecac28c
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions src/layers/regularizers/layer_norm.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC.
// Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
// Produced at the Lawrence Livermore National Laboratory.
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
// the CONTRIBUTORS file. <[email protected]>
Expand Down Expand Up @@ -28,8 +28,8 @@
#include "layer_norm_kernels.cuh"
#include "lbann/comm_impl.hpp"
#include "lbann/layers/regularizers/layer_norm.hpp"
#include "lbann/optimizers/optimizer.hpp"
#include "lbann/layers/regularizers/layer_norm_impl.hpp"
#include "lbann/optimizers/optimizer.hpp"
#include "lbann/utils/gpu/helpers.hpp"

#ifdef LBANN_HAS_DISTCONV
Expand Down Expand Up @@ -654,12 +654,12 @@ void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::bp_compute()
template <typename TensorDataType, data_layout Layout, El::Device Device>
void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
{
#ifdef LBANN_HAS_DISTCONV
#ifdef LBANN_HAS_DISTCONV
if (this->distconv_enabled()) {
this->get_distconv_adapter().fp_compute();
return;
}
#endif // LBANN_HAS_DISTCONV
#endif // LBANN_HAS_DISTCONV

int weight_idx = 0;
const TensorDataType* scale_weights = nullptr;
Expand All @@ -671,7 +671,6 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
bias_weights =
this->weights_values(weight_idx).LockedMatrix().LockedBuffer();


fp_impl(*this->get_comm(),
this->m_epsilon,
this->get_prev_activations(),
Expand All @@ -684,13 +683,13 @@ void layer_norm_layer<TensorDataType, Layout, Device>::fp_compute()
template <typename TensorDataType, data_layout Layout, El::Device Device>
void layer_norm_layer<TensorDataType, Layout, Device>::bp_compute()
{
#ifdef LBANN_HAS_DISTCONV
#ifdef LBANN_HAS_DISTCONV
if (this->distconv_enabled()) {
this->get_distconv_adapter().bp_compute();
return;
}
#endif // LBANN_HAS_DISTCONV
#endif // LBANN_HAS_DISTCONV

// Obtain optional buffers
const TensorDataType* scale_weights = nullptr;
TensorDataType* scale_grad = nullptr;
Expand Down

0 comments on commit ecac28c

Please sign in to comment.