Skip to content

Commit

Permalink
#5389: revert ef62d9a
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed May 17, 2024
1 parent 060b8c1 commit 83af07c
Show file tree
Hide file tree
Showing 6 changed files with 407 additions and 488 deletions.
180 changes: 48 additions & 132 deletions ttnn/cpp/pybind11/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@ void bind_unary(py::module& module, const unary_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, *, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
Args:
* :attr:`input_tensor`
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Keyword Args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
operation.name(),
operation.python_fully_qualified_name());

Expand All @@ -51,28 +51,36 @@ void bind_unary(py::module& module, const unary_operation_t& operation) {
}

template <typename unary_operation_t>
void bind_unary_with_fast_and_approximate_mode(py::module& module, const unary_operation_t& operation) {
void bind_unary_with_bool_parameter_set_to_false_by_default(py::module& module, const unary_operation_t& operation) {
std::string parameter_description;
if (operation.name() == "exp") {
parameter_description = "Use fast and approximate mode";
} else {
TT_THROW("Unknown name!");
}

auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, *, fast_and_approximate_mode: bool = False, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
R"doc({0}(input_tensor: ttnn.Tensor, *, parameter: bool = False, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
Args:
* :attr:`input_tensor`
Keyword Args:
* :attr:`fast_and_approximate_mode` (bool): "Use fast and approximate mode".
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Keyword Args:
* :attr:`parameter` (bool): {2}.
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, fast_and_approximate_mode=true)
)doc",
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, parameter=true)
)doc",
operation.name(),
parameter_description,
operation.python_fully_qualified_name());

bind_registered_operation(
Expand All @@ -82,71 +90,32 @@ void bind_unary_with_fast_and_approximate_mode(py::module& module, const unary_o
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::kw_only(),
py::arg("fast_and_approximate_mode") = false,
py::arg("parameter") = false,
py::arg("memory_config") = std::nullopt});
}

template <typename unary_operation_t>
void bind_unary_with_float_parameter(
py::module& module,
const unary_operation_t& operation,
const std::string& parameter_name,
const std::string& parameter_doc) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, *, fast_and_approximate_mode: bool = False, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
Keyword Args:
* :attr:`{2}` (bool): {3}.
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, {2}=true)
)doc",
operation.name(),
operation.python_fully_qualified_name(),
parameter_name,
parameter_doc);

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_arguments_t{
py::arg("input_tensor"), py::arg(parameter_name.c_str()), py::kw_only(), py::arg("memory_config") = std::nullopt});
}

