Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Infrastructure to exactly calculate L1 Memory Usage for Conv2D #15088 #15455

Merged
merged 10 commits into from
Jan 12, 2025
Merged
3 changes: 1 addition & 2 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def run_conv_with_split(

@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize("stride", [2])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_channels, input_channels, input_height, input_width, shard_layout, config",
(
Expand Down Expand Up @@ -589,7 +589,6 @@ def test_conv_ws(
debug = False
groups = 1

torch.manual_seed(0)
mywoodstock marked this conversation as resolved.
Show resolved Hide resolved
conv_input_shape = [batch_size, input_channels, input_height, input_width]
conv_weight_shape = [output_channels, input_channels // groups, filter_height, filter_width]
conv_bias_shape = [1, 1, 1, output_channels]
Expand Down
12 changes: 12 additions & 0 deletions tt_metal/impl/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class Program_ {
void invalidate_circular_buffer_allocation();

void allocate_circular_buffers(const IDevice* device);
uint32_t get_cb_memory_size() const;

bool is_finalized() const;
void set_finalized();
Expand Down Expand Up @@ -768,6 +769,17 @@ void detail::Program_::invalidate_circular_buffer_allocation() {

void Program::invalidate_circular_buffer_allocation() { pimpl_->invalidate_circular_buffer_allocation(); }

uint32_t Program::get_cb_memory_size() const { return pimpl_->get_cb_memory_size(); }
uint32_t detail::Program_::get_cb_memory_size() const {
uint32_t total_cb_size = 0;
for (const auto& circular_buffer : this->circular_buffers_) {
if (circular_buffer->globally_allocated()) {
continue;
}
total_cb_size += circular_buffer->size();
}
return total_cb_size;
}
void detail::Program_::allocate_circular_buffers(const IDevice* device) {
//ZoneScoped;
if (not this->local_circular_buffer_allocation_needed_) {
Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/program/program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class Program {
HWCommandQueue* get_last_used_command_queue() const;
const std::vector<SubDeviceId> &determine_sub_device_ids(const IDevice* device);
void set_kernels_bin_buffer(const std::shared_ptr<Buffer>& buffer);
uint32_t get_cb_memory_size() const;
private:
std::unique_ptr<detail::Program_> pimpl_;

Expand Down
65 changes: 17 additions & 48 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,21 @@
#include <optional>
#include <utility>

#include "common/constants.hpp"
#include "common/math.hpp"

#include "tt_metal/impl/buffers/buffer_constants.hpp"

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/types.hpp"
#include "ttnn/operations/core/core.hpp"

#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp"
#include "ttnn/operations/conv/conv2d/conv2d.hpp"
#include "ttnn/operations/conv/conv2d/conv2d_utils.hpp"
#include "ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp"
#include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp"
#include "ttnn/operations/data_movement/move/move.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
#include "ttnn/operations/sliding_window/halo/halo.hpp"
#include "ttnn/operations/sliding_window/sliding_window.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/data_movement/move/move.hpp"

using namespace tt;
namespace ttnn {
Expand Down Expand Up @@ -61,12 +59,15 @@ Result conv2d(
const uint32_t output_width =
((input_width - kernel_size[1] - ((kernel_size[0] - 1) * (dilation[0] - 1)) + 2 * padding[1]) / stride[1]) + 1;

DeviceComputeKernelConfig compute_config = compute_config_.value_or(get_conv_default_compute_kernel_config(device));

const auto compute_grid_size = device->compute_with_storage_grid_size();

bool auto_shard = false;
if (!input_tensor.is_sharded() && !conv_config.shard_layout.has_value()) {
// In this case we deduce the shard layout.
adjust_conv_op_config_for_auto_shard_if_necessary(
conv_config = determine_conv_config_for_auto_shard(
conv_config,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ clang-diagnostic-error ⚠️
no matching function for call to adjust_conv_op_config_for_auto_shard_if_necessary

mm_conv,
batch_size,
in_channels,
Expand All @@ -76,7 +77,6 @@ Result conv2d(
weight_tensor.get_shape()[3],
input_width,
compute_grid_size,
conv_config,
input_tensor.layout(),
ttnn::is_tensor_on_device_or_multidevice(input_tensor) ? std::make_optional(input_tensor.memory_config())
: std::nullopt);
Expand All @@ -87,8 +87,6 @@ Result conv2d(
conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR;
bool is_non_tile_mul_width = check_non_tile_mul_width(device, conv_config, in_channels);

DeviceComputeKernelConfig compute_config = compute_config_.value_or(
init_device_compute_kernel_config(device->arch(), std::nullopt, MathFidelity::HiFi4, true, false, false));
auto [input_tensor_post_tm, parallel_config, output_parallel_config, use_non_tile_height] =
shard_or_reshard_tensor_if_required(
device,
Expand All @@ -103,48 +101,19 @@ Result conv2d(
auto_shard,
is_non_tile_mul_width);

uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1;
uint32_t nhw_out = batch_size * output_height * output_width;
uint32_t out_channels_padded = tt::round_up(
out_channels, get_num_cores_channels_from_parallel_config(output_parallel_config) * tt::constants::TILE_WIDTH);
if (is_non_tile_mul_width) {
out_channels_padded = tt::round_up(out_channels, 32);
}
MemoryConfig conv_out_memory_config = create_sharded_memory_config_from_parallel_config(
ttnn::Shape(std::array<uint32_t, 4>{1, 1, nhw_out, out_channels_padded}),
auto [opt_conv_op_parallel_config, opt_conv_op_block_config, conv_out_memory_config] = get_conv_configs(
conv_config,
compute_config,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ clang-diagnostic-error ⚠️
no matching function for call to get_conv_configs

parallel_config,
output_parallel_config,
round_up_size);
ParallelConfig largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores()
? output_parallel_config
: parallel_config;

OptimizedConvParallelizationConfig opt_conv_op_parallel_config =
determine_conv_op_parallel_config_from_conv_output_mem_config(
conv_out_memory_config,
get_num_cores_nhw_from_parallel_config(largest_parallel_config),
get_num_cores_channels_from_parallel_config(largest_parallel_config));

uint32_t in_channels_padded = tt::round_up(
in_channels,
get_num_cores_channels_from_parallel_config(parallel_config) * conv_config.input_channels_alignment);
if (is_non_tile_mul_width) {
in_channels_padded = tt::round_up(in_channels, conv_config.input_channels_alignment);
}

uint32_t nhw_out_padded_ntile_per_core =
conv_out_memory_config.shard_spec.value().shape[0] / tt::constants::TILE_HEIGHT;
out_channels,
batch_size,
output_height,
output_width,
kernel_size,
device);

OptimizedConvBlockConfig opt_conv_op_block_config = determine_per_core_conv_block_config(
parallel_config,
opt_conv_op_parallel_config,
in_channels_padded,
nhw_out_padded_ntile_per_core,
conv_config.act_block_h_override,
conv_config.act_block_w_div,
kernel_size[0],
kernel_size[1],
get_fp32_dest_acc_en(compute_config),
conv_config.enable_split_reader);
bool weight_is_on_device = ttnn::is_tensor_on_device_or_multidevice(weight_tensor);
ttnn::Tensor weight_tensor_on_device = weight_tensor;
std::optional<ttnn::Tensor> bias_tensor_on_device = bias_tensor;
Expand Down
Loading
Loading