Skip to content

Commit

Permalink
#7512: Use bmm autoformat path for outer op
Browse files Browse the repository at this point in the history
  • Loading branch information
bbradelTT committed May 6, 2024
1 parent 5211c52 commit e39cadd
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
22 changes: 18 additions & 4 deletions tt_eager/tt_dnn/op_library/bmm/bmm_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ inline Tensor matmul(
std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt,
bool untilize_out = false,
std::optional<const CoreCoord> user_core_coord = std::nullopt,
std::optional<const bool> input_b_is_batched = std::nullopt) {
std::optional<const bool> input_b_is_batched = std::nullopt,
const bool needs_autoformat = false) {
std::vector<std::optional<const Tensor>> optional_input_tensors = {};
std::vector<Tensor> output_tensors;
if (bias) {
Expand All @@ -399,8 +400,9 @@ inline Tensor matmul(
output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a, input_tensor_b}))};
}

operation::launch_op(
[program_config, mem_config, output_dtype, compute_kernel_config, untilize_out, user_core_coord, input_b_is_batched] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors) mutable -> std::vector<Tensor> {
if (!needs_autoformat) {
operation::launch_op(
[program_config, mem_config, output_dtype, compute_kernel_config, untilize_out, user_core_coord, input_b_is_batched] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors) mutable -> std::vector<Tensor> {
const auto& input_tensor_a = input_tensors.at(0);
const auto& input_tensor_b = input_tensors.at(1);
auto arch = input_tensor_a.device()->arch();
Expand All @@ -413,7 +415,19 @@ inline Tensor matmul(
}
return operation::run(Matmul{matmul_program_config, broadcast_batch, mem_config, output_dtype.value_or(input_tensor_a.get_dtype()), kernel_config_val, untilize_out}, {input_tensor_a, input_tensor_b}, optional_input_tensors);
},
{input_tensor_a, input_tensor_b}, output_tensors, optional_input_tensors);
{input_tensor_a, input_tensor_b}, output_tensors, optional_input_tensors);
} else {
operation::launch_with_autoformat(
[program_config, mem_config, output_dtype, compute_kernel_config, untilize_out, user_core_coord, input_b_is_batched] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors) mutable -> std::vector<Tensor> {
const auto& input_tensor_a = input_tensors.at(0);
const auto& input_tensor_b = input_tensors.at(1);
auto arch = input_tensor_a.storage_type() == StorageType::DEVICE ? input_tensor_a.device()->arch() : AutoFormat::GetDefaultDevice()->arch();
auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config);
bool broadcast_batch = get_broadcast_batch(input_tensor_a, input_tensor_b, program_config);
return operation::run_with_autoformat(Matmul{program_config, broadcast_batch, mem_config, output_dtype.value_or(input_tensor_a.get_dtype()), kernel_config_val, untilize_out}, {input_tensor_a, input_tensor_b}, optional_input_tensors);
},
{input_tensor_a, input_tensor_b}, output_tensors, optional_input_tensors);
}
return output_tensors.at(0);
}

Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,7 @@ Tensor _outer(Tensor& a, Tensor& b, const MemoryConfig& output_mem_config) {
b_slim = reshape(b, 1, 1, 1, b.volume(), output_mem_config);
}

return tt::operations::primary::matmul(a_slim, b_slim, std::nullopt, tt::operations::primary::MatmulDefaultProgramConfig{}, output_mem_config);
return tt::operations::primary::matmul(a_slim, b_slim, std::nullopt, tt::operations::primary::MatmulDefaultProgramConfig{}, output_mem_config, std::nullopt /*output_dtype*/, std::nullopt /*compute_kernel_config*/, false /*untilize_out*/, std::nullopt /*user_core_coord*/, std::nullopt /*input_b_is_batched*/, true /*needs_autoformat*/);
}
Tensor outer(Tensor& a, Tensor& b, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _outer)(a, b, output_mem_config);
Expand Down

0 comments on commit e39cadd

Please sign in to comment.