Skip to content

Commit

Permalink
#5725: Adding bilinear support in upsample
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Sep 19, 2024
1 parent eedf9af commit 1673fd9
Show file tree
Hide file tree
Showing 14 changed files with 639 additions and 56 deletions.
141 changes: 139 additions & 2 deletions tests/ttnn/unit_tests/operations/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

import pytest
import math
from loguru import logger
from typing import Union, Tuple

import torch
import torch.nn as nn
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_grayskull, skip_for_blackhole
from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout


TILE_WIDTH = 32
Expand Down Expand Up @@ -222,3 +223,139 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate
assert allclose
assert isclose
assert isequal


@skip_for_grayskull()
@skip_for_blackhole()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"batch_size, num_channels, height, width, scale_h, scale_w",
(
(1, 256, 16, 16, 8, 8), # 256x256
(1, 256, 32, 32, 4, 4), # 256x256
(1, 256, 64, 64, 2, 2), # 256x256
(1, 256, 128, 128, 1, 1), # 256x256
),
)
@pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT])
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.HiFi4, ttnn.MathFidelity.LoFi])
@pytest.mark.parametrize("math_approx_mode", [True, False])
def test_bilinear_multi_core(
device,
use_program_cache,
batch_size,
num_channels,
height,
width,
scale_h,
scale_w,
shard_strategy,
math_fidelity,
math_approx_mode,
):
## input shape is N C H W
input_shape = [batch_size, num_channels, height, width]
torch.manual_seed(0)
input = torch.rand(input_shape, dtype=torch.bfloat16)

## golden reference using torch
scale_factor = (scale_h, scale_w)
torch_upsample = nn.Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=False)
torch_result = torch_upsample(input)

## permute to N H W C, which is what the upsample op expects
tt_input = input.permute(0, 2, 3, 1)

num_bytes = 2 ## only BFLOAT16 is supported

## calculate ncores, corresponding grid_size and in_shard_shape based on the input_shape
ncores = None
device_grid = device.compute_with_storage_grid_size()
max_grid_size = (device_grid.y, device_grid.x)
if shard_strategy == ttnn.ShardStrategy.HEIGHT:
## nsticks per shard should be divisible by in_w
max_nshards = min(batch_size * height * width, max_grid_size[0] * max_grid_size[1])
nshards = max_nshards
while nshards > 0:
if batch_size * height * width % (nshards * TILE_WIDTH) == 0:
break
nshards -= 1
ncores = nshards
elif shard_strategy == ttnn.ShardStrategy.BLOCK:
max_nshards_h = min(batch_size * height, max_grid_size[0]) ## height along NHW
max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C
## find nshards_h along NHW
nshards_h = max_nshards_h
while nshards_h > 0:
if batch_size * height % nshards_h == 0:
break
nshards_h -= 1
## find nshards_w along C
nshards_w = max_nshards_w
while nshards_w > 0:
## make sure: 1. nshards_w divides num_channels, and 2. shard_shape[1] is aligned to 32B
if num_channels % nshards_w == 0 and math.ceil(num_channels * num_bytes / nshards_w) % TILE_WIDTH == 0:
break
nshards_w -= 1
if nshards_w == 0 or nshards_h == 0:
raise ValueError("nshards_h or nshards_w is 0")
ncores = (nshards_h, nshards_w)

shard_grid = get_shard_grid_from_num_cores(device, ncores)
shard_orientation = ttnn.ShardOrientation.ROW_MAJOR

if shard_strategy == ttnn.ShardStrategy.BLOCK:
tensor_memory_layout = ttnn.types.TensorMemoryLayout.BLOCK_SHARDED
elif shard_strategy == ttnn.ShardStrategy.HEIGHT:
tensor_memory_layout = ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED

## input shard
if shard_strategy == ttnn.ShardStrategy.BLOCK:
shard_height = math.ceil(batch_size * height * width / ncores[0])
shard_width = math.ceil(num_channels / ncores[1])
elif shard_strategy == ttnn.ShardStrategy.HEIGHT:
shard_height = math.ceil(batch_size * height * width / ncores)
shard_width = num_channels
# breakpoint()
shard_shape = (shard_height, shard_width)
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, shard_orientation, False)
in_sharded_mem_config = ttnn.MemoryConfig(tensor_memory_layout, ttnn.types.BufferType.L1, shard_spec)

## output shard
shard_height = shard_height * scale_h * scale_w
shard_shape = (shard_height, shard_width)
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, shard_orientation, False)

compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=math_fidelity,
math_approx_mode=math_approx_mode,
fp32_dest_acc_en=False,
)

