Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bilinear support for upsample. #12385

Merged
merged 3 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 139 additions & 2 deletions tests/ttnn/unit_tests/operations/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

import pytest
import math
from loguru import logger
from typing import Union, Tuple

import torch
import torch.nn as nn
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_grayskull, skip_for_blackhole
from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc_without_tensor_printout


TILE_WIDTH = 32
Expand Down Expand Up @@ -222,3 +223,139 @@ def test_upsample_multi_core(device, input_shape, scale_h, scale_w, shard_strate
assert allclose
assert isclose
assert isequal


@skip_for_grayskull()
@skip_for_blackhole()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"batch_size, num_channels, height, width, scale_h, scale_w",
(
(1, 256, 16, 16, 8, 8), # 256x256
(1, 256, 32, 32, 4, 4), # 256x256
(1, 256, 64, 64, 2, 2), # 256x256
(1, 256, 128, 128, 1, 1), # 256x256
),
)
@pytest.mark.parametrize("shard_strategy", [ttnn.ShardStrategy.HEIGHT])
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.HiFi4, ttnn.MathFidelity.LoFi])
@pytest.mark.parametrize("math_approx_mode", [True, False])
def test_bilinear_multi_core(
device,
use_program_cache,
batch_size,
num_channels,
height,
width,
scale_h,
scale_w,
shard_strategy,
math_fidelity,
math_approx_mode,
):
## input shape is N C H W
input_shape = [batch_size, num_channels, height, width]
torch.manual_seed(0)
input = torch.rand(input_shape, dtype=torch.bfloat16)

## golden reference using torch
scale_factor = (scale_h, scale_w)
torch_upsample = nn.Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=False)
torch_result = torch_upsample(input)

## permute to N H W C, which is what the upsample op expects
tt_input = input.permute(0, 2, 3, 1)

num_bytes = 2 ## only BFLOAT16 is supported

## calculate ncores, corresponding grid_size and in_shard_shape based on the input_shape
ncores = None
device_grid = device.compute_with_storage_grid_size()
max_grid_size = (device_grid.y, device_grid.x)
if shard_strategy == ttnn.ShardStrategy.HEIGHT:
## nsticks per shard should be divisible by in_w
max_nshards = min(batch_size * height * width, max_grid_size[0] * max_grid_size[1])
nshards = max_nshards
while nshards > 0:
if batch_size * height * width % (nshards * TILE_WIDTH) == 0:
break
nshards -= 1
ncores = nshards
elif shard_strategy == ttnn.ShardStrategy.BLOCK:
max_nshards_h = min(batch_size * height, max_grid_size[0]) ## height along NHW
max_nshards_w = min(num_channels, max_grid_size[1]) ## width along C
## find nshards_h along NHW
nshards_h = max_nshards_h
while nshards_h > 0:
if batch_size * height % nshards_h == 0:
break
nshards_h -= 1
## find nshards_w along C
nshards_w = max_nshards_w
while nshards_w > 0:
## make sure: 1. nshards_w divides num_channels, and 2. shard_shape[1] is aligned to 32B
if num_channels % nshards_w == 0 and math.ceil(num_channels * num_bytes / nshards_w) % TILE_WIDTH == 0:
break
nshards_w -= 1
if nshards_w == 0 or nshards_h == 0:
raise ValueError("nshards_h or nshards_w is 0")
ncores = (nshards_h, nshards_w)

shard_grid = get_shard_grid_from_num_cores(device, ncores)
shard_orientation = ttnn.ShardOrientation.ROW_MAJOR

if shard_strategy == ttnn.ShardStrategy.BLOCK:
tensor_memory_layout = ttnn.types.TensorMemoryLayout.BLOCK_SHARDED
elif shard_strategy == ttnn.ShardStrategy.HEIGHT:
tensor_memory_layout = ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED

## input shard
if shard_strategy == ttnn.ShardStrategy.BLOCK:
shard_height = math.ceil(batch_size * height * width / ncores[0])
shard_width = math.ceil(num_channels / ncores[1])
elif shard_strategy == ttnn.ShardStrategy.HEIGHT:
shard_height = math.ceil(batch_size * height * width / ncores)
shard_width = num_channels
# breakpoint()
shard_shape = (shard_height, shard_width)
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, shard_orientation, False)
in_sharded_mem_config = ttnn.MemoryConfig(tensor_memory_layout, ttnn.types.BufferType.L1, shard_spec)

## output shard
shard_height = shard_height * scale_h * scale_w
shard_shape = (shard_height, shard_width)
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, shard_orientation, False)

compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=math_fidelity,
math_approx_mode=math_approx_mode,
fp32_dest_acc_en=False,
)

