Skip to content

Commit

Permalink
#9709: Add optional transpose_a and transpose_b to ttnn matmul and li…
Browse files Browse the repository at this point in the history
…near

- Fix transpose to respect padding when swapping dims during compute_output_shapes
  • Loading branch information
TT-BrianLiu committed Jul 2, 2024
1 parent 4a8f2da commit 2008386
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 36 deletions.
34 changes: 34 additions & 0 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,40 @@ def test_matmul_by_passing_in_1D_systolic_array_program_config(device, batch_siz
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.997)


@pytest.mark.parametrize(
"n_size, c, m, k, n",
[
(1, 1, 2, 3, 4),
(1, 1, 1024, 64, 512),
],
)
@pytest.mark.parametrize("transpose_b", [True, False])
@pytest.mark.parametrize("transpose_a", [True, False])
def test_matmul_with_transpose_a_or_b(device, n_size, c, m, k, n, transpose_a, transpose_b):
torch.manual_seed(0)

torch_input_tensor_a = torch.rand((n_size, c, m, k), dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand((n_size, c, k, n), dtype=torch.bfloat16)
torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b)

if transpose_a:
torch_input_tensor_a = torch_input_tensor_a.transpose(-1, -2)
if transpose_b:
torch_input_tensor_b = torch_input_tensor_b.transpose(-1, -2)

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
output = ttnn.matmul(input_tensor_a, input_tensor_b, transpose_a=transpose_a, transpose_b=transpose_b)
output = ttnn.to_torch(output)

assert len(output.shape) == len(torch_output_tensor.shape)
assert output.shape == torch_output_tensor.shape
assert_with_pcc(torch_output_tensor, output, 0.999)


##########################
# MODEL SPECIFIC MATMULS #
##########################
@skip_for_wormhole_b0()
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("m_size", [128])
Expand Down
27 changes: 14 additions & 13 deletions tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,33 +58,34 @@ void Transpose::validate(const std::vector<Tensor> &input_tensors) const {
std::vector<Shape> Transpose::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
const auto& input_tensor = input_tensors.at(0);
auto out_shape = input_tensor.get_legacy_shape();
auto padding = out_shape.padding();
switch (this->dim){
case TransposeOpDim::CN:
out_shape[0] = input_tensor.get_legacy_shape()[1];
out_shape[1] = input_tensor.get_legacy_shape()[0];
std::swap(out_shape[0], out_shape[1]);
std::swap(padding[0], padding[1]);
break;
case TransposeOpDim::HC:
out_shape[1] = input_tensor.get_legacy_shape()[2];
out_shape[2] = input_tensor.get_legacy_shape()[1];
std::swap(out_shape[1], out_shape[2]);
std::swap(padding[1], padding[2]);
break;
case TransposeOpDim::WH:
out_shape[2] = input_tensor.get_legacy_shape()[3];
out_shape[3] = input_tensor.get_legacy_shape()[2];
std::swap(out_shape[2], out_shape[3]);
std::swap(padding[2], padding[3]);
break;
case TransposeOpDim::NH:
out_shape[0] = input_tensor.get_legacy_shape()[2];
out_shape[2] = input_tensor.get_legacy_shape()[0];
std::swap(out_shape[0], out_shape[2]);
std::swap(padding[0], padding[2]);
break;
case TransposeOpDim::NW:
out_shape[3] = input_tensor.get_legacy_shape()[0];
out_shape[0] = input_tensor.get_legacy_shape()[3];
std::swap(out_shape[0], out_shape[3]);
std::swap(padding[0], padding[3]);
break;
case TransposeOpDim::CW:
out_shape[1] = input_tensor.get_legacy_shape()[3];
out_shape[3] = input_tensor.get_legacy_shape()[1];
std::swap(out_shape[1], out_shape[3]);
std::swap(padding[1], padding[3]);
break;
}
return {out_shape};
return {Shape(out_shape, padding)};
}


Expand Down
44 changes: 32 additions & 12 deletions ttnn/cpp/pybind11/operations/matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,33 @@ void py_module(py::module& module) {
"matmul",
[](const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const bool transpose_a = false,
const bool transpose_b = false,
const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG,
const std::optional<const DataType> dtype = std::nullopt,
const std::optional<const ttnn::MatmulProgramConfig> program_config = std::nullopt,
const std::optional<const std::string>& activation = std::nullopt,
const std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt,
const std::optional<const ttnn::CoreGrid> core_grid = std::nullopt) -> ttnn::Tensor {
return ttnn::operations::matmul::matmul(
input_tensor_a, input_tensor_b, /*bias=*/std::nullopt, program_config, memory_config, dtype, activation, compute_kernel_config, core_grid, /*propagate_is_b_batched=*/true);
input_tensor_a,
input_tensor_b,
/*bias=*/std::nullopt,
transpose_a,
transpose_b,
program_config,
memory_config,
dtype,
activation,
compute_kernel_config,
core_grid,
/*propagate_is_b_batched=*/true);
},
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("transpose_a") = false,
py::arg("transpose_b") = false,
py::arg("memory_config") = DRAM_MEMORY_CONFIG,
py::arg("dtype") = std::nullopt,
py::arg("program_config") = std::nullopt,
Expand All @@ -41,20 +56,24 @@ void py_module(py::module& module) {
py::arg("core_grid") = std::nullopt);