out_sharded_mem_config = ttnn.MemoryConfig(tensor_memory_layout, ttnn.types.BufferType.L1, shard_spec)

logger.debug(f"in_shard_mem_config: {in_sharded_mem_config}")
logger.debug(f"out_shard_mem_config: {out_sharded_mem_config}")

## ttnn uses NHWC, so need to set scale_factor_c = 1
scale_factor = (scale_h, scale_w, 1)
input_tensor = ttnn.from_torch(tt_input, device=device)
input_tensor = ttnn.to_memory_config(input_tensor, memory_config=in_sharded_mem_config)
output_tensor = ttnn.upsample(
input_tensor,
scale_factor,
mode="bilinear",
memory_config=out_sharded_mem_config,
compute_kernel_config=compute_kernel_config,
)
output_tensor = ttnn.to_memory_config(output_tensor, memory_config=ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor)

## compare the results
torch_result = torch_result.permute(0, 2, 3, 1)
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_result, output_tensor, pcc=0.999)
allclose = torch.allclose(output_tensor, torch_result, atol=1e-1, rtol=1e-1)
logger.info(pcc_msg)

assert allclose
assert passing
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device//upsample_bilinear_program_factory_multicore.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_singlecore.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/upsample.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/upsample_pybind.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "compute_kernel_api/tilize.h"
#include "compute_kernel_api/reduce.h"
#include "compute_kernel_api/pack_untilize.h"

template<uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, uint32_t unpA_face_r_dim>
inline void reduce_h_fused(
const uint32_t in_cb_id,
const uint32_t in_scalar_cb_id,
const uint32_t in_ntiles_hwc,
const uint32_t in_stick_index,
const uint32_t out_cb_id) {

cb_reserve_back(out_cb_id, 1);
tile_regs_acquire();
cb_wait_front(in_cb_id, 4);
unpack_tilizeA_B_block(in_cb_id, in_scalar_cb_id, in_ntiles_hwc, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, 2 /* unpack 1 or 2 faces ) */, unpA_face_r_dim);
for (uint32_t c_i = 0; c_i < in_ntiles_c; ++c_i) {
reduce_tile_math(c_i, 2 /* reduce 1 or 2 faces */);
}
cb_pop_front(in_cb_id, 4);

tile_regs_wait();
tile_regs_commit();
pack_untilize_dst<out_ntiles_c>(out_cb_id, 1, 0, 1, 2); /* pack 1 row (1x16 or 1x32) */
tile_regs_release();

cb_push_back(out_cb_id, 1);
}

namespace NAMESPACE{
void MAIN{
constexpr uint32_t out_cb_id = tt::CB::c_out0;
constexpr uint32_t in1_cb_id = tt::CB::c_in1;
constexpr uint32_t bias_cb_id = tt::CB::c_in2;
constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4;
constexpr uint32_t in2_cb_id = tt::CB::c_intermed0;

constexpr uint32_t in_ntiles_hw = get_compile_time_arg_val(0);
constexpr uint32_t in_ntiles_c = get_compile_time_arg_val(1);
constexpr uint32_t in_ntiles_hwc = get_compile_time_arg_val(2);
constexpr uint32_t window_size_hw = get_compile_time_arg_val(3);
constexpr uint32_t out_h = get_compile_time_arg_val(4);
constexpr uint32_t out_w = get_compile_time_arg_val(5);
constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7);

constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(8);
constexpr uint32_t num_output_tiles = out_ntiles_c; //* nblocks;

tilizeA_B_reduce_init<false, true>(in1_cb_id, in_scalar_cb_id, in_ntiles_hwc, out_cb_id, 2, 4);
pack_untilize_dst_init_short<num_output_tiles>(out_cb_id, 1, 2); /* pack 1 row (1x16 or 1x32) */
for(uint32_t i = 0; i < nsticks_per_core_by_nblocks; i++){
cb_wait_front(in_scalar_cb_id, 1);
reduce_h_fused<in_ntiles_hw, in_ntiles_c, out_ntiles_c, window_size_hw>(in1_cb_id,
in_scalar_cb_id, in_ntiles_hwc, i, out_cb_id);
cb_pop_front(in_scalar_cb_id, 1);
}
} // MAIN
} //NAMESPACE
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "dataflow_api.h"
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp"

#define ALWI inline __attribute__((always_inline))

// Fill given four values into the memory starting at the given address.
// WARNING: Use with caution as there's no memory protection. Make sure size is within limits
ALWI bool fill_four_val(uint32_t begin_addr, uint16_t val, uint16_t val1, uint16_t val2, uint16_t val3) {
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(begin_addr);

ptr[0] = (val | (val1 << 16));
ptr[1] = (val2 | (val3 << 16));
return true;
}


