diff --git a/include/lbann/layers/regularizers/distconv/distconv_layer_norm.hpp b/include/lbann/layers/regularizers/distconv/distconv_layer_norm.hpp index e2f7fda48b3..3206d7b87ac 100644 --- a/include/lbann/layers/regularizers/distconv/distconv_layer_norm.hpp +++ b/include/lbann/layers/regularizers/distconv/distconv_layer_norm.hpp @@ -27,7 +27,7 @@ #ifndef LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM #define LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM -#if LBANN_HAS_DISTCONV +#ifdef LBANN_HAS_DISTCONV namespace distconv { template @@ -39,21 +39,17 @@ class LayerNormalization using DCTensor = tensor::Tensor; public: - LayerNormalization(Backend& backend, - Datatype epsilon, - size_t max_mini_batch_size) - : m_backend(backend), - m_epsilon(epsilon), - m_max_mini_batch_size(max_mini_batch_size) + LayerNormalization(Backend& backend, DataType epsilon) + : m_backend(backend), m_epsilon(epsilon) {} template void calculate_forward_stats(const DCTensor& input, - DC& statistics); + DCTensor& statistics); template void apply_normalization(const DCTensor& input, - const DCTensor& statistics, + DCTensor& statistics, DCTensor& output); template @@ -74,10 +70,9 @@ class LayerNormalization private: DataType m_epsilon; - size_t m_max_mini_batch_size; }; // class definition LayerNorm } // namespace distconv -#endif // LBANN_HAS_DISTONV -#endif // LBANN_LAYERSE_REGULARIZERS_DISTCONV_LAYER_NORM \ No newline at end of file +#endif // LBANN_HAS_DISTCONV +#endif // LBANN_LAYERS_REGULARIZERS_DISTCONV_LAYER_NORM \ No newline at end of file diff --git a/include/lbann/layers/regularizers/layer_norm.hpp b/include/lbann/layers/regularizers/layer_norm.hpp index 7b4beac1f24..ce31e1d1e4f 100644 --- a/include/lbann/layers/regularizers/layer_norm.hpp +++ b/include/lbann/layers/regularizers/layer_norm.hpp @@ -36,19 +36,20 @@ #include #ifdef LBANN_HAS_DISTCONV -#include "lbann/utils/distconv.hpp" #include "lbann/layers/data_type_distconv_adapter.hpp" #include "lbann/layers/regularizers/distconv/distconv_layer_norm.hpp" +#include "lbann/utils/distconv.hpp" #endif // LBANN_HAS_DISTCONV namespace lbann { #ifdef LBANN_HAS_DISTCONV namespace dc { -using Shape = ::distconv::tensor::Shape; -using Backend= ::distconv::BackendDNNLib; +using Shape = ::distconv::tensor::Shape; +using Backend = ::distconv::BackendDNNLib; template -using LayerNormalization = ::distconv::LayerNormalization; +using LayerNormalization = + ::distconv::LayerNormalization; } // namespace dc template @@ -67,12 +68,10 @@ class layer_norm_distconv_adapter void setup_distributions(tensor_overlap_constraints& constraints) override; void setup_layer(size_t workspace_capacity) override; - void setup_fp_tensors() override; - void setup_bp_tensors() override; void fp_compute(); void bp_compute(); - + TensorDevType m_statistics; TensorDevType m_statistics_grad; std::unique_ptr> m_layer_norm_operator; @@ -419,13 +418,9 @@ ::get_distconv_adapter(){ // Scatter DistConv Adapter implementation // ============================================================= -#endif // LBANN_HAS_DISTCONV - - LBANN_DEFINE_LAYER_BUILDER(layer_norm); - - // ========================================================= - // Explicit template instantiation - // ========================================================= +// ========================================================= +// Explicit template instantiation +// ========================================================= #ifndef LBANN_LAYER_NORM_LAYER_INSTANTIATE #define PROTO_DEVICE(T, Device) \ diff --git a/include/lbann/layers/regularizers/layer_norm_impl.hpp b/include/lbann/layers/regularizers/layer_norm_impl.hpp index 0cc91e17573..14d7076ccbb 100644 --- a/include/lbann/layers/regularizers/layer_norm_impl.hpp +++ b/include/lbann/layers/regularizers/layer_norm_impl.hpp @@ -33,8 +33,7 @@ #include "lbann/layers/data_type_distconv_adapter.hpp" #endif -namespace lbann{ - +namespace lbann { // ========================================================= // Implementation @@ -135,7 +134,6 @@ void layer_norm_layer::setup_data( m_statistics_gradient.reset(AbsDistMatrixType::Instantiate(dist)); } - #ifdef LBANN_HAS_DISTCONV // ============================================================= @@ -174,57 +172,53 @@ layer_norm_layer::get_distconv_adapter() layer_norm_distconv_adapter&>( static_cast&>(*this) .get_distconv_adapter()); +} // ============================================================= // LayerNorm DistConv Adapter implementation // ============================================================= - template - void layer_norm_distconv_adapter:: - setup_distributions(tensor_overlap_constraints & constraints) - { - data_type_distconv_adapter::setup_distributions( - constraints); - // no overlap needed - for (auto& d : this->m_prev_activations_dists) { - d.clear_overlap(); - constraints.mark_updated(d); - constraints.mark_invariant(d); - } - for (auto& d : this->m_activations_dists) { - d.clear_overlap(); - constraints.mark_updated(d); - constraints.mark_invariant(d); - } - for (auto& d : this->m_prev_error_signals_dists) { - d.clear_overlap(); - constraints.mark_updated(d); - constraints.mark_invariant(d); - } - for (auto& d : this->m_error_signals_dists) { - d.clear_overlap(); - constraints.mark_updated(d); - constraints.mark_invariant(d); - } +template +void layer_norm_distconv_adapter:: + setup_distributions(tensor_overlap_constraints& constraints) +{ + data_type_distconv_adapter::setup_distributions(constraints); + // no overlap needed + for (auto& d : this->m_prev_activations_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); } - - template - void layer_norm_distconv_adapter::setup_layer( - size_t workspace_capacity) - { - data_type_distconv_adapter::setup_layer(workspace_capacity); - auto& layer = dynamic_cast< - channelwise_fully_connected_layer&>( - this->layer()); - const auto max_mini_batch_size = - layer.get_model()->m_max_mini_batch_size_distconv; - - m_layer_norm_operator = - make_unique>(dc::get_backend(), - layer.m_epsilon, - max_mini_batch_size); + for (auto& d : this->m_activations_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } + for (auto& d : this->m_prev_error_signals_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); } + for (auto& d : this->m_error_signals_dists) { + d.clear_overlap(); + constraints.mark_updated(d); + constraints.mark_invariant(d); + } +} + +template +void layer_norm_distconv_adapter::setup_layer( + size_t workspace_capacity) +{ + data_type_distconv_adapter::setup_layer(workspace_capacity); + auto& layer = dynamic_cast&>( + this->layer()); + + m_layer_norm_operator = + make_unique>(dc::get_backend(), + layer.m_epsilon); +} -#endif LBANN_HAS_DISTCONV +#endif // LBANN_HAS_DISTCONV } // namespace lbann #endif // LBANN_LAYER_REGULARIZER_LAYER_NORM_IMPL_HPP_INCLUDED \ No newline at end of file diff --git a/src/layers/regularizers/distconv/distconv_layer_norm.cu b/src/layers/regularizers/distconv/distconv_layer_norm.cu index 9e2cf0ee6cc..ba102f2a897 100644 --- a/src/layers/regularizers/distconv/distconv_layer_norm.cu +++ b/src/layers/regularizers/distconv/distconv_layer_norm.cu @@ -26,27 +26,27 @@ #define LBANN_LAYERS_REGULARIZERS_DISTCONV_LAYER_NORM_INSTANTIATE -#include "../layer_norm_kernel.cuh" -#include "lbann/layers/regularizers/distconv/distonv_layer_norm.hpp" +#include "../layer_norm_kernels.cuh" +#include "lbann/layers/regularizers/distconv/distconv_layer_norm.hpp" #ifdef LBANN_HAS_DISTCONV +namespace distconv { + template template -void LayerNormalization ::calculate_forward_stats( +void LayerNormalization::calculate_forward_stats( const DCTensor& input, DCTensor& statistics) { - if (input_0.get_local_size() == 0) { + if (input.get_local_size() == 0) { util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; return; // no op for empty inputs } const auto& input_dims = input.get_local_shape(); const auto& statistics_dims = statistics.get_local_shape(); - - const auto local_num_samples = input_0_dims[3]; - + const auto local_num_samples = input_dims[3]; const auto global_num_samples = statistics_dims[3]; const auto local_sample_size = std::accumulate(input_dims.begin(), @@ -60,10 +60,7 @@ void LayerNormalization ::calculate_forward_stats( input.get_buffer(), local_sample_size); - LocalMat local_statistics(2, - local_num_samples, - statistics.get_local_shape(), - 2); + LocalMat local_statistics(2, global_num_samples, statistics.get_buffer(), 2); El::Zero(local_statistics); auto local_means = El::View(local_statistics, El::IR(0), El::ALL); @@ -71,15 +68,15 @@ void LayerNormalization ::calculate_forward_stats( { using namespace hydrogen; - auto multisync = El::MakeMultiSync(gpu::get_sync_info(local_statistics), - gpu::get_sync_info(local_input)); + auto multisync = El::MakeMultiSync(El::SyncInfoFromMatrix(local_statistics), + El::SyncInfoFromMatrix(local_input)); constexpr size_t block_size = 256; dim3 block_dims, grid_dims; block_dims.x = block_size; grid_dims.x = (local_sample_size + block_size - 1) / block_size; grid_dims.y = local_num_samples; hydrogen::gpu::LaunchKernel( - ::lbann::layer_norm_fp_sums_kernel, + ::lbann::layer_norm_fp_sums_kernel, grid_dims, block_dims, 0, @@ -97,28 +94,252 @@ void LayerNormalization ::calculate_forward_stats( template template -void LayerNormalization::apply_normalization( +void LayerNormalization::apply_normalization( const DCTensor& input, - const DCTensor& statistics, + DCTensor& statistics, DCTensor& output) -{} +{ + const auto& input_dims = input.get_local_shape(); + const auto& statistics_dims = statistics.get_local_shape(); + const auto local_num_samples = input_dims[3]; + const auto global_num_samples = statistics_dims[3]; + const auto local_sample_size = std::accumulate(input_dims.begin(), + input_dims.end() - 1, + 1, + std::multiplies()); + + using LocalMat = El::Matrix; + const LocalMat local_input(local_sample_size, + local_num_samples, + input.get_buffer(), + local_sample_size); + + LocalMat local_statistics(2, global_num_samples, statistics.get_buffer(), 2); + + LocalMat local_output(local_sample_size, + local_num_samples, + output.get_buffer(), + local_sample_size); + + auto local_means = El::View(local_statistics, El::IR(0), El::ALL); + auto local_vars = El::View(local_statistics, El::IR(1), El::ALL); + + { + using namespace hydrogen; + auto sync_info = El::SyncInfoFromMatrix(local_statistics); + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (local_num_samples + block_size - 1) / block_size; + hydrogen::gpu::LaunchKernel( + ::lbann::layer_norm_fp_statistics_kernel, + grid_dims, + block_dims, + 0, + sync_info, + local_sample_size, + local_num_samples, + local_means.Buffer(), + local_means.LDim(), + local_vars.Buffer(), + local_vars.LDim()); + + auto multisync = El::MakeMultiSync(El::SyncInfoFromMatrix(local_output), + El::SyncInfoFromMatrix(local_statistics), + El::SyncInfoFromMatrix(local_input)); + + constexpr size_t block_size_output_kernel = 256; + dim3 block_dims_output_kernel, grid_dims_output_kernel; + block_dims_output_kernel.x = block_size_output_kernel; + grid_dims_output_kernel.x = + (local_sample_size + block_size - 1) / block_size_output_kernel; + grid_dims_output_kernel.y = local_num_samples; + hydrogen::gpu::LaunchKernel(::lbann::layer_norm_fp_output_kernel, + grid_dims_output_kernel, + block_dims_output_kernel, + 0, + multisync, + local_num_samples, + local_sample_size, + m_epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output.Buffer(), + local_output.LDim(), + local_means.Buffer(), + local_means.LDim(), + local_vars.Buffer(), + local_vars.LDim()); + } +} template template -void LayerNormalization::calculate_backward_stats( +void LayerNormalization::calculate_backward_stats( const DCTensor& input, const DCTensor& output_grad, const DCTensor& statistics, DCTensor& statistics_grad) -{} +{ + const auto& input_dims = input.get_local_shape(); + const auto& statistics_dims = statistics.get_local_shape(); + const auto local_num_samples = input_dims[3]; + const auto global_num_samples = statistics_dims[3]; + const auto local_sample_size = std::accumulate(input_dims.begin(), + input_dims.end() - 1, + 1, + std::multiplies()); + using LocalMat = El::Matrix; + const LocalMat local_input(local_sample_size, + local_num_samples, + input.get_buffer(), + local_sample_size); + const LocalMat local_output_grad(local_sample_size, + local_num_samples, + output_grad.get_buffer(), + local_sample_size); + + const LocalMat local_statistics(2, + global_num_samples, + statistics.get_buffer(), + 2); + + LocalMat local_statistics_grad(2, + global_num_samples, + statistics_grad.get_buffer(), + 2); + const auto local_means = El::LockedView(local_statistics, El::IR(0), El::ALL); + const auto local_vars = El::LockedView(local_statistics, El::IR(1), El::ALL); + + auto local_means_grad = El::View(local_statistics_grad, El::IR(0), El::ALL); + auto local_vars_grad = El::View(local_statistics_grad, El::IR(1), El::ALL); + + { + using namespace hydrogen; + auto multisync = + El::MakeMultiSync(El::SyncInfoFromMatrix(local_statistics_grad), + El::SyncInfoFromMatrix(local_output_grad), + El::SyncInfoFromMatrix(local_statistics), + El::SyncInfoFromMatrix(local_input)); + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (local_sample_size + block_size - 1) / block_size; + grid_dims.y = local_num_samples; + hydrogen::gpu::LaunchKernel( + ::lbann::layer_norm_bp_statistics_grad_kernel, + grid_dims, + block_dims, + 0, + multisync, + local_num_samples, + local_sample_size, + m_epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output_grad.LockedBuffer(), + local_output_grad.LDim(), + local_means.LockedBuffer(), + local_means.LDim(), + local_vars.LockedBuffer(), + local_vars.LDim(), + local_means_grad.Buffer(), + local_means_grad.LDim(), + local_vars_grad.Buffer(), + local_vars_grad.LDim()); + } +} + template template -void LayerNormalization::apply_grad(const DCTensor& input, - const DCTensor& output_grad, - const DCTensor& statistics, - const DCTensor& statistics_grad, - DCTensor& input_grad) -{} +void LayerNormalization::apply_grad( + const DCTensor& input, + const DCTensor& output_grad, + const DCTensor& statistics, + const DCTensor& statistics_grad, + DCTensor& input_grad) +{ + const auto& input_dims = input.get_local_shape(); + const auto& statistics_dims = statistics.get_local_shape(); + const auto local_num_samples = input_dims[3]; + const auto global_num_samples = statistics_dims[3]; + const auto local_sample_size = std::accumulate(input_dims.begin(), + input_dims.end() - 1, + 1, + std::multiplies()); + + const auto global_sample_size = local_sample_size; + + using LocalMat = El::Matrix; + const LocalMat local_input(local_sample_size, + local_num_samples, + input.get_buffer(), + local_sample_size); + const LocalMat local_output_grad(local_sample_size, + local_num_samples, + output_grad.get_buffer(), + local_sample_size); + + const LocalMat local_statistics(2, + global_num_samples, + statistics.get_buffer(), + 2); + + const LocalMat local_statistics_grad(2, + global_num_samples, + statistics_grad.get_buffer(), + 2); + + LocalMat local_input_grad(local_sample_size, + local_num_samples, + input_grad.get_buffer(), + local_sample_size); + + const auto local_means = El::LockedView(local_statistics, El::IR(0), El::ALL); + const auto local_vars = El::LockedView(local_statistics, El::IR(1), El::ALL); + const auto local_means_grad = + El::LockedView(local_statistics_grad, El::IR(0), El::ALL); + const auto local_vars_grad = + El::LockedView(local_statistics_grad, El::IR(1), El::ALL); + + { + using namespace hydrogen; + auto multisync = + El::MakeMultiSync(El::SyncInfoFromMatrix(local_statistics_grad), + El::SyncInfoFromMatrix(local_output_grad), + El::SyncInfoFromMatrix(local_statistics), + El::SyncInfoFromMatrix(local_input)); + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (local_sample_size + block_size - 1) / block_size; + grid_dims.y = local_num_samples; + hydrogen::gpu::LaunchKernel( + ::lbann::layer_norm_bp_input_grad_kernel, + grid_dims, + block_dims, + 0, + multisync, + global_sample_size, + local_num_samples, + local_sample_size, + m_epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output_grad.LockedBuffer(), + local_output_grad.LDim(), + local_input_grad.Buffer(), + local_input_grad.LDim(), + local_means.LockedBuffer(), + local_means.LDim(), + local_vars.LockedBuffer(), + local_vars.LDim(), + local_means_grad.LockedBuffer(), + local_means_grad.LDim(), + local_vars_grad.LockedBuffer(), + local_vars_grad.LDim()); + } +} #define ETI(T, Backend) \ template class LayerNormalization; \ @@ -129,8 +350,7 @@ void LayerNormalization::apply_grad(const DCTensor& input, template void \ LayerNormalization::apply_normalization( \ const tensor::Tensor& input, \ - const tensor::Tensor& \ - statistics, \ + tensor::Tensor& statistics, \ tensor::Tensor& output); \ template void LayerNormalization::calculate_backward_stats< \ tensor::CUDAAllocator>( \ @@ -154,5 +374,6 @@ void LayerNormalization::apply_grad(const DCTensor& input, ETI(float, BackendDNNLib) ETI(double, BackendDNNLib) -#endef ETI +#undef ETI +} // namespace distconv #endif // LBANN_HAS_DISTCONV \ No newline at end of file diff --git a/src/layers/regularizers/layer_norm.cpp b/src/layers/regularizers/layer_norm.cpp index 7a2226378cd..fe30602362e 100644 --- a/src/layers/regularizers/layer_norm.cpp +++ b/src/layers/regularizers/layer_norm.cpp @@ -25,9 +25,9 @@ //////////////////////////////////////////////////////////////////////////////// #define LBANN_LAYER_NORM_LAYER_INSTANTIATE -#include "lbann/layers/regularizers/layer_norm.hpp" #include "lbann/comm_impl.hpp" #include "lbann/optimizers/optimizer.hpp" +#include "lbann/layers/regularizers/layer_norm_impl.hpp" #ifdef LBANN_HAS_DISTCONV #include "lbann/layers/data_type_distconv_adapter.hpp" diff --git a/src/layers/regularizers/layer_norm.cu b/src/layers/regularizers/layer_norm.cu index 3b59ab90402..e013efd5f5f 100644 --- a/src/layers/regularizers/layer_norm.cu +++ b/src/layers/regularizers/layer_norm.cu @@ -25,9 +25,11 @@ //////////////////////////////////////////////////////////////////////////////// #define LBANN_LAYER_NORM_LAYER_INSTANTIATE +#include "layer_norm_kernels.cuh" #include "lbann/comm_impl.hpp" #include "lbann/layers/regularizers/layer_norm.hpp" #include "lbann/optimizers/optimizer.hpp" +#include "lbann/layers/regularizers/layer_norm_impl.hpp" #include "lbann/utils/gpu/helpers.hpp" #ifdef LBANN_HAS_DISTCONV @@ -278,7 +280,6 @@ void fp_impl(lbann_comm& comm, block_dims.x = block_size; grid_dims.x = (local_sample_size + block_size - 1) / block_size; grid_dims.y = local_num_samples; -<<<<<<< HEAD auto kernel = ((!local_scale && !local_bias) ? fp_output_kernel @@ -447,27 +448,6 @@ bp_input_grad_kernel(unsigned long long sample_size, } } } -======= - hydrogen::gpu::LaunchKernel(layer_norm_fp_output_kernel, - grid_dims, - block_dims, - 0, - multisync, - local_num_samples, - local_sample_size, - epsilon, - local_input.LockedBuffer(), - local_input.LDim(), - local_output.Buffer(), - local_output.LDim(), - local_means.LockedBuffer(), - local_means.LDim(), - local_vars.LockedBuffer(), - local_vars.LDim()); - } -} - ->>>>>>> 9167f88b6 (Add implementation files) /** @brief Backprop */ template @@ -529,7 +509,6 @@ void bp_impl(lbann_comm& comm, block_dims.x = block_size; grid_dims.x = (local_sample_size + block_size - 1) / block_size; grid_dims.y = local_num_samples; -<<<<<<< HEAD auto kernel = ((!scale_grad && !bias_grad) ? bp_statistics_grad_kernel @@ -570,29 +549,6 @@ void bp_impl(lbann_comm& comm, local_scale, scale_grad, bias_grad); -======= - hydrogen::gpu::LaunchKernel( - layer_norm_bp_statistics_grad_kernel, - grid_dims, - block_dims, - 0, - multisync, - local_num_samples, - local_sample_size, - epsilon, - local_input.LockedBuffer(), - local_input.LDim(), - local_output_grad.LockedBuffer(), - local_output_grad.LDim(), - local_means.LockedBuffer(), - local_means.LDim(), - local_vars.LockedBuffer(), - local_vars.LDim(), - local_means_grad.Buffer(), - local_means_grad.LDim(), - local_vars_grad.Buffer(), - local_vars_grad.LDim()); ->>>>>>> 9167f88b6 (Add implementation files) } comm.allreduce(statistics_grad, statistics_grad.RedundantComm(), @@ -610,7 +566,6 @@ void bp_impl(lbann_comm& comm, block_dims.x = block_size; grid_dims.x = (local_sample_size + block_size - 1) / block_size; grid_dims.y = local_num_samples; -<<<<<<< HEAD auto kernel = (local_scale ? bp_input_grad_kernel : bp_input_grad_kernel); hydrogen::gpu::LaunchKernel(kernel, @@ -637,31 +592,6 @@ void bp_impl(lbann_comm& comm, local_vars_grad.LockedBuffer(), local_vars_grad.LDim(), local_scale); -======= - hydrogen::gpu::LaunchKernel(layer_norm_bp_input_grad_kernel, - grid_dims, - block_dims, - 0, - multisync, - sample_size, - local_num_samples, - local_sample_size, - epsilon, - local_input.LockedBuffer(), - local_input.LDim(), - local_output_grad.LockedBuffer(), - local_output_grad.LDim(), - local_input_grad.Buffer(), - local_input_grad.LDim(), - local_means.LockedBuffer(), - local_means.LDim(), - local_vars.LockedBuffer(), - local_vars.LDim(), - local_means_grad.LockedBuffer(), - local_means_grad.LDim(), - local_vars_grad.LockedBuffer(), - local_vars_grad.LDim()); ->>>>>>> 9167f88b6 (Add implementation files) } } @@ -673,19 +603,64 @@ void bp_impl(lbann_comm& comm, #ifdef LBANN_HAS_DISTCONV template -void layer_norm_distconv_adapter fp_compute() -{} +void layer_norm_distconv_adapter::fp_compute() +{ + auto& l = dynamic_cast&>( + this->layer()); + lbann_comm& comm = *(l.get_comm()); + + auto& statistics = *l.m_statistics; + assert0(dc::tensor::View(m_statistics, statistics.Buffer())); + + using GPUMatType = El::Matrix; + m_layer_norm_operator->calculate_forward_stats(this->get_prev_activations(), + m_statistics); + comm.allreduce(statistics, statistics.RedundantComm(), El::mpi::SUM); + m_layer_norm_operator->apply_normalization(this->get_prev_activations(), + m_statistics, + this->get_activations()); +} template -void layer_norm_distconv_adapter bp_compute() -{} +void layer_norm_distconv_adapter::bp_compute() +{ + auto& l = dynamic_cast&>( + this->layer()); + lbann_comm& comm = *(l.get_comm()); + + auto& statistics = *l.m_statistics; + auto& statistics_grad = *l.m_statistics_gradient; + assert0(dc::tensor::View(m_statistics, statistics.Buffer())); + assert0(dc::tensor::View(m_statistics_grad, statistics_grad.Buffer())); + + using GPUMatType = El::Matrix; + m_layer_norm_operator->calculate_backward_stats( + this->get_prev_activations(), + this->get_prev_error_signals(), + m_statistics, + m_statistics_grad); + comm.allreduce(statistics_grad, + statistics_grad.RedundantComm(), + El::mpi::SUM); + m_layer_norm_operator->apply_grad(this->get_prev_activations(), + this->get_prev_error_signals(), + m_statistics, + m_statistics_grad, + this->get_error_signals()); +} #endif // LBANN_HAS_DISTCONV // Template instantiation template -<<<<<<< HEAD void layer_norm_layer::fp_compute() { + #ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { + this->get_distconv_adapter().fp_compute(); + return; + } + #endif // LBANN_HAS_DISTCONV + int weight_idx = 0; const TensorDataType* scale_weights = nullptr; const TensorDataType* bias_weights = nullptr; @@ -696,16 +671,7 @@ void layer_norm_layer::fp_compute() bias_weights = this->weights_values(weight_idx).LockedMatrix().LockedBuffer(); -======= -void layer_norm_layer::fp_compute() -{ -#ifdef LBANN_HAS_DISTCONV - if (this->distconv_enabled()) { - this->get_distconv_adapter().fp_compute(); - return; - } -#endif // LBANN_HAS_DISTCONV ->>>>>>> 9167f88b6 (Add implementation files) + fp_impl(*this->get_comm(), this->m_epsilon, this->get_prev_activations(), @@ -716,9 +682,15 @@ void layer_norm_layer::fp_compute() } template -<<<<<<< HEAD void layer_norm_layer::bp_compute() { + #ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { + this->get_distconv_adapter().bp_compute(); + return; + } + #endif // LBANN_HAS_DISTCONV + // Obtain optional buffers const TensorDataType* scale_weights = nullptr; TensorDataType* scale_grad = nullptr; @@ -735,18 +707,6 @@ void layer_norm_layer::bp_compute() bias_grad = this->m_bias_gradient->Buffer(); } -// Compute backpropagation -======= -void layer_norm_layer::bp_compute() -{ -#ifdef LBANN_HAS_DISTCONV - if (this->distconv_enabled()) { - this->get_distconv_adapter().bp_compute(); - return; - } -#endif // LBANN_HAS_DISTCONV - ->>>>>>> 9167f88b6 (Add implementation files) bp_impl(*this->get_comm(), this->m_epsilon, this->get_prev_activations(),