out_sharded_mem_config = ttnn.MemoryConfig(tensor_memory_layout, ttnn.types.BufferType.L1, shard_spec)

logger.debug(f"in_shard_mem_config: {in_sharded_mem_config}")
logger.debug(f"out_shard_mem_config: {out_sharded_mem_config}")

## ttnn uses NHWC, so need to set scale_factor_c = 1
scale_factor = (scale_h, scale_w, 1)
input_tensor = ttnn.from_torch(tt_input, device=device)
input_tensor = ttnn.to_memory_config(input_tensor, memory_config=in_sharded_mem_config)
output_tensor = ttnn.upsample(
input_tensor,
scale_factor,
mode="bilinear",
memory_config=out_sharded_mem_config,
compute_kernel_config=compute_kernel_config,
)
output_tensor = ttnn.to_memory_config(output_tensor, memory_config=ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor)

## compare the results
torch_result = torch_result.permute(0, 2, 3, 1)
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_result, output_tensor, pcc=0.999)
allclose = torch.allclose(output_tensor, torch_result, atol=1e-1, rtol=1e-1)
logger.info(pcc_msg)

assert allclose
assert passing
1 change: 1 addition & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/maxpool/max_pool2d_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device/upsample_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device//upsample_bilinear_program_factory_multicore.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_singlecore.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/upsample.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/pool/upsample/upsample_pybind.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2024

//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "compute_kernel_api/tilize.h"
#include "compute_kernel_api/reduce.h"
#include "compute_kernel_api/pack_untilize.h"

template<uint32_t in_ntiles_hw, uint32_t in_ntiles_c, uint32_t out_ntiles_c, uint32_t unpA_face_r_dim>
inline void reduce_h_fused(
const uint32_t in_cb_id,
const uint32_t in_scalar_cb_id,
const uint32_t in_ntiles_hwc,
const uint32_t in_stick_index,
const uint32_t out_cb_id) {

cb_reserve_back(out_cb_id, 1);
tile_regs_acquire();
cb_wait_front(in_cb_id, 4);
unpack_tilizeA_B_block(in_cb_id, in_scalar_cb_id, in_ntiles_hwc, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, 2 /* unpack 1 or 2 faces ) */, unpA_face_r_dim);
for (uint32_t c_i = 0; c_i < in_ntiles_c; ++c_i) {
reduce_tile_math(c_i, 2 /* reduce 1 or 2 faces */);
}
cb_pop_front(in_cb_id, 4);

tile_regs_wait();
tile_regs_commit();
pack_untilize_dst<out_ntiles_c>(out_cb_id, 1, 0, 1, 2); /* pack 1 row (1x16 or 1x32) */
tile_regs_release();

cb_push_back(out_cb_id, 1);
}

namespace NAMESPACE{
void MAIN{
constexpr uint32_t out_cb_id = tt::CB::c_out0;
constexpr uint32_t in1_cb_id = tt::CB::c_in1;
constexpr uint32_t bias_cb_id = tt::CB::c_in2;
constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4;
constexpr uint32_t in2_cb_id = tt::CB::c_intermed0;

constexpr uint32_t in_ntiles_hw = get_compile_time_arg_val(0);
constexpr uint32_t in_ntiles_c = get_compile_time_arg_val(1);
constexpr uint32_t in_ntiles_hwc = get_compile_time_arg_val(2);
constexpr uint32_t window_size_hw = get_compile_time_arg_val(3);
constexpr uint32_t out_h = get_compile_time_arg_val(4);
constexpr uint32_t out_w = get_compile_time_arg_val(5);
constexpr uint32_t out_ntiles_c = get_compile_time_arg_val(7);

constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(8);
constexpr uint32_t num_output_tiles = out_ntiles_c; //* nblocks;

tilizeA_B_reduce_init<false, true>(in1_cb_id, in_scalar_cb_id, in_ntiles_hwc, out_cb_id, 2, 4);
pack_untilize_dst_init_short<num_output_tiles>(out_cb_id, 1, 2); /* pack 1 row (1x16 or 1x32) */
for(uint32_t i = 0; i < nsticks_per_core_by_nblocks; i++){
cb_wait_front(in_scalar_cb_id, 1);
reduce_h_fused<in_ntiles_hw, in_ntiles_c, out_ntiles_c, window_size_hw>(in1_cb_id,
in_scalar_cb_id, in_ntiles_hwc, i, out_cb_id);
cb_pop_front(in_scalar_cb_id, 1);
}
} // MAIN
} //NAMESPACE
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "dataflow_api.h"
#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp"

#define ALWI inline __attribute__((always_inline))

