Skip to content

Commit

Permalink
#0: Port Matmul, Conv, and AllGather to compute_ouput_specs (#15978)
Browse files Browse the repository at this point in the history
### Ticket

### Problem description
We're continuing to port ops from compute_output_shapes to the new
compute_output_specs

### What's changed
Ported matmul, all_gather, all_gather_matmul, conv2d, and halo ops to
use compute_output_specs

### Checklist
- [x] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12304931255)
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored Dec 13, 2024
1 parent 660d249 commit b314a49
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 116 deletions.
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ constexpr bool implements_get_parallelization_strategy() {
return std::experimental::is_detected_v<has_get_parallelization_strategy_t, T, const Tensors&>;
}

} // namespace detail

template <typename ConcreteOperation>
auto default_create_output_tensors(
const ConcreteOperation& operation,
Expand Down Expand Up @@ -427,8 +429,6 @@ auto default_create_output_tensors(
return output_tensors;
}

} // namespace detail

template <class OutputTensorsT = Tensors>
struct DeviceOperation final {
using storage_t = std::array<std::byte, 1152>;
Expand Down Expand Up @@ -628,7 +628,7 @@ struct DeviceOperation final {
"create_output_tensors");
return operation.create_output_tensors(input_tensors);
} else if constexpr (detail::implements_compute_output_specs<T>()) {
return detail::default_create_output_tensors(operation, input_tensors, output_tensors);
return default_create_output_tensors(operation, input_tensors, output_tensors);
} else {
static_assert(
tt::stl::concepts::always_false_v<T>,
Expand Down
31 changes: 14 additions & 17 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,27 +186,24 @@ void AllGather::validate(const std::vector<Tensor>& input_tensors) const {
}
}

std::vector<ttnn::SimpleShape> AllGather::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
auto shape = input_tensors[0].get_padded_shape(); // TODO: Replace with get_logical_shape()
shape[this->dim] *= this->ring_size;
return std::vector<ttnn::SimpleShape>(input_tensors.size(), shape);
}
std::vector<ttnn::TensorSpec> AllGather::compute_output_specs(const std::vector<Tensor>& input_tensors) const {
auto output_shape = input_tensors[0].get_padded_shape(); // TODO: Replace with get_logical_shape()
output_shape[this->dim] *= this->ring_size;

std::vector<Tensor> AllGather::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors[0];
auto tile = input_tensor.get_tensor_spec().tile();
TensorSpec spec(
output_shape,
TensorLayout(input_tensor.get_dtype(), input_tensor.get_tensor_spec().page_config(), output_mem_config));
if (this->output_mem_config.is_sharded()) {
return {create_device_tensor(
this->compute_output_shapes(input_tensors).at(0),
input_tensor.get_dtype(),
input_tensor.get_layout(),
input_tensor.device(),
this->output_mem_config,
tile)};
} else {
return operation::generic_create_output_tensors(
*this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config, tile);
return {TensorSpec(
output_shape,
TensorLayout(input_tensor.get_dtype(), input_tensor.get_tensor_spec().page_config(), output_mem_config))};
}
return std::vector<TensorSpec>(input_tensors.size(), spec);
}

std::vector<Tensor> AllGather::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
return operation::default_create_output_tensors(*this, input_tensors, {});
}

operation::ProgramWithCallbacks AllGather::create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct AllGather {
const ccl::Topology topology;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, std::vector<Tensor> &output_tensors) const;
};
Expand Down
36 changes: 10 additions & 26 deletions ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ void OptimizedConvNew::validate(const std::vector<Tensor>& input_tensors, const
sliding_window_config,
parallelization_config.num_cores_nhw,
out_block_h_ntiles);
uint32_t out_width_ntiles = this->compute_output_shapes(input_tensors).at(0)[-1] / TILE_WIDTH;
uint32_t out_width_ntiles = this->compute_output_specs(input_tensors).at(0).padded_shape()[-1] / TILE_WIDTH;
if(this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
TT_FATAL(per_core_out_matrix_width_ntiles == out_width_ntiles, "Error");
TT_FATAL(this->block_config.out_subblock_w_ntiles == out_width_ntiles || this->block_config.out_subblock_h_ntiles == 1, "Error");
Expand All @@ -136,22 +136,13 @@ void OptimizedConvNew::validate(const std::vector<Tensor>& input_tensors, const
}
}

