Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8114: Move softplus to ttnn c++ #8366

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ jobs:
# So we can get all the makefile output we want
SILENT: 0
runs-on: ${{ matrix.config.runs-on }}
name: build ${{ matrix.config.type }} ${{ matrix.arch }}
name: make build ${{ matrix.config.type }} ${{ matrix.arch }}
steps:
- uses: actions/checkout@v4
with:
Expand Down Expand Up @@ -131,7 +131,7 @@ jobs:
arch: [grayskull, wormhole_b0]
os: [ubuntu-20.04]
needs: build-lib
name: build cpptest ${{ matrix.config }} ${{ matrix.arch }}
name: make build cpptest ${{ matrix.config }} ${{ matrix.arch }}
env:
ARCH_NAME: ${{ matrix.arch }}
CONFIG: ${{ matrix.config }}
Expand Down
2 changes: 1 addition & 1 deletion models/demos/mamba/tt/mamba_one_step_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def forward(self, x):
)
ttnn.deallocate(delta_t0)

delta_t2 = ttnn.softplus(delta_t1, parameter1=1.0, parameter2=20.0, memory_config=ttnn.L1_MEMORY_CONFIG)
delta_t2 = ttnn.softplus(delta_t1, beta=1.0, threshold=20.0, memory_config=ttnn.L1_MEMORY_CONFIG)
ttnn.deallocate(delta_t1)

# calculate abar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run_eltwise_softplus_tests(
ref_value = torch.nn.functional.softplus(x, beta=beta, threshold=threshold)

x = ttnn_ops.setup_ttnn_tensor(x, device, dlayout[0], in_mem_config[0], dtype[0])
tt_result = ttnn.softplus(x, beta, threshold, memory_config=output_mem_config)
tt_result = ttnn.softplus(x, beta=beta, threshold=threshold, memory_config=output_mem_config)

tt_result = ttnn_ops.ttnn_tensor_to_torch(tt_result, output_mem_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run_eltwise_softplus_tests(
ref_value = torch.nn.functional.softplus(x, beta=beta, threshold=threshold)

x = ttnn_ops.setup_ttnn_tensor(x, device, dlayout[0], in_mem_config[0], dtype[0])
tt_result = ttnn.softplus(x, beta, threshold, memory_config=output_mem_config)
tt_result = ttnn.softplus(x, beta=beta, threshold=threshold, memory_config=output_mem_config)

tt_result = ttnn_ops.ttnn_tensor_to_torch(tt_result, output_mem_config)

Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ def eltwise_softplus(
**kwargs,
):
t0 = setup_ttnn_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttnn.softplus(t0, beta, threshold, memory_config=memory_config_to_ttnn(output_mem_config))
t1 = ttnn.softplus(t0, beta=beta, threshold=threshold, memory_config=memory_config_to_ttnn(output_mem_config))
return ttnn_tensor_to_torch(t1)


Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/sweep_tests/sweeps/sweeps/softplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run(
torch_input_tensor, dtype=input_dtype, device=device, memory_config=input_memory_config, layout=layout
)

output_tensor = ttnn.softplus(input_tensor, beta, threshold, memory_config=output_memory_config)
output_tensor = ttnn.softplus(input_tensor, beta=beta, threshold=threshold, memory_config=output_memory_config)
output_tensor = ttnn.to_torch(output_tensor)

return check_with_pcc(torch_output_tensor, output_tensor, 0.99)
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def run_activation_softplus_test(device, h, w, beta, threshold, ttnn_function, t

input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn_function(input_tensor_a, beta, threshold)
output_tensor = ttnn_function(input_tensor_a, beta=beta, threshold=threshold)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ EAGER_OUTPUT_DIR = $(OUT)/dist

TT_EAGER_INCLUDES = $(TT_METAL_BASE_INCLUDES) -Itt_eager/ -I ttnn/cpp/

include tt_eager/queue/module.mk
include tt_eager/tensor/module.mk
include tt_eager/tt_dnn/module.mk
include tt_eager/queue/module.mk
include tt_eager/tt_lib/module.mk

TT_LIBS_TO_BUILD = tt_eager/tensor \
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tensor/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ TENSOR_DEPS = $(addprefix $(OBJDIR)/, $(TENSOR_SRCS:.cpp=.d))
# Each module has a top level target as the entrypoint which must match the subdir name
tt_eager/tensor: $(TENSOR_LIB)

$(TENSOR_LIB): $(COMMON_LIB) $(TT_METAL_LIB) $(TENSOR_OBJS)
$(TENSOR_LIB): $(COMMON_LIB) $(TT_METAL_LIB) $(TENSOR_OBJS) $(QUEUE_LIB)
@mkdir -p $(LIBDIR)
ar rcs -o $@ $(TENSOR_OBJS)

Expand Down
1 change: 1 addition & 0 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "tt_dnn/op_library/prod/prod_op_all.hpp"
#include "tt_dnn/op_library/permute/permute_op.hpp"
#include "tt_eager/tt_dnn/op_library/unpad/unpad_op.hpp"

namespace tt {

namespace tt_metal {
Expand Down
38 changes: 38 additions & 0 deletions ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,49 @@ void bind_unary_with_bool_parameter_set_to_false_by_default(py::module& module,
py::arg("memory_config") = std::nullopt});
}

void bind_softplus(py::module& module) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, *, beta: float = 1.0, threshold: float = 20.0, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor

Applies {0} to :attr:`input_tensor` element-wise.

.. math::
{0}(\\mathrm{{input\\_tensor}}_i)

Args:
* :attr:`input_tensor`

Keyword Args:
* :attr:`beta` (float): Scales the input before applying the Softplus function. By modifying beta, you can adjust the steepness of the function. A higher beta value makes the function steeper, approaching a hard threshold like the ReLU function for large values of beta
* :attr:`threshold` (float): Used to switch to a linear function for large values to improve numerical stability. This avoids issues with floating-point representation for very large values
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.

Example::

>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, parameter=true)
)doc",
ttnn::softplus.name(),
ttnn::softplus.python_fully_qualified_name());

