Skip to content

Commit

Permalink
#9322: Cleanup code using unnecessary functions
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed Jul 24, 2024
1 parent 698bbdd commit 1be643c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 57 deletions.
9 changes: 3 additions & 6 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ struct ExecuteBinaryCompositeOps
const Tensor& input_tensor_a,
const Tensor& input_tensor_b,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
auto op_type = get_function_type0<binary_comp_op_type>();
return op_type(input_tensor_a, input_tensor_b, memory_config);
return OpHandler<binary_comp_op_type>::handle(input_tensor_a, input_tensor_b, memory_config);
}
};

Expand All @@ -34,8 +33,7 @@ struct ExecuteBinaryCompositeOpsFloat
const Tensor& input_tensor_b,
float alpha,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
auto op_type = get_function_type1<binary_comp_op_type>();
return op_type(input_tensor_a, input_tensor_b, alpha, memory_config);
return OpHandler<binary_comp_op_type>::handle(input_tensor_a, input_tensor_b, alpha, memory_config);
}
};

Expand All @@ -49,8 +47,7 @@ struct ExecuteBinaryCompositeOpsIsClose
float atol,
const bool equal_nan,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
auto op_type = get_function_type2<binary_comp_op_type>();
return op_type(input_tensor_a, input_tensor_b, rtol, atol, equal_nan, memory_config);
return OpHandler<binary_comp_op_type>::handle(input_tensor_a, input_tensor_b, rtol, atol, equal_nan, memory_config);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ template <BinaryCompositeOpType OpType>
struct OpHandler;

template <BinaryCompositeOpType OpType>
struct OpHandler_Float;
struct OpHandler;

template <BinaryCompositeOpType OpType>
struct OpHandler_IsClose;
struct OpHandler;


template <>
Expand Down Expand Up @@ -96,41 +96,26 @@ struct OpHandler<BinaryCompositeOpType::LOGICAL_XOR> {
};