void kernel_main() {

uint32_t stick_nbytes = get_arg_val<uint32_t>(0);
uint32_t in_image_rows_per_core = get_arg_val<uint32_t>(1);
uint32_t scale_h = get_arg_val<uint32_t>(2);
uint32_t scale_w = get_arg_val<uint32_t>(3);
uint32_t in_w = get_arg_val<uint32_t>(4);
uint32_t out_w = get_arg_val<uint32_t>(5);
uint32_t src1_addr = get_arg_val<uint32_t>(6);
uint32_t read_offset = get_arg_val<uint32_t>(8);
uint32_t is_last_row = get_arg_val<uint32_t>(9);
uint32_t in_h = 1;
constexpr bool src1_is_dram = false;

constexpr uint32_t in_cb_id = get_compile_time_arg_val(0);
constexpr uint32_t out_cb_id = tt::CB::c_in1;
constexpr uint32_t is_reader = get_compile_time_arg_val(2);

uint32_t in_image_row_nbytes = in_w * stick_nbytes;
uint32_t out_image_row_nbytes = out_w * stick_nbytes;
uint32_t reader_image_rows_per_core = (in_image_rows_per_core + is_reader) / 2;
uint32_t writer_image_rows_per_core = in_image_rows_per_core / 2;
uint32_t image_row_begin = is_reader ? 0 : reader_image_rows_per_core;
uint32_t image_row_end = is_reader ? reader_image_rows_per_core : in_image_rows_per_core;
uint32_t l1_read_addr = get_read_ptr(in_cb_id); //+ image_row_begin * in_image_row_nbytes;
constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4;

// assuming shard begins with a new row. TODO: generalize?
float scale_h_inv = 1.0f / scale_h;
float scale_w_inv = 1.0f / scale_w;
float x, y, x_index, y_index, dx, dy;
y_index = (float)(0.5f) * (float)scale_h_inv + 0.5f;
for (uint32_t image_row = 0 ; image_row < in_image_rows_per_core * scale_h; ++image_row){
x_index = (float)(0.5f) * (float)scale_w_inv -0.5f;
for(uint32_t j=0; j < in_w * scale_w; j++){
cb_reserve_back(out_cb_id, 4);
cb_reserve_back(in_scalar_cb_id, 1);

x = x_index < 0 ? 0 : x_index;
y = y_index < read_offset ? read_offset : y_index;
dx = x - int(x);
dy = y - int(y);

uint32_t x1 = int(x);
uint32_t y1 = int(y);
uint32_t x2 = min(x1 + 1, in_w-1);
uint32_t y2 = y1 + 1; //, in_image_rows_per_core - 1);
if(is_last_row){
y2 = min(y2, in_image_rows_per_core); //if last row, y2 should be in_image_rows_per_core
}

fill_four_val(get_write_ptr(in_scalar_cb_id), float_to_bfloat16((1-dx) * (1-dy)),
float_to_bfloat16(dx * (1 - dy)), float_to_bfloat16((1 - dx) * dy), float_to_bfloat16(dx * dy));

uint32_t l1_write_addr = get_write_ptr(out_cb_id);
uint32_t l1_read_addr_temp = l1_read_addr + x1 * stick_nbytes + y1 * in_w * stick_nbytes;
//1st tile
uint64_t src_noc_addr = get_noc_addr(l1_read_addr_temp);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
l1_write_addr += stick_nbytes;

//2nd tile
l1_read_addr_temp = l1_read_addr + y1 * in_w * stick_nbytes + x2 * stick_nbytes;
src_noc_addr = get_noc_addr(l1_read_addr_temp);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
l1_write_addr += stick_nbytes;

//3rd tile
l1_read_addr_temp = l1_read_addr + y2 * in_w * stick_nbytes + x1 * stick_nbytes;
src_noc_addr = get_noc_addr(l1_read_addr_temp);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
l1_write_addr += stick_nbytes;

//4th tile
l1_read_addr_temp = l1_read_addr + y2 * in_w * stick_nbytes + x2 * stick_nbytes;
src_noc_addr = get_noc_addr(l1_read_addr_temp);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
l1_write_addr += stick_nbytes;

//push scaler and data into cb.
noc_async_read_barrier();
cb_push_back(out_cb_id, 4);
cb_push_back(in_scalar_cb_id, 1);
x_index += scale_w_inv;
}
y_index += scale_h_inv;
}
}
Loading

0 comments on commit 1673fd9

Please sign in to comment.