Skip to content

Commit

Permalink
#0: Add mixed precision support for bcast
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed May 17, 2024
1 parent f1c3130 commit 8024d15
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,13 @@ operation::ProgramWithCallbacks bcast_multi_core_h(const Tensor &a, const Tensor

tt_metal::Device *device = a.device();

tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype());
tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype());
tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype());
tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype());

uint32_t single_tile_size = tt_metal::detail::TileSize(cb_data_format);
uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format);
uint32_t src1_single_tile_size = tt_metal::detail::TileSize(src1_cb_data_format);
uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format);

auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
Expand All @@ -69,19 +73,19 @@ operation::ProgramWithCallbacks bcast_multi_core_h(const Tensor &a, const Tensor
uint32_t src0_cb_index = 0;
uint32_t num_input_tiles = 2;

tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, single_tile_size);
tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}})
.set_page_size(src0_cb_index, src0_single_tile_size);
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config);

uint32_t src1_cb_index = 1;
tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src1_cb_index, cb_data_format}})
.set_page_size(src1_cb_index, single_tile_size);
tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}})
.set_page_size(src1_cb_index, src1_single_tile_size);
auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, src1_cb_config);

uint32_t output_cb_index = 16; // output operands start at index 16
uint32_t num_output_tiles = 2;
tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_tiles * single_tile_size, {{output_cb_index, cb_data_format}})
.set_page_size(output_cb_index, single_tile_size);
tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}})
.set_page_size(output_cb_index, dst_single_tile_size);
auto cb_output = tt_metal::CreateCircularBuffer(program, all_device_cores, output_cb_config);

bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,13 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso
shard_spec = output.shard_spec().value();
}

tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype());
tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype());
tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype());
tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype());

uint32_t single_tile_size = tt_metal::detail::TileSize(cb_data_format);
uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format);
uint32_t src1_single_tile_size = tt_metal::detail::TileSize(src1_cb_data_format);
uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format);

auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
Expand Down Expand Up @@ -91,22 +95,22 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso

uint32_t num_input_tiles_cb0 = src0_sharded ? num_tiles_per_shard : num_input_tiles;

tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(num_input_tiles_cb0 * single_tile_size, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, single_tile_size);
tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(num_input_tiles_cb0 * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}})
.set_page_size(src0_cb_index, src0_single_tile_size);
if (src0_sharded) {
src0_cb_config = src0_cb_config.set_globally_allocated_address(*a.buffer());
}
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config);

uint32_t src1_cb_index = 1;
tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src1_cb_index, cb_data_format}})
.set_page_size(src1_cb_index, single_tile_size);
tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}})
.set_page_size(src1_cb_index, src1_single_tile_size);
auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, src1_cb_config);

uint32_t output_cb_index = 16; // output operands start at index 16
uint32_t num_output_tiles = output_sharded ? num_tiles_per_shard : 2;
tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_tiles * single_tile_size, {{output_cb_index, cb_data_format}})
.set_page_size(output_cb_index, single_tile_size);
tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}})
.set_page_size(output_cb_index, dst_single_tile_size);
if (output_sharded) {
output_cb_config = output_cb_config.set_globally_allocated_address(*output.buffer());
}
Expand Down Expand Up @@ -211,7 +215,9 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso
bcast_kernel_id,
compute_with_storage_grid_size,
cb_src0,
single_tile_size,
src0_single_tile_size,
src1_single_tile_size,
dst_single_tile_size,
cb_output
]
(
Expand Down Expand Up @@ -324,12 +330,12 @@ operation::ProgramWithCallbacks bcast_multi_core_hw(const Tensor &a, const Tenso

if (src0_sharded) {
UpdateDynamicCircularBufferAddress(program, cb_src0, *src_buffer_a);
UpdateCircularBufferTotalSize(program, cb_src0, num_tiles_per_core_group_1 * single_tile_size);
UpdateCircularBufferTotalSize(program, cb_src0, num_tiles_per_core_group_1 * src0_single_tile_size);
}

if (out_sharded) {
UpdateDynamicCircularBufferAddress(program, cb_output, *dst_buffer);
UpdateCircularBufferTotalSize(program, cb_output, num_tiles_per_core_group_1 * single_tile_size);
UpdateCircularBufferTotalSize(program, cb_output, num_tiles_per_core_group_1 * dst_single_tile_size);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,13 @@ operation::ProgramWithCallbacks bcast_multi_core_w(const Tensor &a, const Tensor

tt_metal::Device *device = a.device();

tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype());
tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype());
tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype());
tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype());

uint32_t single_tile_size = tt_metal::detail::TileSize(cb_data_format);
uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format);
uint32_t src1_single_tile_size = tt_metal::detail::TileSize(src1_cb_data_format);
uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format);

auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
Expand All @@ -69,19 +73,19 @@ operation::ProgramWithCallbacks bcast_multi_core_w(const Tensor &a, const Tensor
uint32_t src0_cb_index = 0;
uint32_t num_input_tiles = 2;

tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, single_tile_size);
tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}})
.set_page_size(src0_cb_index, src0_single_tile_size);
auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config);

uint32_t src1_cb_index = 1;
tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src1_cb_index, cb_data_format}})
.set_page_size(src1_cb_index, single_tile_size);
tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}})
.set_page_size(src1_cb_index, src1_single_tile_size);
auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, src1_cb_config);

uint32_t output_cb_index = 16; // output operands start at index 16
uint32_t num_output_tiles = 2;
tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_tiles * single_tile_size, {{output_cb_index, cb_data_format}})
.set_page_size(output_cb_index, single_tile_size);
tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}})
.set_page_size(output_cb_index, dst_single_tile_size);
auto cb_output = tt_metal::CreateCircularBuffer(program, all_device_cores, output_cb_config);


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,29 @@ operation::ProgramWithCallbacks bcast_single_core(const Tensor &a, const Tensor
// This should allocate a DRAM buffer on the device
tt_metal::Device *device = a.device();

tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype());
tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype());
tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype());
tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype());

uint32_t single_tile_size = tt_metal::detail::TileSize(cb_data_format);
uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format);
uint32_t src1_single_tile_size = tt_metal::detail::TileSize(src1_cb_data_format);
uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format);

uint32_t src0_cb_index = 0;
uint32_t num_input_tiles = 2;
tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, single_tile_size);
tt_metal::CircularBufferConfig src0_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}})
.set_page_size(src0_cb_index, src0_single_tile_size);
auto cb_src0 = tt_metal::CreateCircularBuffer(program, core, src0_cb_config);

uint32_t src1_cb_index = 1;
tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * single_tile_size, {{src1_cb_index, cb_data_format}})
.set_page_size(src1_cb_index, single_tile_size);
tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}})
.set_page_size(src1_cb_index, src1_single_tile_size);
auto cb_src1 = tt_metal::CreateCircularBuffer(program, core, src1_cb_config);

uint32_t output_cb_index = 16; // output operands start at index 16
uint32_t num_output_tiles = 2;
tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_tiles * single_tile_size, {{output_cb_index, cb_data_format}})
.set_page_size(output_cb_index, single_tile_size);
tt_metal::CircularBufferConfig output_cb_config = tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}})
.set_page_size(output_cb_index, dst_single_tile_size);
auto cb_output = tt_metal::CreateCircularBuffer(program, core, output_cb_config);

uint32_t bnc1 = (bN*bC == 1);
Expand Down

0 comments on commit 8024d15

Please sign in to comment.