template <>
struct OpHandler_Float<BinaryCompositeOpType::ADDALPHA> {
struct OpHandler<BinaryCompositeOpType::ADDALPHA> {
static Tensor handle(const Tensor& t1, const Tensor& t2, float alpha, const std::optional<MemoryConfig>& mem_cfg) {
return _addalpha(t1, t2, alpha, mem_cfg);
}
};

template <>
struct OpHandler_Float<BinaryCompositeOpType::SUBALPHA> {
struct OpHandler<BinaryCompositeOpType::SUBALPHA> {
static Tensor handle(const Tensor& t1, const Tensor& t2, float alpha, const std::optional<MemoryConfig>& mem_cfg) {
return _subalpha(t1, t2, alpha, mem_cfg);
}
};


template <>
struct OpHandler_IsClose<BinaryCompositeOpType::ISCLOSE> {
struct OpHandler<BinaryCompositeOpType::ISCLOSE> {
static Tensor handle(const Tensor& t1, const Tensor& t2, float rtol, float atol, const bool equal_nan, const std::optional<MemoryConfig>& mem_cfg) {
return _isclose(t1, t2, rtol, atol, equal_nan, mem_cfg);
}
};

// Template functions to get the function pointers
template <BinaryCompositeOpType OpType>
auto get_function_type0() {
return &OpHandler<OpType>::handle;
}

template <BinaryCompositeOpType OpType>
auto get_function_type1() {
return &OpHandler_Float<OpType>::handle;
}

template <BinaryCompositeOpType OpType>
auto get_function_type2() {
return &OpHandler_IsClose<OpType>::handle;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,26 @@ Tensor _addcdiv(const Tensor&, const Tensor&, const Tensor&, float, const std::o


template <TernaryCompositeOpType OpType>
struct OpHandler_Float;
struct OpHandler;


template <>
struct OpHandler_Float<TernaryCompositeOpType::ADDCMUL> {
struct OpHandler<TernaryCompositeOpType::ADDCMUL> {
static Tensor handle(const Tensor& t1, const Tensor& t2, const Tensor& t3, float value, const std::optional<MemoryConfig>& mem_cfg) {
return _addcmul(t1, t2, t3, value, mem_cfg);
}
};

template <>
struct OpHandler_Float<TernaryCompositeOpType::ADDCDIV> {
struct OpHandler<TernaryCompositeOpType::ADDCDIV> {
static Tensor handle(const Tensor& t1, const Tensor& t2, const Tensor& t3, float value, const std::optional<MemoryConfig>& mem_cfg) {
return _addcdiv(t1, t2, t3, value, mem_cfg);
}
};

template <TernaryCompositeOpType OpType>
auto get_ternary_fn_float() {
return &OpHandler_Float<OpType>::handle;
return &OpHandler<OpType>::handle;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,22 @@ template <UnaryCompositeOpType OpType>
struct OpHandler;

template <UnaryCompositeOpType OpType>
struct OpHandler_Power;
struct OpHandler;

template <UnaryCompositeOpType OpType>
struct OpHandler_scale_shift;
struct OpHandler;

template <UnaryCompositeOpType OpType>
struct OpHandler_scale_alpha;
struct OpHandler;

template <UnaryCompositeOpType OpType>
struct OpHandler_low_high;
struct OpHandler;

template <UnaryCompositeOpType OpType>
struct OpHandler_threshold_value;
struct OpHandler;

template <UnaryCompositeOpType OpType>
struct OpHandler_dim;
struct OpHandler;

template <>
struct OpHandler<UnaryCompositeOpType::DEG2RAD> {
Expand Down Expand Up @@ -250,92 +250,92 @@ struct OpHandler<UnaryCompositeOpType::NORMALIZE_HW> {
};

template <>
struct OpHandler_scale_shift<UnaryCompositeOpType::HARDSWISH> {
struct OpHandler<UnaryCompositeOpType::HARDSWISH> {
static Tensor handle(const Tensor& t1, float scale, float shift, const std::optional<MemoryConfig>& mem_cfg ) {
return _hardswish(t1, scale, shift, mem_cfg);
}
};

template <>
struct OpHandler_scale_shift<UnaryCompositeOpType::HARDSIGMOID> {
struct OpHandler<UnaryCompositeOpType::HARDSIGMOID> {
static Tensor handle(const Tensor& t1, float scale, float shift, const std::optional<MemoryConfig>& mem_cfg ) {
return _hardsigmoid(t1, scale, shift, mem_cfg);
}
};

template <>
struct OpHandler_low_high<UnaryCompositeOpType::HARDTANH> {
struct OpHandler<UnaryCompositeOpType::HARDTANH> {
static Tensor handle(const Tensor& t1, float low, float high, const std::optional<MemoryConfig>& mem_cfg ) {
return _hardtanh(t1, low, high, mem_cfg);
}
};

template <>
struct OpHandler_low_high<UnaryCompositeOpType::CLIP> {
struct OpHandler<UnaryCompositeOpType::CLIP> {
static Tensor handle(const Tensor& t1, float low, float high, const std::optional<MemoryConfig>& mem_cfg ) {
return _clip(t1, low, high, mem_cfg);
}
};

template <>
struct OpHandler_low_high<UnaryCompositeOpType::CLAMP> {
struct OpHandler<UnaryCompositeOpType::CLAMP> {
static Tensor handle(const Tensor& t1, float low, float high, const std::optional<MemoryConfig>& mem_cfg ) {
return _clamp(t1, low, high, mem_cfg);
}
};

template <>
struct OpHandler_scale_alpha<UnaryCompositeOpType::SELU> {
struct OpHandler<UnaryCompositeOpType::SELU> {
static Tensor handle(const Tensor& t1, float scale, float alpha, const std::optional<MemoryConfig>& mem_cfg ) {
return _selu(t1, scale, alpha, mem_cfg);
}
};

template <>
struct OpHandler_threshold_value<UnaryCompositeOpType::THRESHOLD> {
struct OpHandler<UnaryCompositeOpType::THRESHOLD> {
static Tensor handle(const Tensor& t1, float threshold, float value, const std::optional<MemoryConfig>& mem_cfg ) {
return _threshold(t1, threshold, value, mem_cfg);
}
};

//glu (geglu, reglu, swiglu, glu) varinats are supported only for last dimension.
template <>
struct OpHandler_dim<UnaryCompositeOpType::GLU> {
struct OpHandler<UnaryCompositeOpType::GLU> {
static Tensor handle(const Tensor& t1, int32_t dim, const std::optional<MemoryConfig>& mem_cfg ) {
return _glu(t1, dim, mem_cfg);
}
};

template <>
struct OpHandler_dim<UnaryCompositeOpType::REGLU> {
struct OpHandler<UnaryCompositeOpType::REGLU> {
static Tensor handle(const Tensor& t1, int32_t dim, const std::optional<MemoryConfig>& mem_cfg ) {
return _reglu(t1, dim, mem_cfg);
}
};

template <>
struct OpHandler_dim<UnaryCompositeOpType::GEGLU> {
struct OpHandler<UnaryCompositeOpType::GEGLU> {
static Tensor handle(const Tensor& t1, int32_t dim, const std::optional<MemoryConfig>& mem_cfg ) {
return _geglu(t1, dim, mem_cfg);
}
};

template <>
struct OpHandler_dim<UnaryCompositeOpType::SWIGLU> {
struct OpHandler<UnaryCompositeOpType::SWIGLU> {
static Tensor handle(const Tensor& t1, int32_t dim, const std::optional<MemoryConfig>& mem_cfg ) {
return _swiglu(t1, dim, mem_cfg);
}
};

template <>
struct OpHandler_Power<UnaryCompositeOpType::POWER_FP> {
struct OpHandler<UnaryCompositeOpType::POWER_FP> {
static Tensor handle(uint8_t q_id, const Tensor& input, float exponent, const std::optional<MemoryConfig>& mem_cfg, std::optional<Tensor> output) {
return _power(q_id, input, exponent, mem_cfg, output);
}
};

template <>
struct OpHandler_Power<UnaryCompositeOpType::POWER_INT> {
struct OpHandler<UnaryCompositeOpType::POWER_INT> {
static Tensor handle(uint8_t q_id, const Tensor& input, uint32_t exponent, const std::optional<MemoryConfig>& mem_cfg, std::optional<Tensor> output) {
return _power(q_id, input, exponent, mem_cfg, output);
}
Expand All @@ -349,31 +349,31 @@ auto get_function_type1() {

template <UnaryCompositeOpType OpType>
auto get_function_type2() {
return &OpHandler_scale_shift<OpType>::handle;
return &OpHandler<OpType>::handle;
}

template <UnaryCompositeOpType OpType>
auto get_function_type3() {
return &OpHandler_low_high<OpType>::handle;
return &OpHandler<OpType>::handle;
}

template <UnaryCompositeOpType OpType>
auto get_function_type4() {
return &OpHandler_scale_alpha<OpType>::handle;
return &OpHandler<OpType>::handle;
}

template <UnaryCompositeOpType OpType>
auto get_function_type5() {
return &OpHandler_threshold_value<OpType>::handle;
return &OpHandler<OpType>::handle;
}

template <UnaryCompositeOpType OpType>
auto get_glu_fn() {
return &OpHandler_dim<OpType>::handle;
return &OpHandler<OpType>::handle;
}

template <UnaryCompositeOpType OpType>
auto get_power_fn() {
return &OpHandler_Power<OpType>::handle;
return &OpHandler<OpType>::handle;
}
}
3 changes: 1 addition & 2 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ struct ExecuteUnaryCompositeOpWithScaleShift
float scale,
float shift,
const std::optional<MemoryConfig>& memory_config = std::nullopt) {
auto op_type = get_function_type2<unary_comp_op_type>();
return op_type(input_tensor, scale, shift, memory_config);
return OpHandler<unary_comp_op_type>::handle(input_tensor, scale, shift, memory_config);
}
};

Expand Down

0 comments on commit 1be643c

Please sign in to comment.