Skip to content

Commit

Permalink
#6537: #6677: #6687: Refactor split steps into module
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Apr 10, 2024
1 parent 9a015a2 commit c419751
Showing 1 changed file with 30 additions and 46 deletions.
76 changes: 30 additions & 46 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,14 +1236,9 @@ Tensor outer(Tensor& a, Tensor& b, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _outer)(a, b, output_mem_config);
}

// Gated Linear Unit activation: matmul(split[0],sigmoid(split[1]))
Tensor _glu(
const Tensor& input_a,
int32_t dim /* = -1 */,
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
TT_ASSERT(dim == -1 || dim == 3, "last dim GLU only supported at this time ");
if (dim == -1)
dim = 3;
std::vector<Tensor> split_tensor_for_glu(const Tensor& input_a, int32_t dim, const MemoryConfig& output_mem_config)
{
std::vector<Tensor> t_split;
Shape inshape = input_a.get_legacy_shape();
TT_FATAL(((inshape[dim] / 2 )% TILE_WIDTH == 0),
"Split tensor dimension should be in full tile");
Expand All @@ -1256,8 +1251,24 @@ Tensor _glu(
Tensor t_a = unpad(input_a, s_a, e_a, output_mem_config);
Tensor t_b = unpad(input_a, s_b, e_b, output_mem_config);

Tensor sigmoid_b = sigmoid(t_b, output_mem_config);
Tensor glu_result = mul(t_a, sigmoid_b, std::nullopt, output_mem_config);
t_split.emplace_back(t_a);
t_split.emplace_back(t_b);

return t_split;
}

// Gated Linear Unit activation: matmul(split[0],sigmoid(split[1]))
Tensor _glu(
const Tensor& input_a,
int32_t dim /* = -1 */,
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
TT_ASSERT(dim == -1 || dim == 3, "last dim GLU only supported at this time ");
if (dim == -1)
dim = 3;

std::vector<Tensor> ab = split_tensor_for_glu(input_a, dim, output_mem_config);
Tensor sigmoid_b = sigmoid(ab[1], output_mem_config);
Tensor glu_result = mul(ab[0], sigmoid_b, std::nullopt, output_mem_config);
return glu_result;
}
Tensor glu(
Expand All @@ -1275,19 +1286,9 @@ Tensor _reglu(
TT_ASSERT(dim == -1 || dim == 3, "last dim REGLU only supported at this time ");
if (dim == -1)
dim = 3;
Shape inshape = input_a.get_legacy_shape();
TT_FATAL(((inshape[dim] / 2 )% TILE_WIDTH == 0),
"Split tensor dimension should be in full tile");
Shape s_a = {0, 0, 0, 0};
Shape e_a = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3]/2 - 1 };

Shape s_b = {0, 0, 0, inshape[3]/2 };
Shape e_b = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3] - 1 };

Tensor t_a = unpad(input_a, s_a, e_a, output_mem_config);
Tensor t_b = unpad(input_a, s_b, e_b, output_mem_config);
Tensor relu_b = relu(t_b, output_mem_config);
Tensor reglu_result = mul(t_a, relu_b, std::nullopt, output_mem_config);
std::vector<Tensor> ab = split_tensor_for_glu(input_a, dim, output_mem_config);
Tensor relu_b = relu(ab[1], output_mem_config);
Tensor reglu_result = mul(ab[0], relu_b, std::nullopt, output_mem_config);
return reglu_result;
}
Tensor reglu(
Expand All @@ -1305,20 +1306,12 @@ Tensor _geglu(
TT_ASSERT(dim == -1 || dim == 3, "last dim GEGLU only supported at this time ");
if (dim == -1)
dim = 3;
Shape inshape = input_a.get_legacy_shape();
TT_FATAL(((inshape[dim] / 2 )% TILE_WIDTH == 0),
"Split tensor dimension should be in full tile");
Shape s_a = {0, 0, 0, 0};
Shape e_a = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3]/2 - 1 };

Shape s_b = {0, 0, 0, inshape[3]/2 };
Shape e_b = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3] - 1 };
std::vector<Tensor> ab = split_tensor_for_glu(input_a, dim, output_mem_config);

Tensor t_a = unpad(input_a, s_a, e_a, output_mem_config);
Tensor t_b = unpad(input_a, s_b, e_b, output_mem_config);
constexpr bool fast_appx = true;
Tensor gelu_b = gelu(t_b, fast_appx, output_mem_config);
Tensor geglu_result = mul(t_a, gelu_b, std::nullopt, output_mem_config);
Tensor gelu_b = gelu(ab[1], fast_appx, output_mem_config);
Tensor geglu_result = mul(ab[0], gelu_b, std::nullopt, output_mem_config);
return geglu_result;
}
Tensor geglu(
Expand All @@ -1336,20 +1329,11 @@ Tensor _swiglu(
TT_ASSERT(dim == -1 || dim == 3, "last dim SWIGLU only supported at this time ");
if (dim == -1)
dim = 3;
Shape inshape = input_a.get_legacy_shape();
TT_FATAL(((inshape[dim] / 2 )% TILE_WIDTH == 0),
"Split tensor dimension should be in full tile");
Shape s_a = {0, 0, 0, 0};
Shape e_a = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3]/2 - 1 };

Shape s_b = {0, 0, 0, inshape[3]/2 };
Shape e_b = {inshape[0]-1, inshape[1]-1, inshape[2]-1, inshape[3] - 1 };

Tensor t_a = unpad(input_a, s_a, e_a, output_mem_config);
Tensor t_b = unpad(input_a, s_b, e_b, output_mem_config);
std::vector<Tensor> ab = split_tensor_for_glu(input_a, dim, output_mem_config);

Tensor swish_b = swish(t_b, output_mem_config);
Tensor swiglu_result = mul(t_a, swish_b, std::nullopt, output_mem_config);
Tensor swish_b = swish(ab[1], output_mem_config);
Tensor swiglu_result = mul(ab[0], swish_b, std::nullopt, output_mem_config);
return swiglu_result;
}
Tensor swiglu(
Expand Down

0 comments on commit c419751

Please sign in to comment.