bind_registered_operation(
module,
ttnn::softplus,
doc,
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::kw_only(),
py::arg("beta") = 1.0f,
py::arg("threshold") = 20.0f,
py::arg("memory_config") = std::nullopt});
}

} // namespace detail

void py_module(py::module& module) {
detail::bind_unary_with_bool_parameter_set_to_false_by_default(module, ttnn::exp);
detail::bind_unary(module, ttnn::silu);
detail::bind_softplus(module);
}

} // namespace unary
Expand Down
33 changes: 33 additions & 0 deletions ttnn/cpp/ttnn/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp"
#include "tt_eager/tt_dnn/op_library/run_operation.hpp"
#include "tt_eager/tt_dnn/op_library/composite/composite_ops.hpp"
#include "ttnn/operations/core.hpp"
#include "ttnn/decorators.hpp"
#include "ttnn/validation.hpp"

Expand Down Expand Up @@ -102,11 +104,42 @@ struct Exp : public EltwiseUnary {
memory_config);
}
};

struct Softplus : public EltwiseUnary {
static const std::array<TensorSchema, 1> input_tensor_schemas() { return detail::input_tensor_schemas(); }

template <typename... Args>
static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) {
return detail::input_tensors_to_validate(input_tensor, std::forward<Args>(args)...);
};

template <typename... Args>
static auto map_launch_op_args_to_execute(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
Args&&... args) {
return std::make_tuple(input_tensors.at(0), std::forward<Args>(args)...);
}

static Tensor execute(const Tensor& input, const float beta, const float threshold, const std::optional<MemoryConfig>& memory_config_arg = std::nullopt) {
auto original_input_shape = input.get_shape();
auto input_4D = ttnn::unsqueeze_to_4D(input);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to get rid of the unsqueeze_to_4D?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eyonland , it may be possible if we add validation and rewrite tests to use 4D tensors. Underlying operations only support 4D right now. Keeping the change minimal, migrating from py to c++ for now.

I will consult with @arakhmati whether we want to enforce on the contract level.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, to clarify, this is not an addition. This is moved from py to c++.

auto memory_config = memory_config_arg.value_or(input_4D.memory_config());
auto result = tt::tt_metal::softplus(input_4D, beta, threshold, memory_config);

result = ttnn::reshape(result, original_input_shape);

return result;
}
};
} // namespace unary
} // namespace operations

constexpr auto exp = ttnn::register_operation<ttnn::operations::unary::Exp>("ttnn::exp");

constexpr auto softplus = ttnn::register_operation<ttnn::operations::unary::Softplus>("ttnn::softplus");

constexpr auto silu =
ttnn::register_operation<ttnn::operations::unary::Unary<ttnn::operations::unary::UnaryOpType::SILU>>("ttnn::silu");

Expand Down
2 changes: 1 addition & 1 deletion ttnn/ttnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def manage_config(name, value):
logical_not,
logit,
signbit,
softplus,
)

from ttnn.operations.binary import (
Expand Down Expand Up @@ -403,7 +404,6 @@ def manage_config(name, value):
softshrink,
softsign,
swish,
softplus,
tanhshrink,
threshold,
glu,
Expand Down
1 change: 0 additions & 1 deletion ttnn/ttnn/operations/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@ def activation_function(
TTL_ACTIVATION_FUNCTIONS_WITH_TWO_FLOAT_PARAM = [
("clip", ttl.tensor.clip, "clip", "min", "max"),
("threshold", ttl.tensor.threshold, "threshold", "value", "threshold"),
("softplus", ttl.tensor.softplus, "softplus", "beta", "threshold"),
]

TTL_ACTIVATION_FUNCTIONS_GLU = [
Expand Down
2 changes: 2 additions & 0 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):
ttnn_function_to_golden_function = {
ttnn._ttnn.operations.unary.exp: torch.exp,
ttnn._ttnn.operations.unary.silu: torch.nn.functional.silu,
ttnn._ttnn.operations.unary.softplus: torch.nn.functional.softplus,
}
torch_function = ttnn_function_to_golden_function[unary_function]
return torch_function(input_tensor)
Expand All @@ -219,6 +220,7 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):
TTNN_ELTWISE_UNARY_CPP_FUNCTIONS = [
ttnn._ttnn.operations.unary.exp,
ttnn._ttnn.operations.unary.silu,
ttnn._ttnn.operations.unary.softplus,
]
for unary_function in TTNN_ELTWISE_UNARY_CPP_FUNCTIONS:
register_eltwise_unary_cpp_function(unary_function)
Expand Down
Loading