Skip to content

Commit

Permalink
Fix up transposed conv
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavle Josipovic committed Dec 9, 2024
1 parent e17b835 commit 2c941fc
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,6 @@ Result conv_transpose2d(

uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1;

sliding_window_config.num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config);
sliding_window_config.core_range_set = input_tensor_post_tm.memory_config().shard_spec.value().grid;
sliding_window_config.snap_to_tile = !use_non_tile_height;

if (tensor_manipulated) {
if (conv_config.deallocate_activation) {
ttnn::Tensor input_tensor_ = input_tensor; // TODO: allow in place modification of inputs to the op
Expand All @@ -207,24 +203,31 @@ Result conv_transpose2d(
conv_config.deallocate_activation = true;
}

auto halo_output = ttnn::halo(
DefaultQueueId,
input_tensor_post_tm,
sliding_window_config,
0,
false,
parallel_config.shard_orientation == ShardOrientation::COL_MAJOR,
0,
input_tensor_post_tm.memory_config());

if(conv_config.deallocate_activation) {
input_tensor_post_tm.deallocate();
log_debug(tt::LogOp, "Deallocate Input Tensor");
}
if (conv_config.reallocate_halo_output) {
auto move_output = ttnn::operations::core::reallocate(halo_output, halo_output.memory_config());
halo_output = move_output;
log_debug(tt::LogOp, "Reallocate Halo Output");
Tensor halo_output;
if (!mm_conv) {
sliding_window_config.num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config);
sliding_window_config.core_range_set = input_tensor_post_tm.memory_config().shard_spec.value().grid;
sliding_window_config.snap_to_tile = !use_non_tile_height;

halo_output = ttnn::halo(
DefaultQueueId,
input_tensor_post_tm,
sliding_window_config,
0,
false,
parallel_config.shard_orientation == ShardOrientation::COL_MAJOR,
0,
input_tensor_post_tm.memory_config());

if(conv_config.deallocate_activation) {
input_tensor_post_tm.deallocate();
log_debug(tt::LogOp, "Deallocate Input Tensor");
}
if (conv_config.reallocate_halo_output) {
auto move_output = ttnn::operations::core::reallocate(halo_output, halo_output.memory_config());
halo_output = move_output;
log_debug(tt::LogOp, "Reallocate Halo Output");
}
}

//Call Conv2d u_op with Stride = 1, Padding = 0.
Expand All @@ -243,13 +246,14 @@ Result conv_transpose2d(

uint32_t in_channels_padded = tt::round_up(
in_channels,
get_num_cores_channels_from_parallel_config(parallel_config) *
conv_config.input_channels_alignment);
get_num_cores_channels_from_parallel_config(parallel_config) * conv_config.input_channels_alignment);
uint32_t nhw_out_padded_ntile = get_num_cores_nhw_from_parallel_config(output_parallel_config) *
conv_out_memory_config.shard_spec.value().shape[0] / tt::constants::TILE_HEIGHT;
auto opt_conv_op_block_config = determine_per_core_conv_block_config(
parallel_config,
opt_conv_op_parallel_config,
in_channels_padded,
(input_tensor_post_tm.shard_spec().value().shape[0] * get_num_cores_nhw_from_parallel_config(parallel_config)) / tt::constants::TILE_HEIGHT,
nhw_out_padded_ntile,
conv_config.act_block_h_override,
conv_config.act_block_w_div,
kernel_size[0],
Expand All @@ -262,7 +266,6 @@ Result conv_transpose2d(
ttnn::Tensor weight_tensor_on_device = weight_tensor;
std::optional<ttnn::Tensor> bias_tensor_on_device = bias_tensor;
if (!weight_is_on_device) {

// prepare weights in desired layout and move to device
tie(weight_tensor_on_device, bias_tensor_on_device) = prepare_conv_weights_biases_and_move_to_device(
transform_weights_for_conv_transpose2d(weight_tensor),
Expand All @@ -279,7 +282,7 @@ Result conv_transpose2d(
}
if (mm_conv) {
Tensor matmul_input = ttnn::to_layout(
halo_output, Layout::TILE, conv_config.dtype, input_tensor_post_tm.memory_config(), device);
input_tensor_post_tm, Layout::TILE, conv_config.dtype, input_tensor_post_tm.memory_config(), device);
std::optional<ttnn::operations::matmul::MatmulProgramConfig> program_config = std::nullopt;
std::optional<MemoryConfig> mm_output_memory_config = std::nullopt;

Expand Down

0 comments on commit 2c941fc

Please sign in to comment.