Skip to content

Commit

Permalink
Merge branch 'main' into dimitri/fix-package-workflow-tagging
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitri-tenstorrent authored Dec 10, 2024
2 parents 2a6cdf0 + b32ae29 commit 664f5d9
Show file tree
Hide file tree
Showing 14 changed files with 27 additions and 57 deletions.
32 changes: 1 addition & 31 deletions ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,36 +52,6 @@ ttnn::Tensor convert_tile_to_rm(
(tensor.get_dtype() == DataType::BFLOAT8_B) ? ttnn::typecast(new_tensor, tensor.get_dtype()) : new_tensor;
return new_tensor;
}
ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) {
//This function is due to embedding issue 15558, once the issue is fixed we want to delete it
tt::log_warning("host_reshape is deprecated and will be removed in the near future");
if (!ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE)) {
return tensor.reshape(shape);
}
auto tensor_shape = tensor.shape();
auto layout = tensor.layout();
auto device = tensor.device();
auto memory_config = tensor.memory_config();
auto host_tensor = tensor.cpu();
auto rm_tensor = ttnn::to_layout(host_tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr);

if (tensor_shape.has_tile_padding()) {
ttnn::Tensor slice_input;
auto host_tensor_4d = unsqueeze_to_4D(rm_tensor);
auto tensor_shape_4d = host_tensor_4d.shape();
ttnn::SmallVector<uint32_t> begins({0, 0, 0, 0});
ttnn::SmallVector<uint32_t> ends(
{tensor_shape_4d[0], tensor_shape_4d[1], tensor_shape_4d[2], tensor_shape_4d[3]});
ttnn::SmallVector<uint32_t> step({1, 1, 1, 1});
host_tensor_4d = ttnn::slice(host_tensor_4d, begins, ends, step, std::nullopt);
host_tensor = squeeze_from_4D(host_tensor_4d, tensor_shape.rank());
}
auto host_reshape_tensor = rm_tensor.reshape(shape);
auto final_layout_tensor =
ttnn::to_layout(host_reshape_tensor, layout, std::nullopt, std::nullopt, (Device*)nullptr);
auto device_tensor = ttnn::data_transfer_to_device(final_layout_tensor, device, memory_config);
return device_tensor;
}

//Wrapper to turn the ND-> MD problem into 3D->3D for tiled and 2D->2D for Row Major

Expand Down Expand Up @@ -399,7 +369,7 @@ ttnn::Tensor ReshapeViewOperation::invoke(
return tensor.reshape(shape);
}
//This is a completely incorrect test but it is due to issue 15558
return detail::host_reshape(tensor, shape);
TT_FATAL(false, "Attempting to reshape between two shapes with different volumes");
}
// Catch-all
// Do the reshape in row-major
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,14 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio
// How many tiles to store per input CB (double buffer)
constexpr uint32_t num_tiles_per_cb = 2;
auto [a_cb, a_cb_handle] =
create_cb(tt::CB::c_in0, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);
create_cb(tt::CBIndex::c_0, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format);
auto [c_cb, c_cb_handle] =
create_cb(tt::CB::c_out0, program, all_device_cores, c_single_tile_size, num_tiles_per_cb, c_data_format);
create_cb(tt::CBIndex::c_2, program, all_device_cores, c_single_tile_size, num_tiles_per_cb, c_data_format);

// If b is a scalar, we only need one tile in the CB
uint32_t b_num_tiles_per_cb = b_buffer != nullptr ? num_tiles_per_cb : 1;
auto [b_cb, b_cb_handle] =
create_cb(tt::CB::c_in1, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format);
create_cb(tt::CBIndex::c_1, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format);

auto a_is_dram = static_cast<uint32_t>(a_buffer->buffer_type() == tt_metal::BufferType::DRAM);
bool b_is_dram = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ void MAIN {
return;
}

constexpr auto cb_in0 = tt::CB::c_in0;
constexpr auto cb_in1 = tt::CB::c_in1;
constexpr auto cb_out0 = tt::CB::c_out0;
constexpr auto cb_in0 = tt::CBIndex::c_0;
constexpr auto cb_in1 = tt::CBIndex::c_1;
constexpr auto cb_out0 = tt::CBIndex::c_2;