void bind_softplus(py::module& module) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, *, beta: float = 1.0, threshold: float = 20.0, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor` element-wise.
Applies {0} to :attr:`input_tensor` element-wise.
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
.. math::
{0}(\\mathrm{{input\\_tensor}}_i)
Args:
* :attr:`input_tensor`
Args:
* :attr:`input_tensor`
Keyword Args:
* :attr:`beta` (float): Scales the input before applying the Softplus function. By modifying beta, you can adjust the steepness of the function. A higher beta value makes the function steeper, approaching a hard threshold like the ReLU function for large values of beta
* :attr:`threshold` (float): Used to switch to a linear function for large values to improve numerical stability. This avoids issues with floating-point representation for very large values
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Keyword Args:
* :attr:`beta` (float): Scales the input before applying the Softplus function. By modifying beta, you can adjust the steepness of the function. A higher beta value makes the function steeper, approaching a hard threshold like the ReLU function for large values of beta
* :attr:`threshold` (float): Used to switch to a linear function for large values to improve numerical stability. This avoids issues with floating-point representation for very large values
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): Memory configuration for the operation.
Example::
Example::
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, parameter=true)
)doc",
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, parameter=true)
)doc",
ttnn::softplus.name(),
ttnn::softplus.python_fully_qualified_name());

Expand All @@ -165,61 +134,8 @@ void bind_softplus(py::module& module) {
} // namespace detail

void py_module(py::module& module) {
detail::bind_unary(module, ttnn::abs);
detail::bind_unary(module, ttnn::acos);
detail::bind_unary(module, ttnn::asin);
detail::bind_unary(module, ttnn::atan);
detail::bind_unary(module, ttnn::cos);
detail::bind_unary(module, ttnn::erfinv);
detail::bind_unary(module, ttnn::exp2);
detail::bind_unary(module, ttnn::expm1);
detail::bind_unary(module, ttnn::eqz);
detail::bind_unary(module, ttnn::gez);
detail::bind_unary(module, ttnn::gtz);
detail::bind_unary(module, ttnn::i0);
detail::bind_unary(module, ttnn::isfinite);
detail::bind_unary(module, ttnn::isinf);
detail::bind_unary(module, ttnn::isnan);
detail::bind_unary(module, ttnn::isneginf);
detail::bind_unary(module, ttnn::isposinf);
detail::bind_unary(module, ttnn::lez);
detail::bind_unary(module, ttnn::log);
detail::bind_unary(module, ttnn::log10);
detail::bind_unary(module, ttnn::log2);
detail::bind_unary(module, ttnn::logical_not);
detail::bind_unary(module, ttnn::ltz);
detail::bind_unary(module, ttnn::neg);
detail::bind_unary(module, ttnn::nez);
detail::bind_unary(module, ttnn::reciprocal);
detail::bind_unary(module, ttnn::relu);
detail::bind_unary(module, ttnn::relu6);
detail::bind_unary(module, ttnn::sigmoid);
detail::bind_unary(module, ttnn::sign);
detail::bind_unary(module, ttnn::signbit);
detail::bind_unary_with_bool_parameter_set_to_false_by_default(module, ttnn::exp);
detail::bind_unary(module, ttnn::silu);
detail::bind_unary(module, ttnn::sin);
detail::bind_unary(module, ttnn::sqrt);
detail::bind_unary(module, ttnn::square);
detail::bind_unary(module, ttnn::tan);
detail::bind_unary(module, ttnn::tanh);

// Unaries with fast_and_approximate_mode
detail::bind_unary_with_fast_and_approximate_mode(module, ttnn::exp);
detail::bind_unary_with_fast_and_approximate_mode(module, ttnn::erf);
detail::bind_unary_with_fast_and_approximate_mode(module, ttnn::erfc);
detail::bind_unary_with_fast_and_approximate_mode(module, ttnn::gelu);
detail::bind_unary_with_fast_and_approximate_mode(module, ttnn::rsqrt);

// Unaries with float parameter
detail::bind_unary_with_float_parameter(module, ttnn::elu, "alpha", "The alpha parameter for the ELU function");
detail::bind_unary_with_float_parameter(
module, ttnn::heaviside, "value", "The value parameter for the Heaviside function");
detail::bind_unary_with_float_parameter(
module, ttnn::leaky_relu, "slope", "The slope parameter for the Leaky ReLU function");
// detail::bind_unary_with_float_parameter(module, ttnn::prelu, "weight", "The weight parameter for the PReLU
// function");

// Other unaries (composite operations)
detail::bind_softplus(module);
}

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ struct AttentionSoftmax : public tt::operations::primary::Softmax {
tt::operations::primary::transformers::SoftmaxDefaultProgramConfig{},
const std::optional<bool> causal_mask = false,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt) {
float head_size = head_size_arg.has_value() ? 1.0f / std::sqrt(head_size_arg.value()) : 1.0f;
float head_size = head_size_arg.has_value() ? 1.0f / ::sqrt(head_size_arg.value()) : 1.0f;
if constexpr (in_place) {
TT_FATAL(attention_mask.has_value(), "Cannot apply divide by sqrt(head_size) using in-place version!");
} else {
Expand Down
99 changes: 11 additions & 88 deletions ttnn/cpp/ttnn/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ struct Unary : public EltwiseUnary {
}
};

template <UnaryOpType unary_op_type>
struct UnaryWithFastAndApproximateMode : public EltwiseUnary {
struct Exp : public EltwiseUnary {
static const std::array<TensorSchema, 1> input_tensor_schemas() { return detail::input_tensor_schemas(); }

template <typename... Args>
Expand All @@ -81,29 +80,14 @@ struct UnaryWithFastAndApproximateMode : public EltwiseUnary {
const bool parameter = false,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return detail::execute(
input_tensor, {UnaryWithParam{unary_op_type, static_cast<float>(parameter)}}, memory_config);
}
};

template <UnaryOpType unary_op_type>
struct UnaryWithFloatParameter : public EltwiseUnary {
static const std::array<TensorSchema, 1> input_tensor_schemas() { return detail::input_tensor_schemas(); }

template <typename... Args>
static auto input_tensors_to_validate(const Tensor& input_tensor, Args&&... args) {
return detail::input_tensors_to_validate(input_tensor, std::forward<Args>(args)...);
}

static Tensor execute(
const Tensor& input_tensor,
const float parameter,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
return detail::execute(
input_tensor, {UnaryWithParam{unary_op_type, static_cast<float>(parameter)}}, memory_config);
input_tensor,
{UnaryWithParam{
ttnn::operations::unary::UnaryOpType::EXP, static_cast<float>(parameter)}},
memory_config);
}
};

struct Softplus {
struct Softplus : public EltwiseUnary {
static const std::array<TensorSchema, 1> input_tensor_schemas() { return detail::input_tensor_schemas(); }

template <typename... Args>
Expand All @@ -126,72 +110,11 @@ struct Softplus {
} // namespace unary
} // namespace operations

#define REGISTER_UNARY_OPERATION(operation_name, operation_type) \
constexpr auto operation_name = ttnn::register_operation< \
ttnn::operations::unary::Unary<ttnn::operations::unary::UnaryOpType::operation_type>>( \
"ttnn::" #operation_name);

#define REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(operation_name, operation_type) \
constexpr auto operation_name = ttnn::register_operation<ttnn::operations::unary::UnaryWithFastAndApproximateMode< \
ttnn::operations::unary::UnaryOpType::operation_type>>("ttnn::" #operation_name);

#define REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(operation_name, operation_type) \
constexpr auto operation_name = ttnn::register_operation< \
ttnn::operations::unary::UnaryWithFloatParameter<ttnn::operations::unary::UnaryOpType::operation_type>>( \
"ttnn::" #operation_name);

REGISTER_UNARY_OPERATION(abs, ABS);
REGISTER_UNARY_OPERATION(acos, ACOS);
REGISTER_UNARY_OPERATION(asin, ASIN);
REGISTER_UNARY_OPERATION(atan, ATAN);
REGISTER_UNARY_OPERATION(cos, COS);
REGISTER_UNARY_OPERATION(erfinv, ERFINV);
REGISTER_UNARY_OPERATION(exp2, EXP2);
REGISTER_UNARY_OPERATION(expm1, EXPM1);
REGISTER_UNARY_OPERATION(eqz, EQZ);
REGISTER_UNARY_OPERATION(gez, GEZ);
REGISTER_UNARY_OPERATION(gtz, GTZ);
REGISTER_UNARY_OPERATION(i0, I0);
REGISTER_UNARY_OPERATION(isfinite, ISFINITE);
REGISTER_UNARY_OPERATION(isinf, ISINF);
REGISTER_UNARY_OPERATION(isnan, ISNAN);
REGISTER_UNARY_OPERATION(isneginf, ISNEGINF);
REGISTER_UNARY_OPERATION(isposinf, ISPOSINF);
REGISTER_UNARY_OPERATION(lez, LEZ);
REGISTER_UNARY_OPERATION(log, LOG);
REGISTER_UNARY_OPERATION(log10, LOG10);
REGISTER_UNARY_OPERATION(log2, LOG2);
REGISTER_UNARY_OPERATION(logical_not, LOGICAL_NOT_UNARY);
REGISTER_UNARY_OPERATION(ltz, LTZ);
REGISTER_UNARY_OPERATION(neg, NEG);
REGISTER_UNARY_OPERATION(nez, NEZ);
REGISTER_UNARY_OPERATION(reciprocal, RECIP);
REGISTER_UNARY_OPERATION(relu, RELU);
REGISTER_UNARY_OPERATION(relu6, RELU6);
REGISTER_UNARY_OPERATION(sigmoid, SIGMOID);
REGISTER_UNARY_OPERATION(sign, SIGN);
REGISTER_UNARY_OPERATION(signbit, SIGNBIT);
REGISTER_UNARY_OPERATION(silu, SILU);
REGISTER_UNARY_OPERATION(sin, SIN);
REGISTER_UNARY_OPERATION(sqrt, SQRT);
REGISTER_UNARY_OPERATION(square, SQUARE);
REGISTER_UNARY_OPERATION(tan, TAN);
REGISTER_UNARY_OPERATION(tanh, TANH);

// Unaries with fast_and_approximate_mode
REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(exp, EXP);
REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(erf, ERF);
REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(erfc, ERFC);
REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(gelu, GELU);
REGISTER_UNARY_OPERATION_WITH_FAST_AND_APPROXIMATE_MODE(rsqrt, RSQRT);

// Unaries with float parameter
REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(elu, ELU);
REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(heaviside, HEAVISIDE);
REGISTER_UNARY_OPERATION_WITH_FLOAT_PARAMETER(leaky_relu, LEAKY_RELU);
auto prelu = leaky_relu; // Alias for leaky_relu. TODO(#8544): implement PReLU properly

// Other unaries (composite operations)
constexpr auto exp = ttnn::register_operation<ttnn::operations::unary::Exp>("ttnn::exp");

constexpr auto softplus = ttnn::register_operation<ttnn::operations::unary::Softplus>("ttnn::softplus");

constexpr auto silu =
ttnn::register_operation<ttnn::operations::unary::Unary<ttnn::operations::unary::UnaryOpType::SILU>>("ttnn::silu");

} // namespace ttnn
Loading

0 comments on commit 83af07c

Please sign in to comment.