Skip to content

Commit

Permalink
Updated implementation with updating statistics tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
szaman19 committed Dec 4, 2023
1 parent 7e16367 commit f021461
Show file tree
Hide file tree
Showing 6 changed files with 367 additions and 202 deletions.
19 changes: 7 additions & 12 deletions include/lbann/layers/regularizers/distconv/distconv_layer_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#ifndef LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM
#define LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM

#if LBANN_HAS_DISTCONV
#ifdef LBANN_HAS_DISTCONV

namespace distconv {
template <typename Backend, typename DataType>
Expand All @@ -39,21 +39,17 @@ class LayerNormalization
using DCTensor = tensor::Tensor<DataType, LocaleMPI, Allocator>;

public:
LayerNormalization(Backend& backend,
Datatype epsilon,
size_t max_mini_batch_size)
: m_backend(backend),
m_epsilon(epsilon),
m_max_mini_batch_size(max_mini_batch_size)
LayerNormalization(Backend& backend, DataType epsilon)
: m_backend(backend), m_epsilon(epsilon)
{}

template <typename Allocator>
void calculate_forward_stats(const DCTensor<Allocator>& input,
DC<Allocator>& statistics);
DCTensor<Allocator>& statistics);

template <typename Allocator>
void apply_normalization(const DCTensor<Allocator>& input,
const DCTensor<Allocator>& statistics,
DCTensor<Allocator>& statistics,
DCTensor<Allocator>& output);

template <typename Allocator>
Expand All @@ -74,10 +70,9 @@ class LayerNormalization

private:
DataType m_epsilon;
size_t m_max_mini_batch_size;

}; // class definition LayerNorm
} // namespace distconv

#endif // LBANN_HAS_DISTONV
#endif // LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM
#endif // LBANN_HAS_DISTCONV
#endif // LBANN_LAYERS_REGULARIZERS_DISTCONV_LAYER_NORM
23 changes: 9 additions & 14 deletions include/lbann/layers/regularizers/layer_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,20 @@
#include <memory>

#ifdef LBANN_HAS_DISTCONV
#include "lbann/utils/distconv.hpp"
#include "lbann/layers/data_type_distconv_adapter.hpp"
#include "lbann/layers/regularizers/distconv/distconv_layer_norm.hpp"
#include "lbann/utils/distconv.hpp"
#endif // LBANN_HAS_DISTCONV