#if BCAST_INPUT
auto cb_bcast = cb_in1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ namespace NAMESPACE {
void MAIN {
uint32_t num_tiles = get_arg_val<uint32_t>(0);

constexpr auto cb_in0 = tt::CB::c_in0;
constexpr auto cb_in1 = tt::CB::c_in1;
constexpr auto cb_out0 = tt::CB::c_out0;
constexpr auto cb_in0 = tt::CBIndex::c_0;
constexpr auto cb_in1 = tt::CBIndex::c_1;
constexpr auto cb_out0 = tt::CBIndex::c_2;

binary_op_init_common(cb_in0, cb_in1, cb_out0);
add_tiles_init();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ namespace NAMESPACE {
void MAIN {
uint32_t num_tiles = get_arg_val<uint32_t>(0);

constexpr auto cb_in0 = tt::CB::c_in0;
constexpr auto cb_in1 = tt::CB::c_in1;
constexpr auto cb_out0 = tt::CB::c_out0;
constexpr auto cb_in0 = tt::CBIndex::c_0;
constexpr auto cb_in1 = tt::CBIndex::c_1;
constexpr auto cb_out0 = tt::CBIndex::c_2;

binary_op_init_common(cb_in0, cb_in1, cb_out0);
add_tiles_init();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ void kernel_main() {

constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;

constexpr auto cb_id_src = tt::CB::c_in0;
constexpr auto cb_id_src = tt::CBIndex::c_0;
constexpr uint32_t onetile = 1;

const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void kernel_main() {

constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;

constexpr auto cb_id_src = tt::CB::c_in0;
constexpr auto cb_id_src = tt::CBIndex::c_0;
constexpr uint32_t onetile = 1;

const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ void kernel_main() {

constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;

constexpr auto cb_id_src = tt::CB::c_in0;
constexpr auto cb_id_src = tt::CBIndex::c_0;
constexpr uint32_t onetile = 1;

const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void kernel_main() {

constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;

constexpr auto cb_id_src = tt::CB::c_in0;
constexpr auto cb_id_src = tt::CBIndex::c_0;
constexpr uint32_t onetile = 1;

const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ void kernel_main() {

constexpr uint32_t onetile = 1;

constexpr auto cb_id_src = tt::CB::c_in1;
constexpr auto cb_id_src = tt::CBIndex::c_1;
constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;
const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
const DataFormat src_data_format = get_dataformat(cb_id_src);

const InterleavedAddrGenFast<src_is_dram> src = {
.bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format};

constexpr auto cb_id_dst = tt::CB::c_out0;
constexpr auto cb_id_dst = tt::CBIndex::c_2;
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst);
const DataFormat dst_data_format = get_dataformat(cb_id_dst);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ void kernel_main() {

constexpr uint32_t onetile = 1;

constexpr auto cb_id_src = tt::CB::c_in1;
constexpr auto cb_id_src = tt::CBIndex::c_1;
constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;
const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
const DataFormat src_data_format = get_dataformat(cb_id_src);

const InterleavedAddrGenFast<src_is_dram> src = {
.bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format};

constexpr auto cb_id_dst = tt::CB::c_out0;
constexpr auto cb_id_dst = tt::CBIndex::c_2;
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst);
const DataFormat dst_data_format = get_dataformat(cb_id_dst);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ void kernel_main() {

constexpr uint32_t onetile = 1;

constexpr auto cb_id_src = tt::CB::c_in1;
constexpr auto cb_id_src = tt::CBIndex::c_1;
constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;
const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
const DataFormat src_data_format = get_dataformat(cb_id_src);

const InterleavedAddrGenFast<src_is_dram> src = {
.bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format};

constexpr auto cb_id_dst = tt::CB::c_out0;
constexpr auto cb_id_dst = tt::CBIndex::c_2;
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst);
const DataFormat dst_data_format = get_dataformat(cb_id_dst);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ void kernel_main() {

constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;

constexpr auto cb_id_src = tt::CB::c_in1;
constexpr auto cb_id_dst = tt::CB::c_out0;
constexpr auto cb_id_src = tt::CBIndex::c_1;
constexpr auto cb_id_dst = tt::CBIndex::c_2;
constexpr uint32_t onetile = 1;

const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ void kernel_main() {

constexpr uint32_t onetile = 1;

constexpr auto cb_id_src = tt::CB::c_in1;
constexpr auto cb_id_src = tt::CBIndex::c_1;
constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1;
const uint32_t src_tile_bytes = get_tile_size(cb_id_src);
const DataFormat src_data_format = get_dataformat(cb_id_src);

const InterleavedAddrGenFast<src_is_dram> src = {
.bank_base_address = src_addr, .page_size = src_tile_bytes, .data_format = src_data_format};

constexpr auto cb_id_dst = tt::CB::c_out0;
constexpr auto cb_id_dst = tt::CBIndex::c_2;
constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1;
const uint32_t dst_tile_bytes = get_tile_size(cb_id_dst);
const DataFormat dst_data_format = get_dataformat(cb_id_dst);
Expand Down

0 comments on commit 664f5d9

Please sign in to comment.