From d59e2dbe09a7b078e1316dad810ad9683072eae4 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Fri, 17 Mar 2023 01:56:24 -0400 Subject: [PATCH] Basic implementation completed --- .../distconv/distconv_layer_norm.cu | 216 +++++++++++++++++- src/layers/regularizers/layer_norm.cu | 190 ++++++++++----- 2 files changed, 342 insertions(+), 64 deletions(-) diff --git a/src/layers/regularizers/distconv/distconv_layer_norm.cu b/src/layers/regularizers/distconv/distconv_layer_norm.cu index 9e2cf0ee6cc..6efdab532b1 100644 --- a/src/layers/regularizers/distconv/distconv_layer_norm.cu +++ b/src/layers/regularizers/distconv/distconv_layer_norm.cu @@ -41,14 +41,10 @@ void LayerNormalization ::calculate_forward_stats( util::MPIRootPrintStreamInfo() << "WARNING: EMPTY INPUT FOUND \n"; return; // no op for empty inputs } - const auto& input_dims = input.get_local_shape(); const auto& statistics_dims = statistics.get_local_shape(); - const auto local_num_samples = input_0_dims[3]; - const auto global_num_samples = statistics_dims[3]; - const auto local_sample_size = std::accumulate(input_dims.begin(), input_dims.end() - 1, 1, @@ -61,7 +57,7 @@ void LayerNormalization ::calculate_forward_stats( local_sample_size); LocalMat local_statistics(2, - local_num_samples, + global_num_samples, statistics.get_local_shape(), 2); @@ -101,7 +97,79 @@ void LayerNormalization::apply_normalization( const DCTensor& input, const DCTensor& statistics, DCTensor& output) -{} +{ + const auto& input_dims = input.get_local_shape(); + const auto& statistics_dims = statistics.get_local_shape(); + const auto local_num_samples = input_0_dims[3]; + const auto global_num_samples = statistics_dims[3]; + const auto local_sample_size = std::accumulate(input_dims.begin(), + input_dims.end() - 1, + 1, + std::multiplies()); + + using LocalMat = El::Matrix; + const LocalMat local_input(local_sample_size, + local_num_samples, + input.get_buffer(), + local_sample_size); + + const LocalMat local_statistics(2, + global_num_samples, + statistics.get_local_shape(), + 2); + + LocalMat local_output(local_sample_size, + local_num_samples, + output.get_buffer(), + local_sample_size); + + const auto local_means = El::View(local_statistics, El::IR(0), El::ALL); + const auto local_vars = El::View(local_statistics, El::IR(1), El::ALL); + { + using namespace hydrogen; + auto sync_info = gpu::get_sync_info(local_statistics); + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (local_num_samples + block_size - 1) / block_size; + hydrogen::gpu::LaunchKernel(layer_norm_fp_statistics_kernel, + grid_dims, + block_dims, + 0, + sync_info, + sample_size, + local_num_samples, + local_means.Buffer(), + local_means.LDim(), + local_vars.Buffer(), + local_vars.LDim()); + + auto multisync = El::MakeMultiSync(gpu::get_sync_info(local_output), + gpu::get_sync_info(local_statistics), + gpu::get_sync_info(local_input)); + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (local_sample_size + block_size - 1) / block_size; + grid_dims.y = local_num_samples; + hydrogen::gpu::LaunchKernel(layer_norm_fp_output_kernel, + grid_dims, + block_dims, + 0, + multisync, + local_num_samples, + local_sample_size, + epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output.Buffer(), + local_output.LDim(), + local_means.LockedBuffer(), + local_means.LDim(), + local_vars.LockedBuffer(), + local_vars.LDim()); + } +} template template @@ -110,7 +178,70 @@ void LayerNormalization::calculate_backward_stats( const DCTensor& output_grad, const DCTensor& statistics, DCTensor& statistics_grad) -{} +{ + const auto& input_dims = input.get_local_shape(); + const auto& statistics_dims = statistics.get_local_shape(); + const auto local_num_samples = input_0_dims[3]; + const auto global_num_samples = statistics_dims[3]; + const auto local_sample_size = std::accumulate(input_dims.begin(), + input_dims.end() - 1, + 1, + std::multiplies()); + using LocalMat = El::Matrix; + const LocalMat local_input(local_sample_size, + local_num_samples, + input.get_buffer(), + local_sample_size); + const LocalMat local_output_grad(local_sample_size, + local_num_samples, + output_grad.get_buffer(), + local_sample_size); + + const LocalMat local_statistics(2, + global_num_samples, + statistics.get_local_shape(), + 2); + + LocalMat local_statistics_grad(2, + global_num_samples, + statistics_grad.get_buffer(), + 2); + { + using namespace hydrogen; + auto multisync = + El::MakeMultiSync(gpu::get_sync_info(local_statistics_grad), + gpu::get_sync_info(local_output_grad), + gpu::get_sync_info(local_statistics), + gpu::get_sync_info(local_input)); + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (local_sample_size + block_size - 1) / block_size; + grid_dims.y = local_num_samples; + hydrogen::gpu::LaunchKernel( + layer_norm_bp_statistics_grad_kernel, + grid_dims, + block_dims, + 0, + multisync, + local_num_samples, + local_sample_size, + m_epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output_grad.LockedBuffer(), + local_output_grad.LDim(), + local_means.LockedBuffer(), + local_means.LDim(), + local_vars.LockedBuffer(), + local_vars.LDim(), + local_means_grad.Buffer(), + local_means_grad.LDim(), + local_vars_grad.Buffer(), + local_vars_grad.LDim()); + } +} + template template void LayerNormalization::apply_grad(const DCTensor& input, @@ -118,7 +249,76 @@ void LayerNormalization::apply_grad(const DCTensor& input, const DCTensor& statistics, const DCTensor& statistics_grad, DCTensor& input_grad) -{} +{ + const auto& input_dims = input.get_local_shape(); + const auto& statistics_dims = statistics.get_local_shape(); + const auto local_num_samples = input_0_dims[3]; + const auto global_num_samples = statistics_dims[3]; + const auto local_sample_size = std::accumulate(input_dims.begin(), + input_dims.end() - 1, + 1, + std::multiplies()); + using LocalMat = El::Matrix; + const LocalMat local_input(local_sample_size, + local_num_samples, + input.get_buffer(), + local_sample_size); + const LocalMat local_output_grad(local_sample_size, + local_num_samples, + output_grad.get_buffer(), + local_sample_size); + + const LocalMat local_statistics(2, + global_num_samples, + statistics.get_local_shape(), + 2); + + const LocalMat local_statistics_grad(2, + global_num_samples, + statistics_grad.get_buffer(), + 2); + + LocalMat local_input_grad(local_sample_size, + local_num_samples, + input_grad.get_buffer(), + local_sample_size); + { + using namespace hydrogen; + auto multisync = + El::MakeMultiSync(gpu::get_sync_info(local_statistics_grad), + gpu::get_sync_info(local_output_grad), + gpu::get_sync_info(local_statistics), + gpu::get_sync_info(local_input)); + constexpr size_t block_size = 256; + dim3 block_dims, grid_dims; + block_dims.x = block_size; + grid_dims.x = (local_sample_size + block_size - 1) / block_size; + grid_dims.y = local_num_samples; + hydrogen::gpu::LaunchKernel(layer_norm_bp_input_grad_kernel, + grid_dims, + block_dims, + 0, + multisync, + sample_size, + local_num_samples, + local_sample_size, + m_epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output_grad.LockedBuffer(), + local_output_grad.LDim(), + local_input_grad.Buffer(), + local_input_grad.LDim(), + local_means.LockedBuffer(), + local_means.LDim(), + local_vars.LockedBuffer(), + local_vars.LDim(), + local_means_grad.LockedBuffer(), + local_means_grad.LDim(), + local_vars_grad.LockedBuffer(), + local_vars_grad.LDim()); + } +} #define ETI(T, Backend) \ template class LayerNormalization; \ diff --git a/src/layers/regularizers/layer_norm.cu b/src/layers/regularizers/layer_norm.cu index b501c404617..dd15451c3ff 100644 --- a/src/layers/regularizers/layer_norm.cu +++ b/src/layers/regularizers/layer_norm.cu @@ -84,11 +84,18 @@ void fp_impl(lbann_comm& comm, grid_dims.y = local_num_samples; hydrogen::gpu::LaunchKernel( layer_norm_fp_sums_kernel, - grid_dims, block_dims, 0, multisync, - local_num_samples, local_sample_size, - local_input.LockedBuffer(), local_input.LDim(), - local_means.Buffer(), local_means.LDim(), - local_vars.Buffer(), local_vars.LDim()); + grid_dims, + block_dims, + 0, + multisync, + local_num_samples, + local_sample_size, + local_input.LockedBuffer(), + local_input.LDim(), + local_means.Buffer(), + local_means.LDim(), + local_vars.Buffer(), + local_vars.LDim()); } comm.allreduce(statistics, statistics.RedundantComm(), El::mpi::SUM); @@ -103,12 +110,17 @@ void fp_impl(lbann_comm& comm, dim3 block_dims, grid_dims; block_dims.x = block_size; grid_dims.x = (local_num_samples + block_size - 1) / block_size; - hydrogen::gpu::LaunchKernel( - layer_norm_fp_statistics_kernel, - grid_dims, block_dims, 0, sync_info, - sample_size, local_num_samples, - local_means.Buffer(), local_means.LDim(), - local_vars.Buffer(), local_vars.LDim()); + hydrogen::gpu::LaunchKernel(layer_norm_fp_statistics_kernel, + grid_dims, + block_dims, + 0, + sync_info, + sample_size, + local_num_samples, + local_means.Buffer(), + local_means.LDim(), + local_vars.Buffer(), + local_vars.LDim()); } // Apply layer norm @@ -121,18 +133,25 @@ void fp_impl(lbann_comm& comm, block_dims.x = block_size; grid_dims.x = (local_sample_size + block_size - 1) / block_size; grid_dims.y = local_num_samples; - hydrogen::gpu::LaunchKernel( - layer_norm_fp_output_kernel, - grid_dims, block_dims, 0, multisync, - local_num_samples, local_sample_size, epsilon, - local_input.LockedBuffer(), local_input.LDim(), - local_output.Buffer(), local_output.LDim(), - local_means.LockedBuffer(), local_means.LDim(), - local_vars.LockedBuffer(), local_vars.LDim()); + hydrogen::gpu::LaunchKernel(layer_norm_fp_output_kernel, + grid_dims, + block_dims, + 0, + multisync, + local_num_samples, + local_sample_size, + epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output.Buffer(), + local_output.LDim(), + local_means.LockedBuffer(), + local_means.LDim(), + local_vars.LockedBuffer(), + local_vars.LDim()); } } - /** @brief Backprop */ template void bp_impl(lbann_comm& comm, @@ -192,14 +211,25 @@ void bp_impl(lbann_comm& comm, grid_dims.y = local_num_samples; hydrogen::gpu::LaunchKernel( layer_norm_bp_statistics_grad_kernel, - grid_dims, block_dims, 0, multisync, - local_num_samples, local_sample_size, epsilon, - local_input.LockedBuffer(), local_input.LDim(), - local_output_grad.LockedBuffer(), local_output_grad.LDim(), - local_means.LockedBuffer(), local_means.LDim(), - local_vars.LockedBuffer(), local_vars.LDim(), - local_means_grad.Buffer(), local_means_grad.LDim(), - local_vars_grad.Buffer(), local_vars_grad.LDim()); + grid_dims, + block_dims, + 0, + multisync, + local_num_samples, + local_sample_size, + epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output_grad.LockedBuffer(), + local_output_grad.LDim(), + local_means.LockedBuffer(), + local_means.LDim(), + local_vars.LockedBuffer(), + local_vars.LDim(), + local_means_grad.Buffer(), + local_means_grad.LDim(), + local_vars_grad.Buffer(), + local_vars_grad.LDim()); } comm.allreduce(statistics_grad, statistics_grad.RedundantComm(), @@ -217,17 +247,29 @@ void bp_impl(lbann_comm& comm, block_dims.x = block_size; grid_dims.x = (local_sample_size + block_size - 1) / block_size; grid_dims.y = local_num_samples; - hydrogen::gpu::LaunchKernel( - layer_norm_bp_input_grad_kernel, - grid_dims, block_dims, 0, multisync, - sample_size, local_num_samples, local_sample_size, epsilon, - local_input.LockedBuffer(), local_input.LDim(), - local_output_grad.LockedBuffer(), local_output_grad.LDim(), - local_input_grad.Buffer(), local_input_grad.LDim(), - local_means.LockedBuffer(), local_means.LDim(), - local_vars.LockedBuffer(), local_vars.LDim(), - local_means_grad.LockedBuffer(), local_means_grad.LDim(), - local_vars_grad.LockedBuffer(), local_vars_grad.LDim()); + hydrogen::gpu::LaunchKernel(layer_norm_bp_input_grad_kernel, + grid_dims, + block_dims, + 0, + multisync, + sample_size, + local_num_samples, + local_sample_size, + epsilon, + local_input.LockedBuffer(), + local_input.LDim(), + local_output_grad.LockedBuffer(), + local_output_grad.LDim(), + local_input_grad.Buffer(), + local_input_grad.LDim(), + local_means.LockedBuffer(), + local_means.LDim(), + local_vars.LockedBuffer(), + local_vars.LDim(), + local_means_grad.LockedBuffer(), + local_means_grad.LDim(), + local_vars_grad.LockedBuffer(), + local_vars_grad.LDim()); } } @@ -239,30 +281,65 @@ void bp_impl(lbann_comm& comm, #ifdef LBANN_HAS_DISTCONV template -void -layer_norm_distconv_adapter -fp_compute(){ +void layer_norm_distconv_adapter fp_compute() +{ + auto& l = dynamic_cast< + channelwise_fully_connected_layer&>( + this->layer()); + lbann_comm& comm = *(l.get_comm()); + + auto& statistics = *l.m_statistics; + assert0(dc::tensor::View(m_statistics, statistics.Buffer())); + using GPUMatType = El::Matrix; + m_layer_norm_operator->calculate_forward_stats(this->get_prev_activations(), + m_statistics); + comm.allreduce(statistics, statistics.RedundantComm(), El::mpi::SUM); + m_layer_norm_operator->apply_normalization(this->get_prev_activations(), + m_statistics, + this->get_activations()); } template -void -layer_norm_distconv_adapter -bp_compute(){ +void layer_norm_distconv_adapter bp_compute() +{ + auto& l = dynamic_cast< + channelwise_fully_connected_layer&>( + this->layer()); + lbann_comm& comm = *(l.get_comm()); + auto& statistics = *l.m_statistics; + auto& statistics_grad = *l.m_statistics_gradient; + assert0(dc::tensor::View(m_statistics, statistics.Buffer())); + assert0(dc::tensor::View(m_statistics_grad, statistics_grad.Buffer())); + + using GPUMatType = El::Matrix; + m_layer_norm_operator->calculate_backward_stats( + this->get_prev_activations(), + this->get_prev_error_signals(), + m_statistics, + m_statistics_grad); + comm.allreduce(statistics_grad, + statistics_grad.RedundantComm(), + El::mpi::SUM); + m_layer_norm_operator->apply_grad(this->get_prev_activations(), + this->get_prev_error_signals(), + m_statistics, + m_statistics_grad, + this->get_error_signals()); } #endif // LBANN_HAS_DISTCONV - // Template instantiation template -void layer_norm_layer::fp_compute() { - #ifdef LBANN_HAS_DISTCONV - if(this->distconv_enabled()){ +void layer_norm_layer::fp_compute() +{ +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { this->get_distconv_adapter().fp_compute(); - return ; + return; } - #endif // LBANN_HAS_DISTCONV +#endif // LBANN_HAS_DISTCONV fp_impl(*this->get_comm(), this->m_epsilon, this->get_prev_activations(), @@ -271,13 +348,14 @@ void layer_norm_layer::fp_compute() { } template -void layer_norm_layer::bp_compute() { - #ifdef LBANN_HAS_DISTCONV - if(this->distconv_enabled()){ +void layer_norm_layer::bp_compute() +{ +#ifdef LBANN_HAS_DISTCONV + if (this->distconv_enabled()) { this->get_distconv_adapter().bp_compute(); - return ; + return; } - #endif // LBANN_HAS_DISTCONV +#endif // LBANN_HAS_DISTCONV bp_impl(*this->get_comm(), this->m_epsilon,