module.def(
"linear",
[](const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<const ttnn::Tensor>& bias,
const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG,
const std::optional<const DataType> dtype = std::nullopt,
const std::optional<const ttnn::MatmulProgramConfig> program_config = std::nullopt,
const std::optional<const std::string>& activation = std::nullopt,
const std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt,
const std::optional<const ttnn::CoreGrid> core_grid = std::nullopt) -> ttnn::Tensor {
"linear",
[](const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<const ttnn::Tensor>& bias = std::nullopt,
const bool transpose_a = false,
const bool transpose_b = false,
const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG,
const std::optional<const DataType> dtype = std::nullopt,
const std::optional<const ttnn::MatmulProgramConfig> program_config = std::nullopt,
const std::optional<const std::string>& activation = std::nullopt,
const std::optional<const DeviceComputeKernelConfig> compute_kernel_config = std::nullopt,
const std::optional<const ttnn::CoreGrid> core_grid = std::nullopt) -> ttnn::Tensor {
return ttnn::operations::matmul::matmul(
input_tensor_a,
input_tensor_b,
bias,
transpose_a,
transpose_b,
program_config,
memory_config,
dtype,
Expand All @@ -66,13 +85,14 @@ void py_module(py::module& module) {
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("bias") = std::nullopt,
py::arg("transpose_a") = false,
py::arg("transpose_b") = false,
py::arg("memory_config") = DRAM_MEMORY_CONFIG,
py::arg("dtype") = std::nullopt,
py::arg("program_config") = std::nullopt,
py::arg("activation") = std::nullopt,
py::arg("compute_kernel_config") = std::nullopt,
py::arg("core_grid") = std::nullopt);

}

} // namespace matmul
Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/operations/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,8 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
matmul_input,
weight_tensor_on_device,
bias_tensor_on_device,
/*transpose_a=*/false,
/*transpose_b=*/false,
matmul_program_config,
conv_out_memory_config,
conv_config.dtype,
Expand Down
61 changes: 50 additions & 11 deletions ttnn/cpp/ttnn/operations/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
// SPDX-License-Identifier: Apache-2.0

#include "matmul.hpp"
#include "ttnn/cpp/ttnn/validation.hpp"

#include "ttnn/cpp/ttnn/operations/core.hpp"
#include "ttnn/cpp/ttnn/validation.hpp"
#include "tt_dnn/op_library/transpose/transpose_op.hpp"

namespace ttnn {

using MatmulMultiCoreReuseProgramConfig = tt::operations::primary::MatmulMultiCoreReuseProgramConfig;
Expand Down Expand Up @@ -35,16 +38,37 @@ bool is_input_batched(const ttnn::Shape& shape) {
const std::array<ttnn::TensorSchema, 3> input_tensor_schemas() {
return {
ttnn::TensorSchema{
2, 4, {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, {ttnn::TILE_LAYOUT}, true, false, true, false},
2,
4,
{ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b},
{ttnn::TILE_LAYOUT},
true,
false,
true,
false},
ttnn::TensorSchema{
2, 4, {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, {ttnn::TILE_LAYOUT}, true, false, true, false},
2,
4,
{ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b},
{ttnn::TILE_LAYOUT},
true,
false,
true,
false},
ttnn::TensorSchema{
2, 4, {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, {ttnn::TILE_LAYOUT}, true, false, true, true}};
2,
4,
{ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b},
{ttnn::TILE_LAYOUT},
true,
false,
true,
true}};
}

std::optional<UnaryWithParam> get_fused_activation(const std::optional<const std::string>& activation) {
if (!activation.has_value()) {
return std::nullopt;
return std::nullopt;
}
return string_to_unary_with_param(activation.value());
}
Expand All @@ -53,6 +77,8 @@ ttnn::Tensor matmul(
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<const ttnn::Tensor>& bias,
const bool transpose_a,
const bool transpose_b,
const std::optional<const MatmulProgramConfig> program_config,
const ttnn::MemoryConfig& memory_config,
std::optional<const DataType> dtype,
Expand All @@ -64,8 +90,11 @@ ttnn::Tensor matmul(
ttnn::validate_input_tensor("ttnn.matmul", input_tensor_b, input_tensor_schemas()[1]);
ttnn::validate_input_tensor("ttnn.matmul", bias, input_tensor_schemas()[2]);

const auto input_tensor_a_shape = input_tensor_a.get_shape();
const auto input_tensor_b_shape = input_tensor_b.get_shape();
const auto& input_tensor_a_adjusted = transpose_a ? tt::tt_metal::transpose(input_tensor_a, -1, -2, input_tensor_a.memory_config()) : input_tensor_a;
const auto& input_tensor_b_adjusted = transpose_b ? tt::tt_metal::transpose(input_tensor_b, -1, -2, input_tensor_b.memory_config()) : input_tensor_b;

const auto input_tensor_a_shape = input_tensor_a_adjusted.get_shape();
const auto input_tensor_b_shape = input_tensor_b_adjusted.get_shape();

const auto width_a = input_tensor_a_shape[-1];
const auto height_b = input_tensor_b_shape[-2];
Expand All @@ -81,19 +110,29 @@ ttnn::Tensor matmul(
std::optional<CoreCoord> user_core_coord;
const bool has_user_grid = core_grid.has_value();
if (has_user_grid) {
user_core_coord = CoreCoord(core_grid->x, core_grid->y);
user_core_coord = CoreCoord(core_grid->x, core_grid->y);
}

const bool has_program_config = program_config.has_value();
bool post_process_bias = false;
if (bias.has_value()) {
if (!has_program_config && !has_user_grid) {
post_process_bias = true;
}
post_process_bias = true;
}
}

auto output_tensor = tt::operations::primary::matmul(
input_tensor_a, input_tensor_b, post_process_bias ? std::nullopt : bias, program_config, memory_config, dtype, compute_kernel_config, false /*untilize_out*/, user_core_coord, get_fused_activation(activation), propagate_is_b_batched && input_b_is_batched);
input_tensor_a_adjusted,
input_tensor_b_adjusted,
post_process_bias ? std::nullopt : bias,
program_config,
memory_config,
dtype,
compute_kernel_config,
false /*untilize_out*/,
user_core_coord,
get_fused_activation(activation),
propagate_is_b_batched && input_b_is_batched);

if (post_process_bias) {
output_tensor = tt::operations::primary::bcast(
Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/operations/matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ ttnn::Tensor matmul(
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<const ttnn::Tensor>& bias,
const bool transpose_a = false,
const bool transpose_b = false,
const std::optional<const MatmulProgramConfig> program_config = std::nullopt,
const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG,
std::optional<const DataType> dtype = std::nullopt,
Expand Down
16 changes: 16 additions & 0 deletions ttnn/ttnn/operations/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs):
import torch

if transpose_a:
input_tensor_a = input_tensor_a.transpose(-1, -2)
if transpose_b:
input_tensor_b = input_tensor_b.transpose(-1, -2)
output_tensor = input_tensor_a @ input_tensor_b.to(input_tensor_a.dtype)

if activation == "gelu":
Expand All @@ -44,6 +48,8 @@ def matmul(
input_tensor_a: ttnn.Tensor,
input_tensor_b: ttnn.Tensor,
*,
transpose_a: bool = False,
transpose_b: bool = False,
memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG,
dtype: Optional[ttnn.DataType] = None,
core_grid: Optional[ttnn.CoreGrid] = None,
Expand Down Expand Up @@ -141,6 +147,8 @@ def matmul(
return ttnn._ttnn.operations.matmul.matmul(
input_tensor_a,
input_tensor_b,
transpose_a=transpose_a,
transpose_b=transpose_b,
memory_config=memory_config,
dtype=dtype,
program_config=program_config,
Expand All @@ -153,6 +161,10 @@ def matmul(
def _golden_function(input_tensor_a, input_tensor_b, *, bias=None, activation=None, **kwargs):
import torch

if transpose_a:
input_tensor_a = input_tensor_a.transpose(-1, -2)
if transpose_b:
input_tensor_b = input_tensor_b.transpose(-1, -2)
output_tensor = input_tensor_a @ input_tensor_b.to(input_tensor_a.dtype)

if bias is not None:
Expand Down Expand Up @@ -182,6 +194,8 @@ def linear(
input_tensor_b: ttnn.Tensor,
*,
bias: Optional[ttnn.Tensor] = None,
transpose_a: bool = False,
transpose_b: bool = False,
memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG,
dtype: Optional[ttnn.DataType] = None,
core_grid: Optional[ttnn.CoreGrid] = None,
Expand Down Expand Up @@ -228,6 +242,8 @@ def linear(
input_tensor_a,
input_tensor_b,
bias=bias,
transpose_a=transpose_a,
transpose_b=transpose_b,
memory_config=memory_config,
dtype=dtype,
program_config=program_config,
Expand Down

0 comments on commit 2008386

Please sign in to comment.