namespace lbann {

#ifdef LBANN_HAS_DISTCONV
namespace dc {
using Shape = ::distconv::tensor::Shape;
using Backend= ::distconv::BackendDNNLib;
using Shape = ::distconv::tensor::Shape;
using Backend = ::distconv::BackendDNNLib;
template <typename TensorDataType>
using LayerNormalization = ::distconv::LayerNormalization<Backend, TensorDataType>;
using LayerNormalization =
::distconv::LayerNormalization<Backend, TensorDataType>;
} // namespace dc

template <typename TensorDataType, data_layout Layout, El::Device Device>
Expand All @@ -67,12 +68,10 @@ class layer_norm_distconv_adapter

void setup_distributions(tensor_overlap_constraints& constraints) override;
void setup_layer(size_t workspace_capacity) override;
void setup_fp_tensors() override;
void setup_bp_tensors() override;

void fp_compute();
void bp_compute();

TensorDevType m_statistics;
TensorDevType m_statistics_grad;
std::unique_ptr<dc::LayerNormalization<TensorDataType>> m_layer_norm_operator;
Expand Down Expand Up @@ -419,13 +418,9 @@ ::get_distconv_adapter(){
// Scatter DistConv Adapter implementation
// =============================================================

#endif // LBANN_HAS_DISTCONV

LBANN_DEFINE_LAYER_BUILDER(layer_norm);

// =========================================================
// Explicit template instantiation
// =========================================================
// =========================================================
// Explicit template instantiation
// =========================================================

#ifndef LBANN_LAYER_NORM_LAYER_INSTANTIATE
#define PROTO_DEVICE(T, Device) \
Expand Down
88 changes: 41 additions & 47 deletions include/lbann/layers/regularizers/layer_norm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
#include "lbann/layers/data_type_distconv_adapter.hpp"
#endif

namespace lbann{

namespace lbann {

// =========================================================
// Implementation
Expand Down Expand Up @@ -135,7 +134,6 @@ void layer_norm_layer<TensorDataType, Layout, Device>::setup_data(
m_statistics_gradient.reset(AbsDistMatrixType::Instantiate(dist));
}


#ifdef LBANN_HAS_DISTCONV

// =============================================================
Expand Down Expand Up @@ -174,57 +172,53 @@ layer_norm_layer<TensorDataType, Layout, Device>::get_distconv_adapter()
layer_norm_distconv_adapter<TensorDataType, Layout, Device>&>(
static_cast<const layer_norm_layer<TensorDataType, Layout, Device>&>(*this)
.get_distconv_adapter());
}

// =============================================================
// LayerNorm DistConv Adapter implementation
// =============================================================

template <typename TensorDataType, data_layout Layout, El::Device Device>
void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::
setup_distributions(tensor_overlap_constraints & constraints)
{
data_type_distconv_adapter<TensorDataType>::setup_distributions(
constraints);
// no overlap needed
for (auto& d : this->m_prev_activations_dists) {
d.clear_overlap();
constraints.mark_updated(d);
constraints.mark_invariant(d);
}
for (auto& d : this->m_activations_dists) {
d.clear_overlap();
constraints.mark_updated(d);
constraints.mark_invariant(d);
}
for (auto& d : this->m_prev_error_signals_dists) {
d.clear_overlap();
constraints.mark_updated(d);
constraints.mark_invariant(d);
}
for (auto& d : this->m_error_signals_dists) {
d.clear_overlap();
constraints.mark_updated(d);
constraints.mark_invariant(d);
}
template <typename TensorDataType, data_layout Layout, El::Device Device>
void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::
setup_distributions(tensor_overlap_constraints& constraints)
{
data_type_distconv_adapter<TensorDataType>::setup_distributions(constraints);
// no overlap needed
for (auto& d : this->m_prev_activations_dists) {
d.clear_overlap();
constraints.mark_updated(d);
constraints.mark_invariant(d);
}

template <typename TensorDataType, data_layout Layout, El::Device Device>
void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::setup_layer(
size_t workspace_capacity)
{
data_type_distconv_adapter<TensorDataType>::setup_layer(workspace_capacity);
auto& layer = dynamic_cast<
channelwise_fully_connected_layer<TensorDataType, Layout, Device>&>(
this->layer());
const auto max_mini_batch_size =
layer.get_model()->m_max_mini_batch_size_distconv;

m_layer_norm_operator =
make_unique<dc::LayerNormalization<TensorDataType>>(dc::get_backend(),
layer.m_epsilon,
max_mini_batch_size);
for (auto& d : this->m_activations_dists) {
d.clear_overlap();
constraints.mark_updated(d);
constraints.mark_invariant(d);
}
for (auto& d : this->m_prev_error_signals_dists) {
d.clear_overlap();
constraints.mark_updated(d);
constraints.mark_invariant(d);
}
for (auto& d : this->m_error_signals_dists) {
d.clear_overlap();
constraints.mark_updated(d);
constraints.mark_invariant(d);
}
}

template <typename TensorDataType, data_layout Layout, El::Device Device>
void layer_norm_distconv_adapter<TensorDataType, Layout, Device>::setup_layer(
size_t workspace_capacity)
{
data_type_distconv_adapter<TensorDataType>::setup_layer(workspace_capacity);
auto& layer = dynamic_cast<layer_norm_layer<TensorDataType, Layout, Device>&>(
this->layer());

m_layer_norm_operator =
make_unique<dc::LayerNormalization<TensorDataType>>(dc::get_backend(),
layer.m_epsilon);
}

#endif LBANN_HAS_DISTCONV
#endif // LBANN_HAS_DISTCONV
} // namespace lbann
#endif // LBANN_LAYER_REGULARIZER_LAYER_NORM_IMPL_HPP_INCLUDED
Loading

0 comments on commit f021461

Please sign in to comment.