// Fill given four values into the memory starting at the given address.
// WARNING: Use with caution as there's no memory protection. Make sure size is within limits
ALWI void fill_four_val(uint32_t begin_addr, uint16_t val, uint16_t val1, uint16_t val2, uint16_t val3) {
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(begin_addr);

ptr[0] = (val | (val1 << 16));
ptr[1] = (val2 | (val3 << 16));
}

ALWI float uint32_to_float(uint32_t f)
{
float ret;
std::memcpy(&ret, &f, sizeof(float));
return ret;
}


void kernel_main() {

uint32_t stick_nbytes = get_arg_val<uint32_t>(0);
uint32_t in_image_rows_per_core = get_arg_val<uint32_t>(1);
uint32_t scale_h = get_arg_val<uint32_t>(2);
uint32_t scale_w = get_arg_val<uint32_t>(3);
uint32_t in_w = get_arg_val<uint32_t>(4);
uint32_t out_w = get_arg_val<uint32_t>(5);
uint32_t src1_addr = get_arg_val<uint32_t>(6);
uint32_t read_offset = get_arg_val<uint32_t>(8);
uint32_t is_last_row = get_arg_val<uint32_t>(9);
uint32_t in_h = 1;
constexpr bool src1_is_dram = false;

constexpr uint32_t in_cb_id = get_compile_time_arg_val(0);
constexpr uint32_t out_cb_id = tt::CB::c_in1;
//constexpr uint32_t is_reader = get_compile_time_arg_val(2);
constexpr uint32_t scale_h_inv_comp = get_compile_time_arg_val(3);
constexpr uint32_t scale_w_inv_comp = get_compile_time_arg_val(4);
constexpr uint32_t y_index_comp = get_compile_time_arg_val(5);
constexpr uint32_t x_index_compute_comp = get_compile_time_arg_val(6);

uint32_t l1_read_addr = get_read_ptr(in_cb_id);
constexpr uint32_t in_scalar_cb_id = tt::CB::c_in4;

// assuming shard begins with a new row. TODO: generalize?
float scale_h_inv = uint32_to_float(scale_h_inv_comp);
float scale_w_inv = uint32_to_float(scale_w_inv_comp);
float x, y, x_index, y_index, dx, dy;
y_index = uint32_to_float(y_index_comp);
float x_index_compute = uint32_to_float(x_index_compute_comp);
for (uint32_t image_row = 0 ; image_row < in_image_rows_per_core * scale_h; ++image_row){
x_index = x_index_compute;
for(uint32_t j=0; j < in_w * scale_w; j++){
cb_reserve_back(out_cb_id, 4);
cb_reserve_back(in_scalar_cb_id, 1);

x = x_index < 0 ? 0 : x_index;
y = y_index < read_offset ? read_offset : y_index;
dx = x - int(x);
dy = y - int(y);

uint32_t x1 = int(x);
uint32_t y1 = int(y);
uint32_t x2 = min(x1 + 1, in_w-1);
uint32_t y2 = y1 + 1;
if(is_last_row){
y2 = min(y2, in_image_rows_per_core); //if last row, y2 should be in_image_rows_per_core
}

fill_four_val(get_write_ptr(in_scalar_cb_id), float_to_bfloat16((1-dx) * (1-dy)),
float_to_bfloat16(dx * (1 - dy)), float_to_bfloat16((1 - dx) * dy), float_to_bfloat16(dx * dy));

uint32_t l1_write_addr = get_write_ptr(out_cb_id);
uint32_t l1_read_addr_temp = l1_read_addr + x1 * stick_nbytes + y1 * in_w * stick_nbytes;
//1st tile
uint64_t src_noc_addr = get_noc_addr(l1_read_addr_temp);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
l1_write_addr += stick_nbytes;

//2nd tile
l1_read_addr_temp = l1_read_addr + y1 * in_w * stick_nbytes + x2 * stick_nbytes;
src_noc_addr = get_noc_addr(l1_read_addr_temp);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
l1_write_addr += stick_nbytes;

//3rd tile
l1_read_addr_temp = l1_read_addr + y2 * in_w * stick_nbytes + x1 * stick_nbytes;
src_noc_addr = get_noc_addr(l1_read_addr_temp);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
l1_write_addr += stick_nbytes;

//4th tile
l1_read_addr_temp = l1_read_addr + y2 * in_w * stick_nbytes + x2 * stick_nbytes;
src_noc_addr = get_noc_addr(l1_read_addr_temp);
noc_async_read(src_noc_addr, l1_write_addr, stick_nbytes);
l1_write_addr += stick_nbytes;

//push scaler and data into cb.
noc_async_read_barrier();
cb_push_back(out_cb_id, 4);
cb_push_back(in_scalar_cb_id, 1);
x_index += scale_w_inv;
}
y_index += scale_h_inv;
}
}
Loading
Loading