Skip to content

Commit

Permalink
#10384: Migrate glu op to TTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Jul 24, 2024
1 parent f86227f commit a5c7cc1
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 112 deletions.
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,6 @@ Other Operations

.. autofunction:: tt_lib.tensor.normalize_global

.. autofunction:: tt_lib.tensor.glu

.. autofunction:: tt_lib.tensor.embeddings

.. autofunction:: tt_lib.tensor.nextafter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2846,7 +2846,7 @@ def unpad_from_tile(
def activation_glu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
dim = kwargs.get("dim", -1)
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.glu(t0, dim, output_mem_config=output_mem_config)
t1 = ttnn.glu(t0, dim, memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down
8 changes: 4 additions & 4 deletions tests/ttnn/profiling/ops_for_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,7 +1450,7 @@ def logical_noti(x):


def glu_1(x):
tt_lib.tensor.glu(x, -1)
ttnn.glu(x, -1)


def geglu_1(x):
Expand All @@ -1466,7 +1466,7 @@ def swiglu_1(x):


def glu_2(x):
tt_lib.tensor.glu(x, -2)
ttnn.glu(x, -2)


def geglu_2(x):
Expand Down Expand Up @@ -2192,7 +2192,7 @@ def clone(x):
},
{
"op": glu_1,
"name": "tt_lib.tensor.glu_dim_3",
"name": "ttnn.glu_dim_3",
},
{
"op": geglu_1,
Expand All @@ -2208,7 +2208,7 @@ def clone(x):
},
{
"op": glu_2,
"name": "tt_lib.tensor.glu_dim_2",
"name": "ttnn.glu_dim_2",
},
{
"op": geglu_2,
Expand Down
23 changes: 23 additions & 0 deletions tests/ttnn/unit_tests/operations/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,29 @@ def test_unary_threshold_ttnn(input_shapes, device):
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 64])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"dim",
[-1, 3],
)
def test_unary_glu_ttnn(input_shapes, dim, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -5, 5, device)
golden_fn = ttnn.get_golden_function(ttnn.glu)

output_tensor = ttnn.glu(input_tensor, dim)
golden_tensor = golden_fn(in_data, dim)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1879,27 +1879,6 @@ std::vector<Tensor> split_tensor_for_glu(const Tensor& input_a, int32_t dim, con
return t_split;
}

// Gated Linear Unit activation: matmul(split[0],sigmoid(split[1]))
Tensor _glu(
const Tensor& input_a,
int32_t dim /* = -1 */,
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
TT_ASSERT(dim == -1 || dim == 3, "last dim GLU only supported at this time ");
if (dim == -1)
dim = 3;

std::vector<Tensor> ab = split_tensor_for_glu(input_a, dim, output_mem_config);
Tensor sigmoid_b = ttnn::sigmoid(ab[1], output_mem_config);
Tensor glu_result = ttnn::multiply(ab[0], sigmoid_b, std::nullopt, output_mem_config);
return glu_result;
}
Tensor glu(
const Tensor& input_a,
int32_t dim /* = -1 */,
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
return operation::decorate_as_composite(__func__, _glu)(input_a, dim, output_mem_config);
}


