diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index 324fba24dc5..c87a6b1f2e6 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -1236,14 +1236,9 @@ Tensor outer(Tensor& a, Tensor& b, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _outer)(a, b, output_mem_config); } -// Gated Linear Unit activation: matmul(split[0],sigmoid(split[1])) -Tensor _glu( - const Tensor& input_a, - int32_t dim /* = -1 */, - const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { - TT_ASSERT(dim == -1 || dim == 3, "last dim GLU only supported at this time "); - if (dim == -1) - dim = 3; +std::vector split_tensor_for_glu(const Tensor& input_a, int32_t dim, const MemoryConfig& output_mem_config) +{ + std::vector t_split; Shape inshape = input_a.get_legacy_shape(); TT_FATAL(((inshape[dim] / 2 )% TILE_WIDTH == 0), "Split tensor dimension should be in full tile"); @@ -1256,8 +1251,24 @@ Tensor _glu( Tensor t_a = unpad(input_a, s_a, e_a, output_mem_config); Tensor t_b = unpad(input_a, s_b, e_b, output_mem_config); - Tensor sigmoid_b = sigmoid(t_b, output_mem_config); - Tensor glu_result = mul(t_a, sigmoid_b, std::nullopt, output_mem_config); + t_split.emplace_back(t_a); + t_split.emplace_back(t_b); + + return t_split; +} + +// Gated Linear Unit activation: matmul(split[0],sigmoid(split[1])) +Tensor _glu( + const Tensor& input_a, + int32_t dim /* = -1 */, + const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) { + TT_ASSERT(dim == -1 || dim == 3, "last dim GLU only supported at this time "); + if (dim == -1) + dim = 3; + + std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); + Tensor sigmoid_b = sigmoid(ab[1], output_mem_config); + Tensor glu_result = mul(ab[0], sigmoid_b, std::nullopt, output_mem_config); return glu_result; } Tensor glu( @@ -1275,19 +1286,9 @@ Tensor _reglu( TT_ASSERT(dim == -1 || dim == 3, "last dim REGLU only supported at this time "); if (dim == -1) dim = 3; - Shape inshape = input_a.get_legacy_shape(); - TT_FATAL(((inshape[dim] / 2 )% TILE_WIDTH == 0), - "Split tensor dimension should be in full tile"); - Shape s_a = {0, 0, 0, 0}; - Shape e_a = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3]/2 - 1 }; - - Shape s_b = {0, 0, 0, inshape[3]/2 }; - Shape e_b = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3] - 1 }; - - Tensor t_a = unpad(input_a, s_a, e_a, output_mem_config); - Tensor t_b = unpad(input_a, s_b, e_b, output_mem_config); - Tensor relu_b = relu(t_b, output_mem_config); - Tensor reglu_result = mul(t_a, relu_b, std::nullopt, output_mem_config); + std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); + Tensor relu_b = relu(ab[1], output_mem_config); + Tensor reglu_result = mul(ab[0], relu_b, std::nullopt, output_mem_config); return reglu_result; } Tensor reglu( @@ -1305,20 +1306,12 @@ Tensor _geglu( TT_ASSERT(dim == -1 || dim == 3, "last dim GEGLU only supported at this time "); if (dim == -1) dim = 3; - Shape inshape = input_a.get_legacy_shape(); - TT_FATAL(((inshape[dim] / 2 )% TILE_WIDTH == 0), - "Split tensor dimension should be in full tile"); - Shape s_a = {0, 0, 0, 0}; - Shape e_a = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3]/2 - 1 }; - Shape s_b = {0, 0, 0, inshape[3]/2 }; - Shape e_b = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3] - 1 }; + std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); - Tensor t_a = unpad(input_a, s_a, e_a, output_mem_config); - Tensor t_b = unpad(input_a, s_b, e_b, output_mem_config); constexpr bool fast_appx = true; - Tensor gelu_b = gelu(t_b, fast_appx, output_mem_config); - Tensor geglu_result = mul(t_a, gelu_b, std::nullopt, output_mem_config); + Tensor gelu_b = gelu(ab[1], fast_appx, output_mem_config); + Tensor geglu_result = mul(ab[0], gelu_b, std::nullopt, output_mem_config); return geglu_result; } Tensor geglu( @@ -1336,20 +1329,11 @@ Tensor _swiglu( TT_ASSERT(dim == -1 || dim == 3, "last dim SWIGLU only supported at this time "); if (dim == -1) dim = 3; - Shape inshape = input_a.get_legacy_shape(); - TT_FATAL(((inshape[dim] / 2 )% TILE_WIDTH == 0), - "Split tensor dimension should be in full tile"); - Shape s_a = {0, 0, 0, 0}; - Shape e_a = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3]/2 - 1 }; - - Shape s_b = {0, 0, 0, inshape[3]/2 }; - Shape e_b = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3] - 1 }; - Tensor t_a = unpad(input_a, s_a, e_a, output_mem_config); - Tensor t_b = unpad(input_a, s_b, e_b, output_mem_config); + std::vector ab = split_tensor_for_glu(input_a, dim, output_mem_config); - Tensor swish_b = swish(t_b, output_mem_config); - Tensor swiglu_result = mul(t_a, swish_b, std::nullopt, output_mem_config); + Tensor swish_b = swish(ab[1], output_mem_config); + Tensor swiglu_result = mul(ab[0], swish_b, std::nullopt, output_mem_config); return swiglu_result; } Tensor swiglu(