std::vector<tt::tt_metal::LegacyShape> OptimizedConvNew::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
std::vector<TensorSpec> OptimizedConvNew::compute_output_specs(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor_a_shape = this->input_tensor_shape;
uint32_t batch_size = input_tensor_a_shape[0];
uint32_t conv_activation_h = input_tensor_a_shape[1];
uint32_t conv_activation_w = input_tensor_a_shape[2];
// TODO: clean up here
uint32_t filter_h = (uint32_t)sliding_window_config.window_hw.first; // filter_h
uint32_t filter_w = (uint32_t)sliding_window_config.window_hw.second; // filter_W
uint32_t stride_h = (uint32_t)sliding_window_config.stride_hw.first;
uint32_t stride_w = (uint32_t)sliding_window_config.stride_hw.second;
uint32_t pad_h = (uint32_t)sliding_window_config.pad_hw.first;
uint32_t pad_w = (uint32_t)sliding_window_config.pad_hw.second;

auto output_shape = sliding_window_config.get_output_shape();
uint32_t conv_output_h = output_shape[1];
uint32_t conv_output_w = output_shape[2];
auto sliding_window_output_shape = sliding_window_config.get_output_shape();
uint32_t conv_output_h = sliding_window_output_shape[1];
uint32_t conv_output_w = sliding_window_output_shape[2];

// Tiled output shape is padded shape. Padded to tile shape.
auto shape_w = batch_size * conv_output_h * conv_output_w;
Expand All @@ -160,16 +151,10 @@ std::vector<tt::tt_metal::LegacyShape> OptimizedConvNew::compute_output_shapes(c
auto padded_shape_c = tt::round_up(this->output_channels, TILE_WIDTH);
auto output_padding = Padding(
{{0, 0}, {0, 0}, {0, (padded_shape_w - shape_w)}, {0, (padded_shape_c - shape_c)}}, Padding::PadValue::Zero);
auto output_tensor_shape = ttnn::Shape(tt::tt_metal::LegacyShape({1, 1, padded_shape_w, padded_shape_c}, output_padding));
return {output_tensor_shape.value};
}
auto output_shape = tt::tt_metal::LegacyShape({1, 1, padded_shape_w, padded_shape_c}, output_padding);

std::vector<Tensor> OptimizedConvNew::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
const auto& weight_tensor = input_tensors.at(1);
auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE;
if (this->memory_config.is_sharded()) {
auto output_shape = this->compute_output_shapes(input_tensors).at(0);
if (this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) {
uint32_t total_height_tiles = tt::tt_metal::compute_volume(output_shape) / output_shape[-1] / TILE_HEIGHT;
uint32_t num_cores;
Expand All @@ -188,23 +173,22 @@ std::vector<Tensor> OptimizedConvNew::create_output_tensors(const std::vector<Te
auto shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR};
auto mem_config = this->memory_config;
mem_config.shard_spec = shard_spec;
return {create_device_tensor(output_shape, this->dtype, output_layout, input_tensor.device(), mem_config)};
return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), mem_config, ttnn::Shape(output_shape)))};
} else if(this->memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED) {
uint32_t total_height_tiles = tt::tt_metal::compute_volume(output_shape) / output_shape[-1] / TILE_HEIGHT;
std::array<uint32_t, 2> shard_shape = {tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * TILE_HEIGHT, tt::div_up(this->parallelization_config.per_core_out_matrix_width, TILE_WIDTH) * TILE_WIDTH};
auto shard_grid = this->memory_config.shard_spec.value().grid;
auto shard_spec = ShardSpec{shard_grid, shard_shape, this->memory_config.shard_spec.value().orientation};
auto mem_config = this->memory_config;
mem_config.shard_spec = shard_spec;
return{create_device_tensor(output_shape, this->dtype, output_layout, input_tensor.device(), mem_config)};

return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), mem_config, ttnn::Shape(output_shape)))};
} else if (this->memory_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED) {
return {create_device_tensor(output_shape, this->dtype, output_layout, input_tensor.device(), this->memory_config)};
return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), memory_config, ttnn::Shape(output_shape)))};
} else {
TT_THROW("Unsupported shard scheme");
}
}
return operation::generic_create_output_tensors(*this, input_tensors, this->dtype, output_layout, this->memory_config);
return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), memory_config, ttnn::Shape(output_shape)))};
}

operation::ProgramWithCallbacks OptimizedConvNew::create_program(const std::vector<Tensor>& input_tensors,
Expand Down
3 changes: 1 addition & 2 deletions ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ struct OptimizedConvNew {
use_non_tile_height(use_non_tile_height) {}

void validate(const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
operation::ProgramWithCallbacks create_program(const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors, std::vector<Tensor> &output_tensors) const;

operation::OpPerformanceModel create_op_performance_model(const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors, const std::vector<Tensor> &output_tensors) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,16 @@ void AllGatherMatmul::validate(
}
}

std::vector<ttnn::SimpleShape> AllGatherMatmul::compute_output_shapes(const std::vector<Tensor>& input_tensors) const {
std::vector<ttnn::TensorSpec> AllGatherMatmul::compute_output_specs(const std::vector<Tensor>& input_tensors) const {
// All Gather shape
ttnn::SimpleShape all_gather_output_shape = this->all_gather_struct.compute_output_shapes({input_tensors[0]})[0];
ttnn::SimpleShape datacopy_output_shape = all_gather_output_shape;
ttnn::TensorSpec all_gather_output_shape = this->all_gather_struct.compute_output_specs({input_tensors[0]})[0];
ttnn::TensorSpec datacopy_output_shape = all_gather_output_shape;

// Matmul shape
ttnn::SimpleShape matmul_output_shapes =
this->matmul_struct.compute_output_shapes({input_tensors[1], input_tensors[2]})[0];
ttnn::TensorSpec matmul_output_specs =
this->matmul_struct.compute_output_specs({input_tensors[1], input_tensors[2]})[0];

return {all_gather_output_shape, matmul_output_shapes, datacopy_output_shape};
return {all_gather_output_shape, matmul_output_specs, datacopy_output_shape};
}

std::vector<Tensor> AllGatherMatmul::create_output_tensors(const std::vector<Tensor>& input_tensors) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct AllGatherMatmul {
void validate(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors) const;
std::vector<ttnn::SimpleShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
std::vector<ttnn::TensorSpec> compute_output_specs(const std::vector<Tensor>& input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor>& input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor>& input_tensors,
Expand Down
Loading

0 comments on commit b314a49

Please sign in to comment.