diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 5bed4846dc0..e775c690fc6 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -96,7 +96,7 @@ jobs: # So we can get all the makefile output we want SILENT: 0 runs-on: ${{ matrix.config.runs-on }} - name: build ${{ matrix.config.type }} ${{ matrix.arch }} + name: make build ${{ matrix.config.type }} ${{ matrix.arch }} steps: - uses: actions/checkout@v4 with: @@ -131,7 +131,7 @@ jobs: arch: [grayskull, wormhole_b0] os: [ubuntu-20.04] needs: build-lib - name: build cpptest ${{ matrix.config }} ${{ matrix.arch }} + name: make build cpptest ${{ matrix.config }} ${{ matrix.arch }} env: ARCH_NAME: ${{ matrix.arch }} CONFIG: ${{ matrix.config }} diff --git a/models/demos/mamba/tt/mamba_one_step_ssm.py b/models/demos/mamba/tt/mamba_one_step_ssm.py index a4a5c18acbb..5cf769e75ae 100644 --- a/models/demos/mamba/tt/mamba_one_step_ssm.py +++ b/models/demos/mamba/tt/mamba_one_step_ssm.py @@ -129,7 +129,7 @@ def forward(self, x): ) ttnn.deallocate(delta_t0) - delta_t2 = ttnn.softplus(delta_t1, parameter1=1.0, parameter2=20.0, memory_config=ttnn.L1_MEMORY_CONFIG) + delta_t2 = ttnn.softplus(delta_t1, beta=1.0, threshold=20.0, memory_config=ttnn.L1_MEMORY_CONFIG) ttnn.deallocate(delta_t1) # calculate abar diff --git a/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_eltwise_softplus.py b/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_eltwise_softplus.py index 4294cf77b5c..3e142acbea4 100644 --- a/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_eltwise_softplus.py +++ b/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_eltwise_softplus.py @@ -32,7 +32,7 @@ def run_eltwise_softplus_tests( ref_value = torch.nn.functional.softplus(x, beta=beta, threshold=threshold) x = ttnn_ops.setup_ttnn_tensor(x, device, dlayout[0], in_mem_config[0], dtype[0]) - tt_result = ttnn.softplus(x, beta, threshold, memory_config=output_mem_config) + tt_result = ttnn.softplus(x, beta=beta, threshold=threshold, memory_config=output_mem_config) tt_result = ttnn_ops.ttnn_tensor_to_torch(tt_result, output_mem_config) diff --git a/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_eltwise_softplus_inf.py b/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_eltwise_softplus_inf.py index 3f04ed526a6..b1bbd87fbd5 100644 --- a/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_eltwise_softplus_inf.py +++ b/tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_eltwise_softplus_inf.py @@ -32,7 +32,7 @@ def run_eltwise_softplus_tests( ref_value = torch.nn.functional.softplus(x, beta=beta, threshold=threshold) x = ttnn_ops.setup_ttnn_tensor(x, device, dlayout[0], in_mem_config[0], dtype[0]) - tt_result = ttnn.softplus(x, beta, threshold, memory_config=output_mem_config) + tt_result = ttnn.softplus(x, beta=beta, threshold=threshold, memory_config=output_mem_config) tt_result = ttnn_ops.ttnn_tensor_to_torch(tt_result, output_mem_config) diff --git a/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py b/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py index ace463764cb..f791681c0b4 100644 --- a/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py +++ b/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py @@ -987,7 +987,7 @@ def eltwise_softplus( **kwargs, ): t0 = setup_ttnn_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttnn.softplus(t0, beta, threshold, memory_config=memory_config_to_ttnn(output_mem_config)) + t1 = ttnn.softplus(t0, beta=beta, threshold=threshold, memory_config=memory_config_to_ttnn(output_mem_config)) return ttnn_tensor_to_torch(t1) diff --git a/tests/ttnn/sweep_tests/sweeps/sweeps/softplus.py b/tests/ttnn/sweep_tests/sweeps/sweeps/softplus.py index 6962d62b0f3..be080eb4ef3 100644 --- a/tests/ttnn/sweep_tests/sweeps/sweeps/softplus.py +++ b/tests/ttnn/sweep_tests/sweeps/sweeps/softplus.py @@ -55,7 +55,7 @@ def run( torch_input_tensor, dtype=input_dtype, device=device, memory_config=input_memory_config, layout=layout ) - output_tensor = ttnn.softplus(input_tensor, beta, threshold, memory_config=output_memory_config) + output_tensor = ttnn.softplus(input_tensor, beta=beta, threshold=threshold, memory_config=output_memory_config) output_tensor = ttnn.to_torch(output_tensor) return check_with_pcc(torch_output_tensor, output_tensor, 0.99) diff --git a/tests/ttnn/unit_tests/operations/test_activation.py b/tests/ttnn/unit_tests/operations/test_activation.py index fd3db4964c5..779607c4bfa 100644 --- a/tests/ttnn/unit_tests/operations/test_activation.py +++ b/tests/ttnn/unit_tests/operations/test_activation.py @@ -109,7 +109,7 @@ def run_activation_softplus_test(device, h, w, beta, threshold, ttnn_function, t input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) - output_tensor = ttnn_function(input_tensor_a, beta, threshold) + output_tensor = ttnn_function(input_tensor_a, beta=beta, threshold=threshold) output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) output_tensor = ttnn.from_device(output_tensor) output_tensor = ttnn.to_torch(output_tensor) diff --git a/tt_eager/module.mk b/tt_eager/module.mk index db07c146e73..c228d86778b 100644 --- a/tt_eager/module.mk +++ b/tt_eager/module.mk @@ -5,9 +5,9 @@ EAGER_OUTPUT_DIR = $(OUT)/dist TT_EAGER_INCLUDES = $(TT_METAL_BASE_INCLUDES) -Itt_eager/ -I ttnn/cpp/ +include tt_eager/queue/module.mk include tt_eager/tensor/module.mk include tt_eager/tt_dnn/module.mk -include tt_eager/queue/module.mk include tt_eager/tt_lib/module.mk TT_LIBS_TO_BUILD = tt_eager/tensor \ diff --git a/tt_eager/tensor/module.mk b/tt_eager/tensor/module.mk index 21964bb6279..0a52cc25046 100644 --- a/tt_eager/tensor/module.mk +++ b/tt_eager/tensor/module.mk @@ -19,7 +19,7 @@ TENSOR_DEPS = $(addprefix $(OBJDIR)/, $(TENSOR_SRCS:.cpp=.d)) # Each module has a top level target as the entrypoint which must match the subdir name tt_eager/tensor: $(TENSOR_LIB) -$(TENSOR_LIB): $(COMMON_LIB) $(TT_METAL_LIB) $(TENSOR_OBJS) +$(TENSOR_LIB): $(COMMON_LIB) $(TT_METAL_LIB) $(TENSOR_OBJS) $(QUEUE_LIB) @mkdir -p $(LIBDIR) ar rcs -o $@ $(TENSOR_OBJS) diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index d41f8f5aa8f..a11e3e15731 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -20,6 +20,7 @@ #include "tt_dnn/op_library/prod/prod_op_all.hpp" #include "tt_dnn/op_library/permute/permute_op.hpp" #include "tt_eager/tt_dnn/op_library/unpad/unpad_op.hpp" + namespace tt { namespace tt_metal { diff --git a/ttnn/cpp/pybind11/operations/unary.hpp b/ttnn/cpp/pybind11/operations/unary.hpp index 7cd5b3cd4f3..f98b4c1aecb 100644 --- a/ttnn/cpp/pybind11/operations/unary.hpp +++ b/ttnn/cpp/pybind11/operations/unary.hpp @@ -94,11 +94,49 @@ void bind_unary_with_bool_parameter_set_to_false_by_default(py::module& module, py::arg("memory_config") = std::nullopt}); } +void bind_softplus(py::module& module) { + auto doc = fmt::format( + R"doc({0}(input_tensor: ttnn.Tensor, *, beta: float = 1.0, threshold: float = 20.0, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + + Applies {0} to :attr:`input_tensor` element-wise. + + .. math:: + {0}(\\mathrm{{input\\_tensor}}_i) + + Args: + * :attr:`input_tensor` + + Keyword Args: + * :attr:`beta` (float): Scales the input before applying the Softplus function. By modifying beta, you can adjust the steepness of the function. A higher beta value makes the function steeper, approaching a hard threshold like the ReLU function for large values of beta + * :attr:`threshold` (float): Used to switch to a linear function for large values to improve numerical stability. This avoids issues with floating-point representation for very large values + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation. + + Example:: + + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = {1}(tensor, parameter=true) + )doc", + ttnn::softplus.name(), + ttnn::softplus.python_fully_qualified_name()); + + bind_registered_operation( + module, + ttnn::softplus, + doc, + ttnn::pybind_arguments_t{ + py::arg("input_tensor"), + py::kw_only(), + py::arg("beta") = 1.0f, + py::arg("threshold") = 20.0f, + py::arg("memory_config") = std::nullopt}); +} + } // namespace detail void py_module(py::module& module) { detail::bind_unary_with_bool_parameter_set_to_false_by_default(module, ttnn::exp); detail::bind_unary(module, ttnn::silu); + detail::bind_softplus(module); } } // namespace unary diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index a485d72d5c9..7f39e48480a 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -6,6 +6,8 @@ #include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp" #include "tt_eager/tt_dnn/op_library/run_operation.hpp" +#include "tt_eager/tt_dnn/op_library/composite/composite_ops.hpp" +#include "ttnn/operations/core.hpp" #include "ttnn/decorators.hpp" #include "ttnn/validation.hpp" @@ -102,11 +104,42 @@ struct Exp : public EltwiseUnary { memory_config); } }; + +struct Softplus : public EltwiseUnary { + static const std::array input_tensor_schemas() { return detail::input_tensor_schemas(); } + + template + static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) { + return detail::input_tensors_to_validate(input_tensor, std::forward(args)...); + }; + + template + static auto map_launch_op_args_to_execute( + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + Args&&... args) { + return std::make_tuple(input_tensors.at(0), std::forward(args)...); + } + + static Tensor execute(const Tensor& input, const float beta, const float threshold, const std::optional& memory_config_arg = std::nullopt) { + auto original_input_shape = input.get_shape(); + auto input_4D = ttnn::unsqueeze_to_4D(input); + + auto memory_config = memory_config_arg.value_or(input_4D.memory_config()); + auto result = tt::tt_metal::softplus(input_4D, beta, threshold, memory_config); + + result = ttnn::reshape(result, original_input_shape); + + return result; + } +}; } // namespace unary } // namespace operations constexpr auto exp = ttnn::register_operation("ttnn::exp"); +constexpr auto softplus = ttnn::register_operation("ttnn::softplus"); + constexpr auto silu = ttnn::register_operation>("ttnn::silu"); diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 8333fc5e723..10168342bc2 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -334,6 +334,7 @@ def manage_config(name, value): logical_not, logit, signbit, + softplus, ) from ttnn.operations.binary import ( @@ -403,7 +404,6 @@ def manage_config(name, value): softshrink, softsign, swish, - softplus, tanhshrink, threshold, glu, diff --git a/ttnn/ttnn/operations/activation.py b/ttnn/ttnn/operations/activation.py index c5c6652930f..c93ae9ac7f8 100644 --- a/ttnn/ttnn/operations/activation.py +++ b/ttnn/ttnn/operations/activation.py @@ -378,7 +378,6 @@ def activation_function( TTL_ACTIVATION_FUNCTIONS_WITH_TWO_FLOAT_PARAM = [ ("clip", ttl.tensor.clip, "clip", "min", "max"), ("threshold", ttl.tensor.threshold, "threshold", "value", "threshold"), - ("softplus", ttl.tensor.softplus, "softplus", "beta", "threshold"), ] TTL_ACTIVATION_FUNCTIONS_GLU = [ diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 875f1a57585..ea62409cc25 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -208,6 +208,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): ttnn_function_to_golden_function = { ttnn._ttnn.operations.unary.exp: torch.exp, ttnn._ttnn.operations.unary.silu: torch.nn.functional.silu, + ttnn._ttnn.operations.unary.softplus: torch.nn.functional.softplus, } torch_function = ttnn_function_to_golden_function[unary_function] return torch_function(input_tensor) @@ -219,6 +220,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): TTNN_ELTWISE_UNARY_CPP_FUNCTIONS = [ ttnn._ttnn.operations.unary.exp, ttnn._ttnn.operations.unary.silu, + ttnn._ttnn.operations.unary.softplus, ] for unary_function in TTNN_ELTWISE_UNARY_CPP_FUNCTIONS: register_eltwise_unary_cpp_function(unary_function)