// on-device tensor creation with shape and filled with value
Tensor _sfpu_eps(const Shape shape, Layout layout, Device* device, const MemoryConfig& output_mem_config) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,12 +563,6 @@ Tensor logical_ori(
float immediate,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// Gated Linear Unit activation
Tensor glu(
const Tensor& input_a,
int32_t dim = -1,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// on-device tensor creation with shape and filled with value
Tensor sfpu_eps(const Shape shape, Layout layout, Device* device, const MemoryConfig& output_mem_config);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) {
R"doc(Returns tensor with the polyval of all of elements of the input tensor ``{0}`` with coefficients ``{1}``.)doc",
R"doc("coefficients value with highest degree first", "List of float", "List size > 0")doc");

detail::bind_unary_op_with_param(
m_tensor,
"glu",
&glu,
py::arg("dim") = -1,
R"doc(Applies the Gated Linear Units (GLU) function to the elements of the input tensor ``{0}`` split along dim ``{1}``.)doc",
R"doc(dimension to split)doc");
m_tensor.def(
"prod",
&prod,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "ttnn/run_operation.hpp"
#include "ttnn/types.hpp"
#include "tt_metal/common/bfloat16.hpp"
#include "ttnn/experimental/tt_dnn/op_library/reduce/reduce_op.hpp"
#include "tt_dnn/op_library/reduce/reduce_op.hpp"
#include "ttnn/operations/data_movement/slice/slice.hpp"

namespace ttnn::operations::unary{
Expand Down Expand Up @@ -400,7 +400,6 @@ Tensor _normalize(const Tensor& y, const std::optional<MemoryConfig>& output_mem
// PyTorch version:
// hard sigmoid(x) = { x <= -3: 0, x >= +3: +3, x/6 + 0.5 otherwise}
Tensor _hardsigmoid(const Tensor& a, float value_1, float value_2, const std::optional<MemoryConfig>& output_mem_config) {
// std::cout<<"\n\n hit in ttnn hardsigmoid";

Tensor a_t = ttnn::full_like(a,value_1);
Tensor b_t = ttnn::full_like(a,value_2);
Expand All @@ -414,7 +413,6 @@ Tensor _hardsigmoid(const Tensor& a, float value_1, float value_2, const std::op
// Ref: PyTorch
// hard swish(x) = x*hardsigmoid(x,scale,shift)
Tensor _hardswish(const Tensor& a, float value_1, float value_2, const std::optional<MemoryConfig>& output_mem_config) {
// std::cout<<"\n\n hit in ttnn hardswish";
Tensor a_sigmoid = _hardsigmoid(a, value_1, value_2, output_mem_config);
Tensor result_sq = ttnn::multiply(a_sigmoid, a, std::nullopt);
return result_sq;
Expand Down Expand Up @@ -518,6 +516,16 @@ std::vector<Tensor> split_tensor_for_glu(const Tensor& input_a, int32_t dim, con
return t_split;
}

// Gated Linear Unit activation: matmul(split[0],sigmoid(split[1]))
Tensor _glu(const Tensor& input_a, int32_t dim , const std::optional<MemoryConfig>& output_mem_config) {
TT_ASSERT(dim == -1 || dim == 3, "last dim GLU only supported at this time ");
if (dim == -1)
dim = 3;
std::vector<Tensor> ab = split_tensor_for_glu(input_a, dim, output_mem_config);
Tensor sigmoid_b = ttnn::sigmoid(ab[1], output_mem_config);
Tensor glu_result = ttnn::multiply(ab[0], sigmoid_b, std::nullopt, output_mem_config);
return glu_result;
}

// ReLU Gated Linear Unit activation: matmul(split[0],relu(split[1]))
Tensor _reglu(
Expand All @@ -533,7 +541,6 @@ Tensor _reglu(
return reglu_result;
}


// Gaussian Error Gated Linear Unit activation: matmul(split[0],gelu(split[1]))
Tensor _geglu(
const Tensor& input_a,
Expand All @@ -551,7 +558,6 @@ Tensor _geglu(
return geglu_result;
}


// Swish Gated Linear Unit activation: matmul(split[0],swish(split[1]))
Tensor _swiglu(
const Tensor& input_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ enum class UnaryCompositeOpType {
CLAMP,
SELU,
THRESHOLD,
GLU,
REGLU,
GEGLU,
SWIGLU,
Expand Down Expand Up @@ -74,7 +75,7 @@ Tensor _clip(const Tensor&, float, float, const std::optional<MemoryConfig>& );
Tensor _clamp(const Tensor&, float, float, const std::optional<MemoryConfig>& );
Tensor _selu(const Tensor&, float, float, const std::optional<MemoryConfig>& );
Tensor _threshold(const Tensor&, float, float, const std::optional<MemoryConfig>& );

Tensor _glu(const Tensor&, int32_t, const std::optional<MemoryConfig>& );
Tensor _reglu(const Tensor&, int32_t, const std::optional<MemoryConfig>& );
Tensor _geglu(const Tensor&, int32_t, const std::optional<MemoryConfig>& );
Tensor _swiglu(const Tensor&, int32_t, const std::optional<MemoryConfig>& );
Expand Down Expand Up @@ -288,6 +289,13 @@ struct OpHandler_threshold_value<UnaryCompositeOpType::THRESHOLD> {
};

//glu (geglu, reglu, swiglu, glu) varinats are supported only for last dimension.
template <>
struct OpHandler_dim<UnaryCompositeOpType::GLU> {
static Tensor handle(const Tensor& t1, int32_t dim, const std::optional<MemoryConfig>& mem_cfg ) {
return _glu(t1, dim, mem_cfg);
}
};

template <>
struct OpHandler_dim<UnaryCompositeOpType::REGLU> {
static Tensor handle(const Tensor& t1, int32_t dim, const std::optional<MemoryConfig>& mem_cfg ) {
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ constexpr auto clamp = ttnn::register_operation<operations::unary::ExecuteUnaryC
constexpr auto selu = ttnn::register_operation<operations::unary::ExecuteUnaryCompositeOpWithScaleAlpha<operations::unary::UnaryCompositeOpType::SELU>>("ttnn::selu");
constexpr auto threshold = ttnn::register_operation<operations::unary::ExecuteUnaryCompositeOpWithThresholdValue<operations::unary::UnaryCompositeOpType::THRESHOLD>>("ttnn::threshold");

constexpr auto glu = ttnn::register_operation<operations::unary::ExecuteUnaryCompositeOpWithDim<operations::unary::UnaryCompositeOpType::GLU>>("ttnn::glu");
constexpr auto reglu = ttnn::register_operation<operations::unary::ExecuteUnaryCompositeOpWithDim<operations::unary::UnaryCompositeOpType::REGLU>>("ttnn::reglu");
constexpr auto geglu = ttnn::register_operation<operations::unary::ExecuteUnaryCompositeOpWithDim<operations::unary::UnaryCompositeOpType::GEGLU>>("ttnn::geglu");
constexpr auto swiglu = ttnn::register_operation<operations::unary::ExecuteUnaryCompositeOpWithDim<operations::unary::UnaryCompositeOpType::SWIGLU>>("ttnn::swiglu");
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,7 @@ void py_module(py::module& module) {


// Unary ops with dim parameter
detail::bind_unary_operation_with_dim_parameter(module, ttnn::glu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply glu function on second tensor followed by mul op with first tensor");
detail::bind_unary_operation_with_dim_parameter(module, ttnn::reglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply relu function on second tensor followed by mul op with first tensor");
detail::bind_unary_operation_with_dim_parameter(module, ttnn::geglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply gelu function on second tensor followed by mul op with first tensor");
detail::bind_unary_operation_with_dim_parameter(module, ttnn::swiglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply silu function on second tensor followed by mul op with first tensor");
Expand Down
74 changes: 9 additions & 65 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ def _golden_function_bitwise_not(input_tensor_a, value, *args, **kwargs):
ttnn.attach_golden_function(ttnn._ttnn.operations.unary.bitwise_not, golden_function=_golden_function_bitwise_not)


def _golden_function_glu(input_tensor_a, dim, *args, **kwargs):
import torch

return torch.nn.functional.glu(input_tensor_a, dim)


ttnn.attach_golden_function(ttnn._ttnn.operations.unary.glu, golden_function=_golden_function_glu)


def _golden_function_reglu(input_tensor_a, dim, *args, **kwargs):
import torch

Expand Down Expand Up @@ -567,69 +576,4 @@ def activation_function(
activation_function_name, ttl_activation_function, param1, param2
)


def register_ttl_activation_function_glu(name, ttl_activation_function, param):
def _golden_function(input_tensor: ttnn.Tensor, dim: int = -1, **_):
import torch

name_to_torch_function = {
"glu": torch.nn.functional.glu,
}
torch_function = name_to_torch_function[name]
input_tensor = ttnn.to_torch(input_tensor)

return torch_function(input_tensor, dim=dim)

doc = f"""{(name)}(input_tensor: ttnn.Tensor, dim: int = -1, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Applies the {name} function to the elements of the input tensor :attr:`input_tensor` split along :attr:`{param}`.
.. math::
{(name)}(\\mathrm{{input\\_tensor}}_i \\; , \\; {param})
Args:
* :attr:`input_tensor`
* :attr:`{param}`
Example::
>>> tensor = ttnn.from_torch(torch.tensor((32, 64), dtype=torch.bfloat16), device=device)
>>> output = ttnn.{(name)}(tensor, {param})
"""

@ttnn.register_python_operation(name=f"ttnn.{name}", golden_function=_golden_function, doc=doc)
def activation_function(
input_tensor: ttnn.Tensor, dim: int = -1, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG
) -> ttnn.Tensor:
input_shape = tuple(input_tensor.shape)
last_dim = input_shape[-1]
glu_shape = input_shape[:-1] + (int(last_dim / 2),)

input_tensor = ttnn.unsqueeze_to_4D(input_tensor)

if not isinstance(input_tensor, ttnn.Tensor):
raise TypeError("Expected first argument to be a ttnn.Tensor")

if not _is_scalar(dim):
raise TypeError("Expected second argument to be a float")

if not ttnn.is_tensor_storage_on_device(input_tensor):
raise RuntimeError("input_tensor must be on device!")

output_tensor = ttl_activation_function(input_tensor, dim, output_mem_config=memory_config)

output_tensor = ttnn.reshape(output_tensor, ttnn.Shape(glu_shape))
return output_tensor


TTL_ACTIVATION_FUNCTIONS_GLU = [
("glu", ttl.tensor.glu, "dim"), # composite
]


for activation_function_name, ttl_activation_function, param in TTL_ACTIVATION_FUNCTIONS_GLU:
register_ttl_activation_function_glu(activation_function_name, ttl_activation_function, param)


__all__ = []

0 comments on commit a5c7cc1

Please sign in to comment.