diff --git a/tests/ttnn/unit_tests/operations/test_matmul.py b/tests/ttnn/unit_tests/operations/test_matmul.py index b67ede0f933..695dee46b6b 100644 --- a/tests/ttnn/unit_tests/operations/test_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_matmul.py @@ -483,6 +483,40 @@ def test_matmul_by_passing_in_1D_systolic_array_program_config(device, batch_siz assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.997) +@pytest.mark.parametrize( + "n_size, c, m, k, n", + [ + (1, 1, 2, 3, 4), + (1, 1, 1024, 64, 512), + ], +) +@pytest.mark.parametrize("transpose_b", [True, False]) +@pytest.mark.parametrize("transpose_a", [True, False]) +def test_matmul_with_transpose_a_or_b(device, n_size, c, m, k, n, transpose_a, transpose_b): + torch.manual_seed(0) + + torch_input_tensor_a = torch.rand((n_size, c, m, k), dtype=torch.bfloat16) + torch_input_tensor_b = torch.rand((n_size, c, k, n), dtype=torch.bfloat16) + torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b) + + if transpose_a: + torch_input_tensor_a = torch_input_tensor_a.transpose(-1, -2) + if transpose_b: + torch_input_tensor_b = torch_input_tensor_b.transpose(-1, -2) + + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + output = ttnn.matmul(input_tensor_a, input_tensor_b, transpose_a=transpose_a, transpose_b=transpose_b) + output = ttnn.to_torch(output) + + assert len(output.shape) == len(torch_output_tensor.shape) + assert output.shape == torch_output_tensor.shape + assert_with_pcc(torch_output_tensor, output, 0.999) + + +########################## +# MODEL SPECIFIC MATMULS # +########################## @skip_for_wormhole_b0() @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("m_size", [128]) diff --git a/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp b/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp index 588289546a1..2eac913e4c9 100644 --- a/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp +++ b/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp @@ -58,33 +58,34 @@ void Transpose::validate(const std::vector &input_tensors) const { std::vector Transpose::compute_output_shapes(const std::vector &input_tensors) const { const auto& input_tensor = input_tensors.at(0); auto out_shape = input_tensor.get_legacy_shape(); + auto padding = out_shape.padding(); switch (this->dim){ case TransposeOpDim::CN: - out_shape[0] = input_tensor.get_legacy_shape()[1]; - out_shape[1] = input_tensor.get_legacy_shape()[0]; + std::swap(out_shape[0], out_shape[1]); + std::swap(padding[0], padding[1]); break; case TransposeOpDim::HC: - out_shape[1] = input_tensor.get_legacy_shape()[2]; - out_shape[2] = input_tensor.get_legacy_shape()[1]; + std::swap(out_shape[1], out_shape[2]); + std::swap(padding[1], padding[2]); break; case TransposeOpDim::WH: - out_shape[2] = input_tensor.get_legacy_shape()[3]; - out_shape[3] = input_tensor.get_legacy_shape()[2]; + std::swap(out_shape[2], out_shape[3]); + std::swap(padding[2], padding[3]); break; case TransposeOpDim::NH: - out_shape[0] = input_tensor.get_legacy_shape()[2]; - out_shape[2] = input_tensor.get_legacy_shape()[0]; + std::swap(out_shape[0], out_shape[2]); + std::swap(padding[0], padding[2]); break; case TransposeOpDim::NW: - out_shape[3] = input_tensor.get_legacy_shape()[0]; - out_shape[0] = input_tensor.get_legacy_shape()[3]; + std::swap(out_shape[0], out_shape[3]); + std::swap(padding[0], padding[3]); break; case TransposeOpDim::CW: - out_shape[1] = input_tensor.get_legacy_shape()[3]; - out_shape[3] = input_tensor.get_legacy_shape()[1]; + std::swap(out_shape[1], out_shape[3]); + std::swap(padding[1], padding[3]); break; } - return {out_shape}; + return {Shape(out_shape, padding)}; } diff --git a/ttnn/cpp/pybind11/operations/matmul.hpp b/ttnn/cpp/pybind11/operations/matmul.hpp index 56bb82534d1..4bd41040893 100644 --- a/ttnn/cpp/pybind11/operations/matmul.hpp +++ b/ttnn/cpp/pybind11/operations/matmul.hpp @@ -21,6 +21,8 @@ void py_module(py::module& module) { "matmul", [](const ttnn::Tensor& input_tensor_a, const ttnn::Tensor& input_tensor_b, + const bool transpose_a = false, + const bool transpose_b = false, const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG, const std::optional dtype = std::nullopt, const std::optional program_config = std::nullopt, @@ -28,11 +30,24 @@ void py_module(py::module& module) { const std::optional compute_kernel_config = std::nullopt, const std::optional core_grid = std::nullopt) -> ttnn::Tensor { return ttnn::operations::matmul::matmul( - input_tensor_a, input_tensor_b, /*bias=*/std::nullopt, program_config, memory_config, dtype, activation, compute_kernel_config, core_grid, /*propagate_is_b_batched=*/true); + input_tensor_a, + input_tensor_b, + /*bias=*/std::nullopt, + transpose_a, + transpose_b, + program_config, + memory_config, + dtype, + activation, + compute_kernel_config, + core_grid, + /*propagate_is_b_batched=*/true); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), py::kw_only(), + py::arg("transpose_a") = false, + py::arg("transpose_b") = false, py::arg("memory_config") = DRAM_MEMORY_CONFIG, py::arg("dtype") = std::nullopt, py::arg("program_config") = std::nullopt, @@ -41,20 +56,24 @@ void py_module(py::module& module) { py::arg("core_grid") = std::nullopt); module.def( - "linear", - [](const ttnn::Tensor& input_tensor_a, - const ttnn::Tensor& input_tensor_b, - const std::optional& bias, - const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG, - const std::optional dtype = std::nullopt, - const std::optional program_config = std::nullopt, - const std::optional& activation = std::nullopt, - const std::optional compute_kernel_config = std::nullopt, - const std::optional core_grid = std::nullopt) -> ttnn::Tensor { + "linear", + [](const ttnn::Tensor& input_tensor_a, + const ttnn::Tensor& input_tensor_b, + const std::optional& bias = std::nullopt, + const bool transpose_a = false, + const bool transpose_b = false, + const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG, + const std::optional dtype = std::nullopt, + const std::optional program_config = std::nullopt, + const std::optional& activation = std::nullopt, + const std::optional compute_kernel_config = std::nullopt, + const std::optional core_grid = std::nullopt) -> ttnn::Tensor { return ttnn::operations::matmul::matmul( input_tensor_a, input_tensor_b, bias, + transpose_a, + transpose_b, program_config, memory_config, dtype, @@ -66,13 +85,14 @@ void py_module(py::module& module) { py::arg("input_tensor_b"), py::kw_only(), py::arg("bias") = std::nullopt, + py::arg("transpose_a") = false, + py::arg("transpose_b") = false, py::arg("memory_config") = DRAM_MEMORY_CONFIG, py::arg("dtype") = std::nullopt, py::arg("program_config") = std::nullopt, py::arg("activation") = std::nullopt, py::arg("compute_kernel_config") = std::nullopt, py::arg("core_grid") = std::nullopt); - } } // namespace matmul diff --git a/ttnn/cpp/ttnn/operations/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv2d.cpp index f852dc02c30..3484c7e4fd8 100644 --- a/ttnn/cpp/ttnn/operations/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv2d.cpp @@ -708,6 +708,8 @@ std::tuple input_tensor_schemas() { return { ttnn::TensorSchema{ - 2, 4, {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, {ttnn::TILE_LAYOUT}, true, false, true, false}, + 2, + 4, + {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::TILE_LAYOUT}, + true, + false, + true, + false}, ttnn::TensorSchema{ - 2, 4, {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, {ttnn::TILE_LAYOUT}, true, false, true, false}, + 2, + 4, + {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::TILE_LAYOUT}, + true, + false, + true, + false}, ttnn::TensorSchema{ - 2, 4, {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, {ttnn::TILE_LAYOUT}, true, false, true, true}}; + 2, + 4, + {ttnn::float32, ttnn::bfloat16, ttnn::bfloat8_b, ttnn::bfloat4_b}, + {ttnn::TILE_LAYOUT}, + true, + false, + true, + true}}; } std::optional get_fused_activation(const std::optional& activation) { if (!activation.has_value()) { - return std::nullopt; + return std::nullopt; } return string_to_unary_with_param(activation.value()); } @@ -53,6 +77,8 @@ ttnn::Tensor matmul( const ttnn::Tensor& input_tensor_a, const ttnn::Tensor& input_tensor_b, const std::optional& bias, + const bool transpose_a, + const bool transpose_b, const std::optional program_config, const ttnn::MemoryConfig& memory_config, std::optional dtype, @@ -64,8 +90,11 @@ ttnn::Tensor matmul( ttnn::validate_input_tensor("ttnn.matmul", input_tensor_b, input_tensor_schemas()[1]); ttnn::validate_input_tensor("ttnn.matmul", bias, input_tensor_schemas()[2]); - const auto input_tensor_a_shape = input_tensor_a.get_shape(); - const auto input_tensor_b_shape = input_tensor_b.get_shape(); + const auto& input_tensor_a_adjusted = transpose_a ? tt::tt_metal::transpose(input_tensor_a, -1, -2, input_tensor_a.memory_config()) : input_tensor_a; + const auto& input_tensor_b_adjusted = transpose_b ? tt::tt_metal::transpose(input_tensor_b, -1, -2, input_tensor_b.memory_config()) : input_tensor_b; + + const auto input_tensor_a_shape = input_tensor_a_adjusted.get_shape(); + const auto input_tensor_b_shape = input_tensor_b_adjusted.get_shape(); const auto width_a = input_tensor_a_shape[-1]; const auto height_b = input_tensor_b_shape[-2]; @@ -81,19 +110,29 @@ ttnn::Tensor matmul( std::optional user_core_coord; const bool has_user_grid = core_grid.has_value(); if (has_user_grid) { - user_core_coord = CoreCoord(core_grid->x, core_grid->y); + user_core_coord = CoreCoord(core_grid->x, core_grid->y); } const bool has_program_config = program_config.has_value(); bool post_process_bias = false; if (bias.has_value()) { if (!has_program_config && !has_user_grid) { - post_process_bias = true; - } + post_process_bias = true; + } } auto output_tensor = tt::operations::primary::matmul( - input_tensor_a, input_tensor_b, post_process_bias ? std::nullopt : bias, program_config, memory_config, dtype, compute_kernel_config, false /*untilize_out*/, user_core_coord, get_fused_activation(activation), propagate_is_b_batched && input_b_is_batched); + input_tensor_a_adjusted, + input_tensor_b_adjusted, + post_process_bias ? std::nullopt : bias, + program_config, + memory_config, + dtype, + compute_kernel_config, + false /*untilize_out*/, + user_core_coord, + get_fused_activation(activation), + propagate_is_b_batched && input_b_is_batched); if (post_process_bias) { output_tensor = tt::operations::primary::bcast( diff --git a/ttnn/cpp/ttnn/operations/matmul.hpp b/ttnn/cpp/ttnn/operations/matmul.hpp index 1b89ee82412..139a020095a 100644 --- a/ttnn/cpp/ttnn/operations/matmul.hpp +++ b/ttnn/cpp/ttnn/operations/matmul.hpp @@ -36,6 +36,8 @@ ttnn::Tensor matmul( const ttnn::Tensor& input_tensor_a, const ttnn::Tensor& input_tensor_b, const std::optional& bias, + const bool transpose_a = false, + const bool transpose_b = false, const std::optional program_config = std::nullopt, const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG, std::optional dtype = std::nullopt, diff --git a/ttnn/ttnn/operations/matmul.py b/ttnn/ttnn/operations/matmul.py index 9bde7070e9a..09c977531f5 100644 --- a/ttnn/ttnn/operations/matmul.py +++ b/ttnn/ttnn/operations/matmul.py @@ -23,6 +23,10 @@ def _golden_function(input_tensor_a, input_tensor_b, *args, **kwargs): import torch + if transpose_a: + input_tensor_a = input_tensor_a.transpose(-1, -2) + if transpose_b: + input_tensor_b = input_tensor_b.transpose(-1, -2) output_tensor = input_tensor_a @ input_tensor_b.to(input_tensor_a.dtype) if activation == "gelu": @@ -44,6 +48,8 @@ def matmul( input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, + transpose_a: bool = False, + transpose_b: bool = False, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG, dtype: Optional[ttnn.DataType] = None, core_grid: Optional[ttnn.CoreGrid] = None, @@ -141,6 +147,8 @@ def matmul( return ttnn._ttnn.operations.matmul.matmul( input_tensor_a, input_tensor_b, + transpose_a=transpose_a, + transpose_b=transpose_b, memory_config=memory_config, dtype=dtype, program_config=program_config, @@ -153,6 +161,10 @@ def matmul( def _golden_function(input_tensor_a, input_tensor_b, *, bias=None, activation=None, **kwargs): import torch + if transpose_a: + input_tensor_a = input_tensor_a.transpose(-1, -2) + if transpose_b: + input_tensor_b = input_tensor_b.transpose(-1, -2) output_tensor = input_tensor_a @ input_tensor_b.to(input_tensor_a.dtype) if bias is not None: @@ -182,6 +194,8 @@ def linear( input_tensor_b: ttnn.Tensor, *, bias: Optional[ttnn.Tensor] = None, + transpose_a: bool = False, + transpose_b: bool = False, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG, dtype: Optional[ttnn.DataType] = None, core_grid: Optional[ttnn.CoreGrid] = None, @@ -228,6 +242,8 @@ def linear( input_tensor_a, input_tensor_b, bias=bias, + transpose_a=transpose_a, + transpose_b=transpose_b, memory_config=memory_config, dtype=dtype, program_config=program_config,