diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index a69a174d155..e6645983994 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -608,14 +608,6 @@ Other Operations .. autofunction:: tt_lib.tensor.normalize_global -.. autofunction:: tt_lib.tensor.glu - -.. autofunction:: tt_lib.tensor.geglu - -.. autofunction:: tt_lib.tensor.reglu - -.. autofunction:: tt_lib.tensor.swiglu - .. autofunction:: tt_lib.tensor.embeddings .. autofunction:: tt_lib.tensor.nextafter diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_glu_variants.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_glu_variants.py index ed5ef9a6b42..4cb0630a09a 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_glu_variants.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_glu_variants.py @@ -31,7 +31,10 @@ @pytest.mark.parametrize("input_mem_config", input_mem_cfgs) @pytest.mark.parametrize("output_mem_config", output_mem_cfgs) class TestGLUVariants: - @pytest.mark.parametrize("fn_kind", ["glu", "reglu", "geglu", "swiglu"]) + @pytest.mark.parametrize( + "fn_kind", + ["glu", "reglu", "geglu", "swiglu"], + ) def test_all_glu_ops( self, input_shapes, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index 512bdf8034e..cedfc109991 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -2846,7 +2846,7 @@ def unpad_from_tile( def activation_glu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): dim = kwargs.get("dim", -1) t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttl.tensor.glu(t0, dim, output_mem_config=output_mem_config) + t1 = ttnn.glu(t0, dim, memory_config=output_mem_config) return tt2torch_tensor(t1) @@ -2855,7 +2855,7 @@ def activation_glu(x, *args, device, dtype, layout, input_mem_config, output_mem def activation_geglu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): dim = kwargs.get("dim", -1) t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttl.tensor.geglu(t0, dim, output_mem_config=output_mem_config) + t1 = ttnn.geglu(t0, dim, memory_config=output_mem_config) return tt2torch_tensor(t1) @@ -2864,7 +2864,7 @@ def activation_geglu(x, *args, device, dtype, layout, input_mem_config, output_m def activation_reglu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): dim = kwargs.get("dim", -1) t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttl.tensor.reglu(t0, dim, output_mem_config=output_mem_config) + t1 = ttnn.reglu(t0, dim, memory_config=output_mem_config) return tt2torch_tensor(t1) @@ -2873,7 +2873,7 @@ def activation_reglu(x, *args, device, dtype, layout, input_mem_config, output_m def activation_swiglu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): dim = kwargs.get("dim", -1) t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttl.tensor.swiglu(t0, dim, output_mem_config=output_mem_config) + t1 = ttnn.swiglu(t0, dim, memory_config=output_mem_config) return tt2torch_tensor(t1) diff --git a/tests/ttnn/profiling/ops_for_profiling.py b/tests/ttnn/profiling/ops_for_profiling.py index 88d88eabe8b..6440f7ec598 100644 --- a/tests/ttnn/profiling/ops_for_profiling.py +++ b/tests/ttnn/profiling/ops_for_profiling.py @@ -1450,35 +1450,35 @@ def logical_noti(x): def glu_1(x): - tt_lib.tensor.glu(x, -1) + ttnn.glu(x, -1) def geglu_1(x): - tt_lib.tensor.geglu(x, -1) + ttnn.geglu(x, -1) def reglu_1(x): - tt_lib.tensor.reglu(x, -1) + ttnn.reglu(x, -1) def swiglu_1(x): - tt_lib.tensor.swiglu(x, -1) + ttnn.swiglu(x, -1) def glu_2(x): - tt_lib.tensor.glu(x, -2) + ttnn.glu(x, -2) def geglu_2(x): - tt_lib.tensor.geglu(x, -2) + ttnn.geglu(x, -2) def reglu_2(x): - tt_lib.tensor.reglu(x, -2) + ttnn.reglu(x, -2) def swiglu_2(x): - tt_lib.tensor.swiglu(x, -2) + ttnn.swiglu(x, -2) def repeat(x): @@ -2192,35 +2192,35 @@ def clone(x): }, { "op": glu_1, - "name": "tt_lib.tensor.glu_dim_3", + "name": "ttnn.glu_dim_3", }, { "op": geglu_1, - "name": "tt_lib.tensor.geglu_dim_3", + "name": "ttnn.geglu_dim_3", }, { "op": reglu_1, - "name": "tt_lib.tensor.reglu_dim_3", + "name": "ttnn.reglu_dim_3", }, { "op": swiglu_1, - "name": "tt_lib.tensor.swiglu_dim_3", + "name": "ttnn.swiglu_dim_3", }, { "op": glu_2, - "name": "tt_lib.tensor.glu_dim_2", + "name": "ttnn.glu_dim_2", }, { "op": geglu_2, - "name": "tt_lib.tensor.geglu_dim_2", + "name": "ttnn.geglu_dim_2", }, { "op": reglu_2, - "name": "tt_lib.tensor.reglu_dim_2", + "name": "ttnn.reglu_dim_2", }, { "op": swiglu_2, - "name": "tt_lib.tensor.swiglu_dim_2", + "name": "ttnn.swiglu_dim_2", }, { "op": repeat, diff --git a/tests/ttnn/unit_tests/operations/test_composite.py b/tests/ttnn/unit_tests/operations/test_composite.py index 98997e0f05f..9959ea790f6 100644 --- a/tests/ttnn/unit_tests/operations/test_composite.py +++ b/tests/ttnn/unit_tests/operations/test_composite.py @@ -394,3 +394,95 @@ def test_unary_threshold_ttnn(input_shapes, device): comp_pass = compare_pcc([output_tensor], [golden_tensor]) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 64])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize( + "dim", + [-1, 3], +) +def test_unary_glu_ttnn(input_shapes, dim, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -5, 5, device) + golden_fn = ttnn.get_golden_function(ttnn.glu) + + output_tensor = ttnn.glu(input_tensor, dim) + golden_tensor = golden_fn(in_data, dim) + + comp_pass = compare_pcc([output_tensor], [golden_tensor]) + assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 64])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize( + "dim", + [-1, 3], +) +def test_unary_reglu_ttnn(input_shapes, dim, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -5, 5, device) + golden_fn = ttnn.get_golden_function(ttnn.reglu) + + output_tensor = ttnn.reglu(input_tensor, dim) + golden_tensor = golden_fn(in_data, dim) + + comp_pass = compare_pcc([output_tensor], [golden_tensor]) + assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 64])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize( + "dim", + [-1, 3], +) +def test_unary_geglu_ttnn(input_shapes, dim, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -5, 5, device) + golden_fn = ttnn.get_golden_function(ttnn.geglu) + + output_tensor = ttnn.geglu(input_tensor, dim) + golden_tensor = golden_fn(in_data, dim) + + comp_pass = compare_pcc([output_tensor], [golden_tensor]) + assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 64])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize( + "dim", + [-1, 3], +) +def test_unary_swiglu_ttnn(input_shapes, dim, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -5, 5, device) + golden_fn = ttnn.get_golden_function(ttnn.swiglu) + + output_tensor = ttnn.swiglu(input_tensor, dim) + golden_tensor = golden_fn(in_data, dim) + + comp_pass = compare_pcc([output_tensor], [golden_tensor]) + assert comp_pass diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.cpp index f8bdf35c6da..961e7921928 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.cpp @@ -1879,91 +1879,6 @@ std::vector split_tensor_for_glu(const Tensor& input_a, int32_t dim, con return t_split; } -// Gated Linear Unit activation: matmul(split[0],sigmoid(split[1])) -Tensor _glu( - const Tensor& input_a, - int32_t dim /* = -1 */, - const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { - TT_ASSERT(dim == -1 || dim == 3, "last dim GLU only supported at this time "); - if (dim == -1) - dim = 3; - - std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); - Tensor sigmoid_b = ttnn::sigmoid(ab[1], output_mem_config); - Tensor glu_result = ttnn::multiply(ab[0], sigmoid_b, std::nullopt, output_mem_config); - return glu_result; -} -Tensor glu( - const Tensor& input_a, - int32_t dim /* = -1 */, - const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { - return operation::decorate_as_composite(__func__, _glu)(input_a, dim, output_mem_config); -} - -// ReLU Gated Linear Unit activation: matmul(split[0],relu(split[1])) -Tensor _reglu( - const Tensor& input_a, - int32_t dim /* = -1 */, - const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { - TT_ASSERT(dim == -1 || dim == 3, "last dim REGLU only supported at this time "); - if (dim == -1) - dim = 3; - std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); - Tensor relu_b = ttnn::relu(ab[1], output_mem_config); - Tensor reglu_result = ttnn::multiply(ab[0], relu_b, std::nullopt, output_mem_config); - return reglu_result; -} -Tensor reglu( - const Tensor& input_a, - int32_t dim /* = -1 */, - const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { - return operation::decorate_as_composite(__func__, _reglu)(input_a, dim, output_mem_config); -} - -// Gaussian Error Gated Linear Unit activation: matmul(split[0],gelu(split[1])) -Tensor _geglu( - const Tensor& input_a, - int32_t dim /* = -1 */, - const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { - TT_ASSERT(dim == -1 || dim == 3, "last dim GEGLU only supported at this time "); - if (dim == -1) - dim = 3; - - std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); - - constexpr bool fast_appx = true; - Tensor gelu_b = ttnn::gelu(ab[1], fast_appx, output_mem_config); - Tensor geglu_result = ttnn::multiply(ab[0], gelu_b, std::nullopt, output_mem_config); - return geglu_result; -} -Tensor geglu( - const Tensor& input_a, - int32_t dim /* = -1 */, - const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { - return operation::decorate_as_composite(__func__, _geglu)(input_a, dim, output_mem_config); -} - -// Swish Gated Linear Unit activation: matmul(split[0],swish(split[1])) -Tensor _swiglu( - const Tensor& input_a, - int32_t dim /* = -1 */, - const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { - TT_ASSERT(dim == -1 || dim == 3, "last dim SWIGLU only supported at this time "); - if (dim == -1) - dim = 3; - - std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); - - Tensor swish_b = swish(ab[1], output_mem_config); - Tensor swiglu_result = ttnn::multiply(ab[0], swish_b, std::nullopt, output_mem_config); - return swiglu_result; -} -Tensor swiglu( - const Tensor& input_a, - int32_t dim /* = -1 */, - const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { - return operation::decorate_as_composite(__func__, _swiglu)(input_a, dim, output_mem_config); -} // on-device tensor creation with shape and filled with value Tensor _sfpu_eps(const Shape shape, Layout layout, Device* device, const MemoryConfig& output_mem_config) { diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.hpp index c5ea4f7efe6..5927179780c 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.hpp @@ -563,27 +563,6 @@ Tensor logical_ori( float immediate, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -// Gated Linear Unit activation -Tensor glu( - const Tensor& input_a, - int32_t dim = -1, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -// ReLU based GLU -Tensor reglu( - const Tensor& input_a, - int32_t dim = -1, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -// Gelu based GLU -Tensor geglu( - const Tensor& input_a, - int32_t dim = -1, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); -// Swish based GLU -Tensor swiglu( - const Tensor& input_a, - int32_t dim = -1, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); - // on-device tensor creation with shape and filled with value Tensor sfpu_eps(const Shape shape, Layout layout, Device* device, const MemoryConfig& output_mem_config); diff --git a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index d72e7d84fa3..21564c7794a 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -332,13 +332,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) { R"doc(Returns tensor with the polyval of all of elements of the input tensor ``{0}`` with coefficients ``{1}``.)doc", R"doc("coefficients value with highest degree first", "List of float", "List size > 0")doc"); - detail::bind_unary_op_with_param( - m_tensor, - "glu", - &glu, - py::arg("dim") = -1, - R"doc(Applies the Gated Linear Units (GLU) function to the elements of the input tensor ``{0}`` split along dim ``{1}``.)doc", - R"doc(dimension to split)doc"); m_tensor.def( "prod", &prod, @@ -362,27 +355,7 @@ void TensorModuleCompositeOPs(py::module& m_tensor) { "dim", "Dimension to perform prod", "int", "default to 0", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - detail::bind_unary_op_with_param( - m_tensor, - "geglu", - &geglu, - py::arg("dim") = -1, - R"doc(Applies the Gaussian Error Gated Linear Units function to the elements of the input tensor ``{0}`` split along dim ``{1}``.)doc", - R"doc(dimension to split)doc"); - detail::bind_unary_op_with_param( - m_tensor, - "reglu", - ®lu, - py::arg("dim") = -1, - R"doc(Applies the Rectified Linear Gated Linear Units (ReGLU) function to the elements of the input tensor ``{0}`` split along dim ``{1}``.)doc", - R"doc(dimension to split)doc"); - detail::bind_unary_op_with_param( - m_tensor, - "swiglu", - &swiglu, - py::arg("dim") = -1, - R"doc(Applies the Swish Gated Linear Units (SwiGLU) function to the elements of the input tensor ``{0}`` split along dim ``{1}``.)doc", - R"doc(dimension to split)doc"); + detail::bind_unary_op_with_param( m_tensor, "logical_andi", diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index a3b8294a97a..593ae8f0faf 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -15,7 +15,8 @@ #include "ttnn/run_operation.hpp" #include "ttnn/types.hpp" #include "tt_metal/common/bfloat16.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/reduce/reduce_op.hpp" +#include "tt_dnn/op_library/reduce/reduce_op.hpp" +#include "ttnn/operations/data_movement/slice/slice.hpp" namespace ttnn::operations::unary{ @@ -399,7 +400,6 @@ Tensor _normalize(const Tensor& y, const std::optional& output_mem // PyTorch version: // hard sigmoid(x) = { x <= -3: 0, x >= +3: +3, x/6 + 0.5 otherwise} Tensor _hardsigmoid(const Tensor& a, float value_1, float value_2, const std::optional& output_mem_config) { -// std::cout<<"\n\n hit in ttnn hardsigmoid"; Tensor a_t = ttnn::full_like(a,value_1); Tensor b_t = ttnn::full_like(a,value_2); @@ -413,7 +413,6 @@ Tensor _hardsigmoid(const Tensor& a, float value_1, float value_2, const std::op // Ref: PyTorch // hard swish(x) = x*hardsigmoid(x,scale,shift) Tensor _hardswish(const Tensor& a, float value_1, float value_2, const std::optional& output_mem_config) { -// std::cout<<"\n\n hit in ttnn hardswish"; Tensor a_sigmoid = _hardsigmoid(a, value_1, value_2, output_mem_config); Tensor result_sq = ttnn::multiply(a_sigmoid, a, std::nullopt); return result_sq; @@ -496,4 +495,83 @@ Tensor _threshold(const Tensor& input_tensor, float threshold, float value, cons Tensor t2 = ttnn::multiply(ttnn::gtz(t0, output_mem_config), input_tensor, std::nullopt, output_mem_config); return ttnn::add(t1, t2, std::nullopt, output_mem_config); } + + +std::vector split_tensor_for_glu(const Tensor& input_a, int32_t dim, const std::optional& output_mem_config) { + std::vector t_split; + Shape inshape(input_a.get_legacy_shape()); + TT_FATAL(((inshape[dim] / 2) % TILE_WIDTH == 0), "Split tensor dimension should be in full tile"); + std::vector s_a = {0, 0, 0, 0}; + std::vector e_a = {input_a.get_legacy_shape()[0] - 1, inshape[1] - 1, inshape[2] - 1, inshape[3] / 2 - 1}; + + std::vector s_b = {0, 0, 0, inshape[3] / 2}; + std::vector e_b = {inshape[0] - 1, inshape[1] - 1, inshape[2] - 1, inshape[3] - 1}; + + Tensor t_a = ttnn::slice(0, input_a, s_a, e_a, output_mem_config); + Tensor t_b = ttnn::slice(0, input_a, s_b, e_b, output_mem_config); + + t_split.emplace_back(t_a); + t_split.emplace_back(t_b); + + return t_split; +} + +// Gated Linear Unit activation: matmul(split[0],sigmoid(split[1])) + Tensor _glu(const Tensor& input_a, int32_t dim , const std::optional& output_mem_config) { + TT_ASSERT(dim == -1 || dim == 3, "last dim GLU only supported at this time "); + if (dim == -1) + dim = 3; + std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); + Tensor sigmoid_b = ttnn::sigmoid(ab[1], output_mem_config); + Tensor glu_result = ttnn::multiply(ab[0], sigmoid_b, std::nullopt, output_mem_config); + return glu_result; +} + +// ReLU Gated Linear Unit activation: matmul(split[0],relu(split[1])) +Tensor _reglu( + const Tensor& input_a, + int32_t dim, + const std::optional& output_mem_config) { + TT_ASSERT(dim == -1 || dim == 3, "last dim REGLU only supported at this time "); + if (dim == -1) + dim = 3; + std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); + Tensor relu_b = ttnn::relu(ab[1], output_mem_config); + Tensor reglu_result = ttnn::multiply(ab[0], relu_b, std::nullopt, output_mem_config); + return reglu_result; +} + +// Gaussian Error Gated Linear Unit activation: matmul(split[0],gelu(split[1])) +Tensor _geglu( + const Tensor& input_a, + int32_t dim, + const std::optional& output_mem_config ) { + TT_ASSERT(dim == -1 || dim == 3, "last dim GEGLU only supported at this time "); + if (dim == -1) + dim = 3; + + std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); + + constexpr bool fast_appx = true; + Tensor gelu_b = ttnn::gelu(ab[1], fast_appx, output_mem_config); + Tensor geglu_result = ttnn::multiply(ab[0], gelu_b, std::nullopt, output_mem_config); + return geglu_result; +} + +// Swish Gated Linear Unit activation: matmul(split[0],swish(split[1])) +Tensor _swiglu( + const Tensor& input_a, + int32_t dim, + const std::optional& output_mem_config ) { + TT_ASSERT(dim == -1 || dim == 3, "last dim SWIGLU only supported at this time "); + if (dim == -1) + dim = 3; + + std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); + + Tensor swish_b = _swish(ab[1], output_mem_config); + Tensor swiglu_result = ttnn::multiply(ab[0], swish_b, std::nullopt, output_mem_config); + return swiglu_result; +} + } // namespace ttnn::operations::unary diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp index d8f14c3ab29..14252408113 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp @@ -38,6 +38,10 @@ enum class UnaryCompositeOpType { CLAMP, SELU, THRESHOLD, + GLU, + REGLU, + GEGLU, + SWIGLU, }; Tensor _tanhshrink (const Tensor&, const std::optional&); @@ -71,6 +75,10 @@ Tensor _clip(const Tensor&, float, float, const std::optional& ); Tensor _clamp(const Tensor&, float, float, const std::optional& ); Tensor _selu(const Tensor&, float, float, const std::optional& ); Tensor _threshold(const Tensor&, float, float, const std::optional& ); +Tensor _glu(const Tensor&, int32_t, const std::optional& ); +Tensor _reglu(const Tensor&, int32_t, const std::optional& ); +Tensor _geglu(const Tensor&, int32_t, const std::optional& ); +Tensor _swiglu(const Tensor&, int32_t, const std::optional& ); // OpHandler struct template template @@ -88,6 +96,9 @@ struct OpHandler_low_high; template struct OpHandler_threshold_value; +template +struct OpHandler_dim; + template <> struct OpHandler { static Tensor handle(const Tensor& t1, const std::optional& mem_cfg ) { @@ -277,6 +288,35 @@ struct OpHandler_threshold_value { } }; +//glu (geglu, reglu, swiglu, glu) varinats are supported only for last dimension. +template <> +struct OpHandler_dim { + static Tensor handle(const Tensor& t1, int32_t dim, const std::optional& mem_cfg ) { + return _glu(t1, dim, mem_cfg); + } +}; + +template <> +struct OpHandler_dim { + static Tensor handle(const Tensor& t1, int32_t dim, const std::optional& mem_cfg ) { + return _reglu(t1, dim, mem_cfg); + } +}; + +template <> +struct OpHandler_dim { + static Tensor handle(const Tensor& t1, int32_t dim, const std::optional& mem_cfg ) { + return _geglu(t1, dim, mem_cfg); + } +}; + +template <> +struct OpHandler_dim { + static Tensor handle(const Tensor& t1, int32_t dim, const std::optional& mem_cfg ) { + return _swiglu(t1, dim, mem_cfg); + } +}; + // Template functions to get the function pointers template auto get_function_type1() { @@ -303,4 +343,8 @@ auto get_function_type5() { return &OpHandler_threshold_value::handle; } +template +auto get_glu_fn() { + return &OpHandler_dim::handle; +} } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp index b3be58f25ac..7a5bfb9cbe6 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp @@ -142,6 +142,20 @@ struct ExecuteUnaryCompositeOpWithScaleAlpha }; +template +struct ExecuteUnaryCompositeOpWithDim +{ + static ttnn::Tensor execute_on_worker_thread( + const Tensor& input_tensor, + int32_t dim, + const std::optional& memory_config = std::nullopt) + { + auto op_type = get_glu_fn(); + return op_type(input_tensor, dim, memory_config); + } +}; + + template struct ExecuteUnaryCompositeOpWithThresholdValue { @@ -278,4 +292,9 @@ constexpr auto clamp = ttnn::register_operation>("ttnn::selu"); constexpr auto threshold = ttnn::register_operation>("ttnn::threshold"); +constexpr auto glu = ttnn::register_operation>("ttnn::glu"); +constexpr auto reglu = ttnn::register_operation>("ttnn::reglu"); +constexpr auto geglu = ttnn::register_operation>("ttnn::geglu"); +constexpr auto swiglu = ttnn::register_operation>("ttnn::swiglu"); + } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index a8a267b9fa3..bd6bdb041c4 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -357,6 +357,63 @@ void bind_unary_operation_with_integer_parameter( py::arg("queue_id") = 0}); } + +template +void bind_unary_operation_with_dim_parameter( + py::module& module, + const unary_operation_t& operation, + const std::string& parameter_name, + const std::string& parameter_doc, + const std::string& info_doc) { + + auto doc = fmt::format( + R"doc({0}(input_tensor: ttnn.Tensor, *, {2}: int32_t = -1, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Applies {0} to :attr:`input_tensor` element-wise. + + {4} + + .. math:: + {0}(\\mathrm{{input\\_tensor}}_i) + + Args: + * :attr:`input_tensor` + + Keyword Args: + * :attr:`{2}` (int32_t): {3}. + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + * :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor + * :attr:`queue_id` (Optional[uint8]): command queue id + + Example: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = {1}(tensor, {2}) + )doc", + operation.base_name(), + operation.python_fully_qualified_name(), + parameter_name, + parameter_doc, + info_doc); + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const unary_operation_t& self, + const Tensor& input_tensor, + int dim, + const std::optional& memory_config) { + return self(input_tensor, dim, memory_config); + }, + py::arg("input_tensor"), + py::arg("dim") = -1, + py::kw_only(), + py::arg("memory_config") = std::nullopt}); +} + + template void bind_softplus(py::module& module, const unary_operation_t& operation) { auto doc = fmt::format( @@ -1099,6 +1156,12 @@ void py_module(py::module& module) { detail::bind_unary_operation_with_integer_parameter(module, ttnn::bitwise_not, "value", "scalar value", "Input tensor needs to be in the range [-2147483647, 2147483647], INT32 dtype. Support provided only for Wormhole_B0."); + // Unary ops with dim parameter + detail::bind_unary_operation_with_dim_parameter(module, ttnn::glu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply glu function on second tensor followed by mul op with first tensor"); + detail::bind_unary_operation_with_dim_parameter(module, ttnn::reglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply relu function on second tensor followed by mul op with first tensor"); + detail::bind_unary_operation_with_dim_parameter(module, ttnn::geglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply gelu function on second tensor followed by mul op with first tensor"); + detail::bind_unary_operation_with_dim_parameter(module, ttnn::swiglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply silu function on second tensor followed by mul op with first tensor"); + // Other unaries (unary chain operations) detail::bind_softplus(module, ttnn::softplus); detail::bind_sigmoid_accurate(module, ttnn::sigmoid_accurate); diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index d3cbdd51211..f75531f30fd 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -376,6 +376,57 @@ def _golden_function_bitwise_not(input_tensor_a, value, *args, **kwargs): ttnn.attach_golden_function(ttnn._ttnn.operations.unary.bitwise_not, golden_function=_golden_function_bitwise_not) +def _golden_function_glu(input_tensor_a, dim, *args, **kwargs): + import torch + + return torch.nn.functional.glu(input_tensor_a, dim) + + +ttnn.attach_golden_function(ttnn._ttnn.operations.unary.glu, golden_function=_golden_function_glu) + + +def _golden_function_reglu(input_tensor_a, dim, *args, **kwargs): + import torch + + assert isinstance(dim, int), "dim must be an integer" + assert dim in [-1, 3], "dim must be -1 or 3" + split_size = input_tensor_a.size(-1) // 2 + split_tensors = torch.split(input_tensor_a, split_size_or_sections=[split_size, split_size], dim=dim) + tensA, tensB = split_tensors[0], split_tensors[1] + return tensA * torch.nn.functional.relu(tensB) + + +ttnn.attach_golden_function(ttnn._ttnn.operations.unary.reglu, golden_function=_golden_function_reglu) + + +def _golden_function_geglu(input_tensor_a, dim, *args, **kwargs): + import torch + + assert isinstance(dim, int), "dim must be an integer" + assert dim in [-1, 3], "dim must be -1 or 3" + split_size = input_tensor_a.size(-1) // 2 + split_tensors = torch.split(input_tensor_a, split_size_or_sections=[split_size, split_size], dim=dim) + tensA, tensB = split_tensors[0], split_tensors[1] + return tensA * torch.nn.functional.gelu(tensB) + + +ttnn.attach_golden_function(ttnn._ttnn.operations.unary.geglu, golden_function=_golden_function_geglu) + + +def _golden_function_swiglu(input_tensor_a, dim, *args, **kwargs): + import torch + + assert isinstance(dim, int), "dim must be an integer" + assert dim in [-1, 3], "dim must be -1 or 3" + split_size = input_tensor_a.size(-1) // 2 + split_tensors = torch.split(input_tensor_a, split_size_or_sections=[split_size, split_size], dim=dim) + tensA, tensB = split_tensors[0], split_tensors[1] + return tensA * torch.nn.functional.silu(tensB) + + +ttnn.attach_golden_function(ttnn._ttnn.operations.unary.swiglu, golden_function=_golden_function_swiglu) + + def _is_scalar(value): return isinstance(value, (int, float)) @@ -525,102 +576,4 @@ def activation_function( activation_function_name, ttl_activation_function, param1, param2 ) - -def torch_reglu(input_tensor, *args, **kwargs): - import torch - - split_size = input_tensor.size(-1) // 2 - split_tensors = torch.split(input_tensor, split_size_or_sections=[split_size, split_size], dim=-1) - tensA, tensB = split_tensors[0], split_tensors[1] - return tensA * torch.nn.functional.relu(tensB) - - -def torch_swiglu(input_tensor, *args, **kwargs): - import torch - - split_size = input_tensor.size(-1) // 2 - split_tensors = torch.split(input_tensor, split_size_or_sections=[split_size, split_size], dim=-1) - tensA, tensB = split_tensors[0], split_tensors[1] - return tensA * torch.nn.functional.silu(tensB) - - -def torch_geglu(input_tensor, *args, **kwargs): - import torch - - split_size = input_tensor.size(-1) // 2 - split_tensors = torch.split(input_tensor, split_size_or_sections=[split_size, split_size], dim=-1) - tensA, tensB = split_tensors[0], split_tensors[1] - return tensA * torch.nn.functional.gelu(tensB) - - -def register_ttl_activation_function_glu(name, ttl_activation_function, param): - def _golden_function(input_tensor: ttnn.Tensor, dim: int = -1, **_): - import torch - - name_to_torch_function = { - "glu": torch.nn.functional.glu, - "reglu": torch_reglu, - "swiglu": torch_swiglu, - "geglu": torch_geglu, - } - torch_function = name_to_torch_function[name] - input_tensor = ttnn.to_torch(input_tensor) - - return torch_function(input_tensor, dim=dim) - - doc = f"""{(name)}(input_tensor: ttnn.Tensor, dim: int = -1, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor - - Applies the {name} function to the elements of the input tensor :attr:`input_tensor` split along :attr:`{param}`. - - .. math:: - {(name)}(\\mathrm{{input\\_tensor}}_i \\; , \\; {param}) - - Args: - * :attr:`input_tensor` - * :attr:`{param}` - - Example:: - - >>> tensor = ttnn.from_torch(torch.tensor((32, 64), dtype=torch.bfloat16), device=device) - >>> output = ttnn.{(name)}(tensor, {param}) - - """ - - @ttnn.register_python_operation(name=f"ttnn.{name}", golden_function=_golden_function, doc=doc) - def activation_function( - input_tensor: ttnn.Tensor, dim: int = -1, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG - ) -> ttnn.Tensor: - input_shape = tuple(input_tensor.shape) - last_dim = input_shape[-1] - glu_shape = input_shape[:-1] + (int(last_dim / 2),) - - input_tensor = ttnn.unsqueeze_to_4D(input_tensor) - - if not isinstance(input_tensor, ttnn.Tensor): - raise TypeError("Expected first argument to be a ttnn.Tensor") - - if not _is_scalar(dim): - raise TypeError("Expected second argument to be a float") - - if not ttnn.is_tensor_storage_on_device(input_tensor): - raise RuntimeError("input_tensor must be on device!") - - output_tensor = ttl_activation_function(input_tensor, dim, output_mem_config=memory_config) - - output_tensor = ttnn.reshape(output_tensor, ttnn.Shape(glu_shape)) - return output_tensor - - -TTL_ACTIVATION_FUNCTIONS_GLU = [ - ("glu", ttl.tensor.glu, "dim"), # composite - ("reglu", ttl.tensor.reglu, "dim"), # composite - ("swiglu", ttl.tensor.swiglu, "dim"), # composite - ("geglu", ttl.tensor.geglu, "dim"), # composite -] - - -for activation_function_name, ttl_activation_function, param in TTL_ACTIVATION_FUNCTIONS_GLU: - register_ttl_activation_function_glu(activation_function_name, ttl_activation_function, param) - - __all__ = []