Skip to content

Commit

Permalink
#10384: Migrate swiglu op to TTNN
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw authored and mouliraj-mcw committed Jul 24, 2024
1 parent d6cd2ce commit f86227f
Show file tree
Hide file tree
Showing 13 changed files with 82 additions and 73 deletions.
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -610,8 +610,6 @@ Other Operations

.. autofunction:: tt_lib.tensor.glu

.. autofunction:: tt_lib.tensor.swiglu

.. autofunction:: tt_lib.tensor.embeddings

.. autofunction:: tt_lib.tensor.nextafter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
@pytest.mark.parametrize("input_mem_config", input_mem_cfgs)
@pytest.mark.parametrize("output_mem_config", output_mem_cfgs)
class TestGLUVariants:
@pytest.mark.parametrize("fn_kind", ["glu", "swiglu"])
@pytest.mark.parametrize(
"fn_kind",
["glu", "reglu", "geglu", "swiglu"],
)
def test_all_glu_ops(
self,
input_shapes,
Expand Down
6 changes: 3 additions & 3 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2855,7 +2855,7 @@ def activation_glu(x, *args, device, dtype, layout, input_mem_config, output_mem
def activation_geglu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
dim = kwargs.get("dim", -1)
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttnn.geglu(t0, dim=dim, memory_config=output_mem_config)
t1 = ttnn.geglu(t0, dim, memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand All @@ -2864,7 +2864,7 @@ def activation_geglu(x, *args, device, dtype, layout, input_mem_config, output_m
def activation_reglu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
dim = kwargs.get("dim", -1)
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttnn.reglu(t0, dim=dim, memory_config=output_mem_config)
t1 = ttnn.reglu(t0, dim, memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand All @@ -2873,7 +2873,7 @@ def activation_reglu(x, *args, device, dtype, layout, input_mem_config, output_m
def activation_swiglu(x, *args, device, dtype, layout, input_mem_config, output_mem_config, **kwargs):
dim = kwargs.get("dim", -1)
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.swiglu(t0, dim, output_mem_config=output_mem_config)
t1 = ttnn.swiglu(t0, dim, memory_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down
8 changes: 4 additions & 4 deletions tests/ttnn/profiling/ops_for_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,7 @@ def reglu_1(x):


def swiglu_1(x):
tt_lib.tensor.swiglu(x, -1)
ttnn.swiglu(x, -1)


def glu_2(x):
Expand All @@ -1478,7 +1478,7 @@ def reglu_2(x):


def swiglu_2(x):
tt_lib.tensor.swiglu(x, -2)
ttnn.swiglu(x, -2)


def repeat(x):
Expand Down Expand Up @@ -2204,7 +2204,7 @@ def clone(x):
},
{
"op": swiglu_1,
"name": "tt_lib.tensor.swiglu_dim_3",
"name": "ttnn.swiglu_dim_3",
},
{
"op": glu_2,
Expand All @@ -2220,7 +2220,7 @@ def clone(x):
},
{
"op": swiglu_2,
"name": "tt_lib.tensor.swiglu_dim_2",
"name": "ttnn.swiglu_dim_2",
},
{
"op": repeat,
Expand Down
31 changes: 27 additions & 4 deletions tests/ttnn/unit_tests/operations/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def test_unary_reglu_ttnn(input_shapes, dim, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -5, 5, device)
golden_fn = ttnn.get_golden_function(ttnn.reglu)

output_tensor = ttnn.reglu(input_tensor, dim=dim)
golden_tensor = golden_fn(in_data, dim=dim)
output_tensor = ttnn.reglu(input_tensor, dim)
golden_tensor = golden_fn(in_data, dim)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
Expand All @@ -435,8 +435,31 @@ def test_unary_geglu_ttnn(input_shapes, dim, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -5, 5, device)
golden_fn = ttnn.get_golden_function(ttnn.geglu)

output_tensor = ttnn.geglu(input_tensor, dim=dim)
golden_tensor = golden_fn(in_data, dim=dim)
output_tensor = ttnn.geglu(input_tensor, dim)
golden_tensor = golden_fn(in_data, dim)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 64])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"dim",
[-1, 3],
)
def test_unary_swiglu_ttnn(input_shapes, dim, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -5, 5, device)
golden_fn = ttnn.get_golden_function(ttnn.swiglu)

output_tensor = ttnn.swiglu(input_tensor, dim)
golden_tensor = golden_fn(in_data, dim)

comp_pass = compare_pcc([output_tensor], [golden_tensor])
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -1901,28 +1901,6 @@ Tensor glu(
}


// Swish Gated Linear Unit activation: matmul(split[0],swish(split[1]))
Tensor _swiglu(
const Tensor& input_a,
int32_t dim /* = -1 */,
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
TT_ASSERT(dim == -1 || dim == 3, "last dim SWIGLU only supported at this time ");
if (dim == -1)
dim = 3;

std::vector<Tensor> ab = split_tensor_for_glu(input_a, dim, output_mem_config);

Tensor swish_b = swish(ab[1], output_mem_config);
Tensor swiglu_result = ttnn::multiply(ab[0], swish_b, std::nullopt, output_mem_config);
return swiglu_result;
}
Tensor swiglu(
const Tensor& input_a,
int32_t dim /* = -1 */,
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
return operation::decorate_as_composite(__func__, _swiglu)(input_a, dim, output_mem_config);
}

// on-device tensor creation with shape and filled with value
Tensor _sfpu_eps(const Shape shape, Layout layout, Device* device, const MemoryConfig& output_mem_config) {
float value = device->sfpu_eps();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,12 +569,6 @@ Tensor glu(
int32_t dim = -1,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// Swish based GLU
Tensor swiglu(
const Tensor& input_a,
int32_t dim = -1,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

// on-device tensor creation with shape and filled with value
Tensor sfpu_eps(const Shape shape, Layout layout, Device* device, const MemoryConfig& output_mem_config);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) {
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

detail::bind_unary_op_with_param(
m_tensor,
"swiglu",
&swiglu,
py::arg("dim") = -1,
R"doc(Applies the Swish Gated Linear Units (SwiGLU) function to the elements of the input tensor ``{0}`` split along dim ``{1}``.)doc",
R"doc(dimension to split)doc");
detail::bind_unary_op_with_param(
m_tensor,
"logical_andi",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,4 +551,21 @@ Tensor _geglu(
return geglu_result;
}


// Swish Gated Linear Unit activation: matmul(split[0],swish(split[1]))
Tensor _swiglu(
const Tensor& input_a,
int32_t dim,
const std::optional<MemoryConfig>& output_mem_config ) {
TT_ASSERT(dim == -1 || dim == 3, "last dim SWIGLU only supported at this time ");
if (dim == -1)
dim = 3;

std::vector<Tensor> ab = split_tensor_for_glu(input_a, dim, output_mem_config);

Tensor swish_b = _swish(ab[1], output_mem_config);
Tensor swiglu_result = ttnn::multiply(ab[0], swish_b, std::nullopt, output_mem_config);
return swiglu_result;
}

} // namespace ttnn::operations::unary
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ enum class UnaryCompositeOpType {
THRESHOLD,
REGLU,
GEGLU,
SWIGLU,
};

Tensor _tanhshrink (const Tensor&, const std::optional<MemoryConfig>&);
Expand Down Expand Up @@ -76,6 +77,7 @@ Tensor _threshold(const Tensor&, float, float, const std::optional<MemoryConfig>

Tensor _reglu(const Tensor&, int32_t, const std::optional<MemoryConfig>& );
Tensor _geglu(const Tensor&, int32_t, const std::optional<MemoryConfig>& );
Tensor _swiglu(const Tensor&, int32_t, const std::optional<MemoryConfig>& );

// OpHandler struct template
template <UnaryCompositeOpType OpType>
Expand Down Expand Up @@ -300,6 +302,13 @@ struct OpHandler_dim<UnaryCompositeOpType::GEGLU> {
}
};

template <>
struct OpHandler_dim<UnaryCompositeOpType::SWIGLU> {
static Tensor handle(const Tensor& t1, int32_t dim, const std::optional<MemoryConfig>& mem_cfg ) {
return _swiglu(t1, dim, mem_cfg);
}
};

// Template functions to get the function pointers
template <UnaryCompositeOpType OpType>
auto get_function_type1() {
Expand Down Expand Up @@ -327,7 +336,7 @@ auto get_function_type5() {
}

template <UnaryCompositeOpType OpType>
auto get_function_type6() {
auto get_glu_fn() {
return &OpHandler_dim<OpType>::handle;
}
}
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ struct ExecuteUnaryCompositeOpWithDim
int32_t dim,
const std::optional<MemoryConfig>& memory_config = std::nullopt)
{
auto op_type = get_function_type6<unary_comp_op_type>();
auto op_type = get_glu_fn<unary_comp_op_type>();
return op_type(input_tensor, dim, memory_config);
}
};
Expand Down Expand Up @@ -294,5 +294,6 @@ constexpr auto threshold = ttnn::register_operation<operations::unary::ExecuteUn

constexpr auto reglu = ttnn::register_operation<operations::unary::ExecuteUnaryCompositeOpWithDim<operations::unary::UnaryCompositeOpType::REGLU>>("ttnn::reglu");
constexpr auto geglu = ttnn::register_operation<operations::unary::ExecuteUnaryCompositeOpWithDim<operations::unary::UnaryCompositeOpType::GEGLU>>("ttnn::geglu");
constexpr auto swiglu = ttnn::register_operation<operations::unary::ExecuteUnaryCompositeOpWithDim<operations::unary::UnaryCompositeOpType::SWIGLU>>("ttnn::swiglu");

} // namespace ttnn
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,7 @@ void py_module(py::module& module) {
// Unary ops with dim parameter
detail::bind_unary_operation_with_dim_parameter(module, ttnn::reglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply relu function on second tensor followed by mul op with first tensor");
detail::bind_unary_operation_with_dim_parameter(module, ttnn::geglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply gelu function on second tensor followed by mul op with first tensor");
detail::bind_unary_operation_with_dim_parameter(module, ttnn::swiglu, "dim", "Dimenstion to split input tensor. Supported dimension -1 or 3", "Split the tensor into two, apply silu function on second tensor followed by mul op with first tensor");

// Other unaries (unary chain operations)
detail::bind_softplus(module, ttnn::softplus);
Expand Down
36 changes: 14 additions & 22 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,20 @@ def _golden_function_geglu(input_tensor_a, dim, *args, **kwargs):
ttnn.attach_golden_function(ttnn._ttnn.operations.unary.geglu, golden_function=_golden_function_geglu)


def _golden_function_swiglu(input_tensor_a, dim, *args, **kwargs):
import torch

assert isinstance(dim, int), "dim must be an integer"
assert dim in [-1, 3], "dim must be -1 or 3"
split_size = input_tensor_a.size(-1) // 2
split_tensors = torch.split(input_tensor_a, split_size_or_sections=[split_size, split_size], dim=dim)
tensA, tensB = split_tensors[0], split_tensors[1]
return tensA * torch.nn.functional.silu(tensB)


ttnn.attach_golden_function(ttnn._ttnn.operations.unary.swiglu, golden_function=_golden_function_swiglu)


def _is_scalar(value):
return isinstance(value, (int, float))

Expand Down Expand Up @@ -554,33 +568,12 @@ def activation_function(
)


def torch_reglu(input_tensor, *args, **kwargs):
import torch

split_size = input_tensor.size(-1) // 2
split_tensors = torch.split(input_tensor, split_size_or_sections=[split_size, split_size], dim=-1)
tensA, tensB = split_tensors[0], split_tensors[1]
return tensA * torch.nn.functional.relu(tensB)


def torch_swiglu(input_tensor, *args, **kwargs):
import torch

split_size = input_tensor.size(-1) // 2
split_tensors = torch.split(input_tensor, split_size_or_sections=[split_size, split_size], dim=-1)
tensA, tensB = split_tensors[0], split_tensors[1]
return tensA * torch.nn.functional.silu(tensB)


def register_ttl_activation_function_glu(name, ttl_activation_function, param):
def _golden_function(input_tensor: ttnn.Tensor, dim: int = -1, **_):
import torch

name_to_torch_function = {
"glu": torch.nn.functional.glu,
"reglu": torch_reglu,
"swiglu": torch_swiglu,
"geglu": torch_geglu,
}
torch_function = name_to_torch_function[name]
input_tensor = ttnn.to_torch(input_tensor)
Expand Down Expand Up @@ -632,7 +625,6 @@ def activation_function(

TTL_ACTIVATION_FUNCTIONS_GLU = [
("glu", ttl.tensor.glu, "dim"), # composite
("swiglu", ttl.tensor.swiglu, "dim"), # composite
]


Expand Down

0 comments on commit f86227f

Please sign in to comment.