diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index b8fb9363b8fe..85fdd7c69c93 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -608,8 +608,6 @@ Other Operations .. autofunction:: tt_lib.tensor.glu -.. autofunction:: tt_lib.tensor.swiglu - .. autofunction:: tt_lib.tensor.embeddings .. autofunction:: tt_lib.tensor.nextafter diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_glu_variants.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_glu_variants.py index 923499047ea7..4cb0630a09a0 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_glu_variants.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_glu_variants.py @@ -31,7 +31,10 @@ @pytest.mark.parametrize("input_mem_config", input_mem_cfgs) @pytest.mark.parametrize("output_mem_config", output_mem_cfgs) class TestGLUVariants: - @pytest.mark.parametrize("fn_kind", ["glu", "swiglu"]) + @pytest.mark.parametrize( + "fn_kind", + ["glu", "reglu", "geglu", "swiglu"], + ) def test_all_glu_ops( self, input_shapes, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index 46bf0d3066d0..eca5a3663974 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -2916,7 +2916,7 @@ def activation_glu(x, *args, device, dtype, layout, input_mem_config, output_mem def activation_geglu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): dim = kwargs.get("dim", -1) t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttnn.geglu(t0, dim=dim, memory_config=output_mem_config) + t1 = ttnn.geglu(t0, dim, memory_config=output_mem_config) return tt2torch_tensor(t1) @@ -2925,7 +2925,7 @@ def activation_geglu(x, *args, device, dtype, layout, input_mem_config, output_m def activation_reglu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): dim = kwargs.get("dim", -1) t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttnn.reglu(t0, dim=dim, memory_config=output_mem_config) + t1 = ttnn.reglu(t0, dim, memory_config=output_mem_config) return tt2torch_tensor(t1) @@ -2934,7 +2934,7 @@ def activation_reglu(x, *args, device, dtype, layout, input_mem_config, output_m def activation_swiglu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): dim = kwargs.get("dim", -1) t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttl.tensor.swiglu(t0, dim, output_mem_config=output_mem_config) + t1 = ttnn.swiglu(t0, dim, memory_config=output_mem_config) return tt2torch_tensor(t1) diff --git a/tests/ttnn/profiling/ops_for_profiling.py b/tests/ttnn/profiling/ops_for_profiling.py index 2d9b5c8ddcee..8776703da37b 100644 --- a/tests/ttnn/profiling/ops_for_profiling.py +++ b/tests/ttnn/profiling/ops_for_profiling.py @@ -1462,7 +1462,7 @@ def reglu_1(x): def swiglu_1(x): - tt_lib.tensor.swiglu(x, -1) + ttnn.swiglu(x, -1) def glu_2(x): @@ -1478,7 +1478,7 @@ def reglu_2(x): def swiglu_2(x): - tt_lib.tensor.swiglu(x, -2) + ttnn.swiglu(x, -2) def repeat(x): @@ -2204,7 +2204,7 @@ def clone(x): }, { "op": swiglu_1, - "name": "tt_lib.tensor.swiglu_dim_3", + "name": "ttnn.swiglu_dim_3", }, { "op": glu_2, @@ -2220,7 +2220,7 @@ def clone(x): }, { "op": swiglu_2, - "name": "tt_lib.tensor.swiglu_dim_2", + "name": "ttnn.swiglu_dim_2", }, { "op": repeat, diff --git a/tests/ttnn/unit_tests/operations/test_composite.py b/tests/ttnn/unit_tests/operations/test_composite.py index 9c809ebf33be..e6afa37e6e89 100644 --- a/tests/ttnn/unit_tests/operations/test_composite.py +++ b/tests/ttnn/unit_tests/operations/test_composite.py @@ -412,8 +412,8 @@ def test_unary_reglu_ttnn(input_shapes, dim, device): in_data, input_tensor = data_gen_with_range(input_shapes, -5, 5, device) golden_fn = ttnn.get_golden_function(ttnn.reglu) - output_tensor = ttnn.reglu(input_tensor, dim=dim) - golden_tensor = golden_fn(in_data, dim=dim) + output_tensor = ttnn.reglu(input_tensor, dim) + golden_tensor = golden_fn(in_data, dim) comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass @@ -435,8 +435,31 @@ def test_unary_geglu_ttnn(input_shapes, dim, device): in_data, input_tensor = data_gen_with_range(input_shapes, -5, 5, device) golden_fn = ttnn.get_golden_function(ttnn.geglu) - output_tensor = ttnn.geglu(input_tensor, dim=dim) - golden_tensor = golden_fn(in_data, dim=dim) + output_tensor = ttnn.geglu(input_tensor, dim) + golden_tensor = golden_fn(in_data, dim) + + comp_pass = compare_pcc([output_tensor], [golden_tensor]) + assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 64])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize( + "dim", + [-1, 3], +) +def test_unary_swiglu_ttnn(input_shapes, dim, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -5, 5, device) + golden_fn = ttnn.get_golden_function(ttnn.swiglu) + + output_tensor = ttnn.swiglu(input_tensor, dim) + golden_tensor = golden_fn(in_data, dim) comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/composite/composite_ops.cpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/composite/composite_ops.cpp index 37c6801dc57e..3db500a01da3 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/composite/composite_ops.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/composite/composite_ops.cpp @@ -1901,28 +1901,6 @@ Tensor glu( } -// Swish Gated Linear Unit activation: matmul(split[0],swish(split[1])) -Tensor _swiglu( - const Tensor& input_a, - int32_t dim /* = -1 */, - const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { - TT_ASSERT(dim == -1 || dim == 3, "last dim SWIGLU only supported at this time "); - if (dim == -1) - dim = 3; - - std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); - - Tensor swish_b = swish(ab[1], output_mem_config); - Tensor swiglu_result = ttnn::multiply(ab[0], swish_b, std::nullopt, output_mem_config); - return swiglu_result; -} -Tensor swiglu( - const Tensor& input_a, - int32_t dim /* = -1 */, - const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { - return operation::decorate_as_composite(__func__, _swiglu)(input_a, dim, output_mem_config); -} - // on-device tensor creation with shape and filled with value Tensor _sfpu_eps(const Shape shape, Layout layout, Device* device, const MemoryConfig& output_mem_config) { float value = device->sfpu_eps(); diff --git a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/composite/composite_ops.hpp b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/composite/composite_ops.hpp index 0bd9ae5d4c90..d4672c9134ad 100644 --- a/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/composite/composite_ops.hpp +++ b/ttnn/cpp/ttnn/experimental/tt_dnn/op_library/composite/composite_ops.hpp @@ -569,12 +569,6 @@ Tensor glu( int32_t dim = -1, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -// Swish based GLU -Tensor swiglu( - const Tensor& input_a, - int32_t dim = -1, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - // on-device tensor creation with shape and filled with value Tensor sfpu_eps(const Shape shape, Layout layout, Device* device, const MemoryConfig& output_mem_config); diff --git a/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index d52571691959..b8c6889c7e85 100644 --- a/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/ttnn/cpp/ttnn/experimental/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -360,13 +360,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) { "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - detail::bind_unary_op_with_param( - m_tensor, - "swiglu", - &swiglu, - py::arg("dim") = -1, - R"doc(Applies the Swish Gated Linear Units (SwiGLU) function to the elements of the input tensor ``{0}`` split along dim ``{1}``.)doc", - R"doc(dimension to split)doc"); detail::bind_unary_op_with_param( m_tensor, "logical_andi", diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index 28517314a0eb..e38879b633c3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -551,4 +551,21 @@ Tensor _geglu( return geglu_result; } + +// Swish Gated Linear Unit activation: matmul(split[0],swish(split[1])) +Tensor _swiglu( + const Tensor& input_a, + int32_t dim, + const std::optional& output_mem_config ) { + TT_ASSERT(dim == -1 || dim == 3, "last dim SWIGLU only supported at this time "); + if (dim == -1) + dim = 3; + + std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); + + Tensor swish_b = _swish(ab[1], output_mem_config); + Tensor swiglu_result = ttnn::multiply(ab[0], swish_b, std::nullopt, output_mem_config); + return swiglu_result; +} + } // namespace ttnn::operations::unary diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp index e97536d19051..0984ed8619e4 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp @@ -40,6 +40,7 @@ enum class UnaryCompositeOpType { THRESHOLD, REGLU, GEGLU, + SWIGLU, }; Tensor _tanhshrink (const Tensor&, const std::optional&); @@ -76,6 +77,7 @@ Tensor _threshold(const Tensor&, float, float, const std::optional Tensor _reglu(const Tensor&, int32_t, const std::optional& ); Tensor _geglu(const Tensor&, int32_t, const std::optional& ); +Tensor _swiglu(const Tensor&, int32_t, const std::optional& ); // OpHandler struct template template @@ -300,6 +302,13 @@ struct OpHandler_dim { } }; +template <> +struct OpHandler_dim { + static Tensor handle(const Tensor& t1, int32_t dim, const std::optional& mem_cfg ) { + return _swiglu(t1, dim, mem_cfg); + } +}; + // Template functions to get the function pointers template auto get_function_type1() { @@ -327,7 +336,7 @@ auto get_function_type5() { } template -auto get_function_type6() { +auto get_glu_fn() { return &OpHandler_dim::handle; } } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp index 2452952491ce..d349a94cf7c9 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp @@ -150,7 +150,7 @@ struct ExecuteUnaryCompositeOpWithDim int32_t dim, const std::optional& memory_config = std::nullopt) { - auto op_type = get_function_type6(); + auto op_type = get_glu_fn(); return op_type(input_tensor, dim, memory_config); } }; @@ -294,5 +294,6 @@ constexpr auto threshold = ttnn::register_operation>("ttnn::reglu"); constexpr auto geglu = ttnn::register_operation>("ttnn::geglu"); +constexpr auto swiglu = ttnn::register_operation>("ttnn::swiglu"); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index 7eceab1eafa8..f474b85c1efc 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -1098,6 +1098,7 @@ void py_module(py::module& module) { // Unary ops with dim parameter detail::bind_unary_operation_with_dim_parameter(module, ttnn::reglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply relu function on second tensor followed by mul op with first tensor"); detail::bind_unary_operation_with_dim_parameter(module, ttnn::geglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply gelu function on second tensor followed by mul op with first tensor"); + detail::bind_unary_operation_with_dim_parameter(module, ttnn::swiglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply silu function on second tensor followed by mul op with first tensor"); // Other unaries (unary chain operations) detail::bind_softplus(module, ttnn::softplus); diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 84678cae301e..5681c0594093 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -404,6 +404,20 @@ def _golden_function_geglu(input_tensor_a, dim, *args, **kwargs): ttnn.attach_golden_function(ttnn._ttnn.operations.unary.geglu, golden_function=_golden_function_geglu) +def _golden_function_swiglu(input_tensor_a, dim, *args, **kwargs): + import torch + + assert isinstance(dim, int), "dim must be an integer" + assert dim in [-1, 3], "dim must be -1 or 3" + split_size = input_tensor_a.size(-1) // 2 + split_tensors = torch.split(input_tensor_a, split_size_or_sections=[split_size, split_size], dim=dim) + tensA, tensB = split_tensors[0], split_tensors[1] + return tensA * torch.nn.functional.silu(tensB) + + +ttnn.attach_golden_function(ttnn._ttnn.operations.unary.swiglu, golden_function=_golden_function_swiglu) + + def _is_scalar(value): return isinstance(value, (int, float)) @@ -554,33 +568,12 @@ def activation_function( ) -def torch_reglu(input_tensor, *args, **kwargs): - import torch - - split_size = input_tensor.size(-1) // 2 - split_tensors = torch.split(input_tensor, split_size_or_sections=[split_size, split_size], dim=-1) - tensA, tensB = split_tensors[0], split_tensors[1] - return tensA * torch.nn.functional.relu(tensB) - - -def torch_swiglu(input_tensor, *args, **kwargs): - import torch - - split_size = input_tensor.size(-1) // 2 - split_tensors = torch.split(input_tensor, split_size_or_sections=[split_size, split_size], dim=-1) - tensA, tensB = split_tensors[0], split_tensors[1] - return tensA * torch.nn.functional.silu(tensB) - - def register_ttl_activation_function_glu(name, ttl_activation_function, param): def _golden_function(input_tensor: ttnn.Tensor, dim: int = -1, **_): import torch name_to_torch_function = { "glu": torch.nn.functional.glu, - "reglu": torch_reglu, - "swiglu": torch_swiglu, - "geglu": torch_geglu, } torch_function = name_to_torch_function[name] input_tensor = ttnn.to_torch(input_tensor) @@ -632,7 +625,6 @@ def activation_function( TTL_ACTIVATION_FUNCTIONS_GLU = [ ("glu", ttl.tensor.glu, "dim"), # composite - ("swiglu", ttl.tensor.swiglu, "dim"), # composite ]