From 0bf8e213d6ea85bb77c329820b0216c4c4baab46 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 13 Nov 2024 19:10:58 +0100 Subject: [PATCH] Use a single template for every operator between clad::array_expression, clad::array, clad::array_ref --- include/clad/Differentiator/Array.h | 100 --------------- include/clad/Differentiator/ArrayExpression.h | 119 ++++++++---------- include/clad/Differentiator/ArrayRef.h | 108 ---------------- 3 files changed, 51 insertions(+), 276 deletions(-) diff --git a/include/clad/Differentiator/Array.h b/include/clad/Differentiator/Array.h index eef7de54e..d5ee05b68 100644 --- a/include/clad/Differentiator/Array.h +++ b/include/clad/Differentiator/Array.h @@ -325,15 +325,6 @@ template class array { *this); } - /// Subtracts the number from every element in the array and returns a new - /// array, when the number is on the left side. - template ::value, - int>::type = 0> - CUDA_HOST_DEVICE friend array_expression&> - operator-(U n, const array& arr) { - return array_expression&>(n, arr); - } - /// Implicitly converts from clad::array to pointer to an array of type T CUDA_HOST_DEVICE operator T*() const { return m_arr; } }; // class array @@ -355,97 +346,6 @@ template CUDA_HOST_DEVICE array zero_vector(std::size_t n) { return array(n); } -/// Overloaded operators for clad::array which return a new array. - -/// Multiplies the number to every element in the array and returns an array -/// expression. -template ::value, int>::type = 0> -CUDA_HOST_DEVICE array_expression&, BinaryMul, U> -operator*(const array& arr, U n) { - return array_expression&, BinaryMul, U>(arr, n); -} - -/// Multiplies the number to every element in the array and returns an array -/// expression, when the number is on the left side. -template ::value, int>::type = 0> -CUDA_HOST_DEVICE array_expression&, BinaryMul, U> -operator*(U n, const array& arr) { - return array_expression&, BinaryMul, U>(arr, n); -} - -/// Divides the number from every element in the array and returns an array -/// expression. -template ::value, int>::type = 0> -CUDA_HOST_DEVICE array_expression&, BinaryDiv, U> -operator/(const array& arr, U n) { - return array_expression&, BinaryDiv, U>(arr, n); -} - -/// Adds the number to every element in the array and returns a new array -template ::value, int>::type = 0> -CUDA_HOST_DEVICE array_expression&, BinaryAdd, U> -operator+(const array& arr, U n) { - return array_expression&, BinaryAdd, U>(arr, n); -} - -/// Adds the number to every element in the array and returns an array -/// expression, when the number is on the left side. -template ::value, int>::type = 0> -CUDA_HOST_DEVICE array_expression&, BinaryAdd, U> -operator+(U n, const array& arr) { - return array_expression&, BinaryAdd, U>(arr, n); -} - -/// Subtracts the number from every element in the array and returns an array -/// expression. -template ::value, int>::type = 0> -CUDA_HOST_DEVICE array_expression&, BinarySub, U> -operator-(const array& arr, U n) { - return array_expression&, BinarySub, U>(arr, n); -} - -/// Function to define element wise adding of two arrays. -template -CUDA_HOST_DEVICE array_expression&, BinaryAdd, const array&> -operator+(const array& arr1, const array& arr2) { - assert(arr1.size() == arr2.size()); - return array_expression&, BinaryAdd, const array&>(arr1, - arr2); -} - -/// Function to define element wise subtraction of two arrays. -template -CUDA_HOST_DEVICE array_expression&, BinarySub, const array&> -operator-(const array& arr1, const array& arr2) { - assert(arr1.size() == arr2.size()); - return array_expression&, BinarySub, const array&>(arr1, - arr2); -} - -/// Function to define element wise multiplication of two arrays. -template -CUDA_HOST_DEVICE array_expression&, BinaryMul, const array&> -operator*(const array& arr1, const array& arr2) { - assert(arr1.size() == arr2.size()); - return array_expression&, BinaryMul, const array&>(arr1, - arr2); -} - -/// Function to define element wise division of two arrays. -template -CUDA_HOST_DEVICE array_expression&, BinaryDiv, const array&> -operator/(const array& arr1, const array& arr2) { - assert(arr1.size() == arr2.size()); - return array_expression&, BinaryDiv, const array&>(arr1, - arr2); -} - } // namespace clad #endif // CLAD_ARRAY_H diff --git a/include/clad/Differentiator/ArrayExpression.h b/include/clad/Differentiator/ArrayExpression.h index c420f9177..a26ba99f2 100644 --- a/include/clad/Differentiator/ArrayExpression.h +++ b/include/clad/Differentiator/ArrayExpression.h @@ -81,82 +81,65 @@ class array_expression { } std::size_t size() const { return std::max(get_size(l), get_size(r)); } +}; - // Operator overload for addition. - template - array_expression&, - BinaryAdd, RE> - operator+(const RE& r) const { - return array_expression< - const array_expression&, BinaryAdd, RE>( - *this, r); - } +// A template class to determine whether a given type is array_expression, array +// or array_ref. +template class array; +template class array_ref; - // Operator overload for multiplication. - template - array_expression&, - BinaryMul, RE> - operator*(const RE& r) const { - return array_expression< - const array_expression&, BinaryMul, RE>( - *this, r); - } - - // Operator overload for subtraction. - template - array_expression&, - BinarySub, RE> - operator-(const RE& r) const { - return array_expression< - const array_expression&, BinarySub, RE>( - *this, r); - } +template struct is_clad_type : std::false_type {}; - // Operator overload for division. - template - array_expression&, - BinaryDiv, RE> - operator/(const RE& r) const { - return array_expression< - const array_expression&, BinaryDiv, RE>( - *this, r); - } -}; +template +struct is_clad_type> + : std::true_type {}; + +template struct is_clad_type> : std::true_type {}; + +template struct is_clad_type> : std::true_type {}; + +// Operator overload for addition, when one of the operands is array_expression, +// array or array_ref. +template < + typename T1, typename T2, + typename std::enable_if::value || is_clad_type::value, + int>::type = 0> +array_expression operator+(const T1& l, + const T2& r) { + return {l, r}; +} -// Operator overload for addition, when the right operand is an array_expression -// and the left operand is a scalar. -template ::value, int>::type = 0> -array_expression&> -operator+(const T& l, const array_expression& r) { - return array_expression&>( - l, r); +// Operator overload for multiplication, when one of the operands is +// array_expression, array or array_ref. +template < + typename T1, typename T2, + typename std::enable_if::value || is_clad_type::value, + int>::type = 0> +array_expression operator*(const T1& l, + const T2& r) { + return {l, r}; } -// Operator overload for multiplication, when the right operand is an -// array_expression and the left operand is a scalar. -template ::value, int>::type = 0> -array_expression&> -operator*(const T& l, const array_expression& r) { - return array_expression&>( - l, r); +// Operator overload for subtraction, when one of the operands is +// array_expression, array or array_ref. +template < + typename T1, typename T2, + typename std::enable_if::value || is_clad_type::value, + int>::type = 0> +array_expression operator-(const T1& l, + const T2& r) { + return {l, r}; } -// Operator overload for subtraction, when the right operand is an -// array_expression and the left operand is a scalar. -template ::value, int>::type = 0> -array_expression&> -operator-(const T& l, const array_expression& r) { - return array_expression&>( - l, r); +// Operator overload for division, when one of the operands is array_expression, +// array or array_ref. +template < + typename T1, typename T2, + typename std::enable_if::value || is_clad_type::value, + int>::type = 0> +array_expression operator/(const T1& l, + const T2& r) { + return {l, r}; } } // namespace clad // NOLINTEND(*-pointer-arithmetic) diff --git a/include/clad/Differentiator/ArrayRef.h b/include/clad/Differentiator/ArrayRef.h index 42ee89c00..80ec4d679 100644 --- a/include/clad/Differentiator/ArrayRef.h +++ b/include/clad/Differentiator/ArrayRef.h @@ -216,114 +216,6 @@ template class array_ref { } }; -/// Overloaded operators for clad::array_ref which returns an array -/// expression. - -/// Multiplies the arrays element wise -template -constexpr CUDA_HOST_DEVICE - array_expression&, BinaryMul, const array_ref&> - operator*(const array_ref& Ar, const array_ref& Br) { - assert(Ar.size() == Br.size() && - "Size of both the array_refs must be equal for carrying out " - "multiplication assignment"); - return array_expression&, BinaryMul, const array_ref&>( - Ar, Br); -} - -/// Adds the arrays element wise -template -constexpr CUDA_HOST_DEVICE - array_expression&, BinaryAdd, const array_ref&> - operator+(const array_ref& Ar, const array_ref& Br) { - assert(Ar.size() == Br.size() && - "Size of both the array_refs must be equal for carrying out addition " - "assignment"); - return array_expression&, BinaryAdd, const array_ref&>( - Ar, Br); -} - -/// Subtracts the arrays element wise -template -constexpr CUDA_HOST_DEVICE - array_expression&, BinarySub, const array_ref&> - operator-(const array_ref& Ar, const array_ref& Br) { - assert( - Ar.size() == Br.size() && - "Size of both the array_refs must be equal for carrying out subtraction " - "assignment"); - return array_expression&, BinarySub, const array_ref&>( - Ar, Br); -} - -/// Divides the arrays element wise -template -constexpr CUDA_HOST_DEVICE - array_expression&, BinaryDiv, const array_ref&> - operator/(const array_ref& Ar, const array_ref& Br) { - assert(Ar.size() == Br.size() && - "Size of both the array_refs must be equal for carrying out division " - "assignment"); - return array_expression&, BinaryDiv, const array_ref&>( - Ar, Br); -} - -/// Multiplies array_ref by a scalar -template ::value, int>::type = 0> -constexpr CUDA_HOST_DEVICE array_expression&, BinaryMul, U> -operator*(const array_ref& Ar, U a) { - return array_expression&, BinaryMul, U>(Ar, a); -} - -/// Multiplies array_ref by a scalar (reverse order) -template ::value, int>::type = 0> -constexpr CUDA_HOST_DEVICE array_expression&, BinaryMul, U> -operator*(U a, const array_ref& Ar) { - return array_expression&, BinaryMul, U>(Ar, a); -} - -/// Divides array_ref by a scalar -template ::value, int>::type = 0> -constexpr CUDA_HOST_DEVICE array_expression&, BinaryDiv, U> -operator/(const array_ref& Ar, U a) { - return array_expression&, BinaryDiv, U>(Ar, a); -} - -/// Adds array_ref by a scalar -template ::value, int>::type = 0> -constexpr CUDA_HOST_DEVICE array_expression&, BinaryAdd, U> -operator+(const array_ref& Ar, U a) { - return array_expression&, BinaryAdd, U>(Ar, a); -} - -/// Adds array_ref by a scalar (reverse order) -template ::value, int>::type = 0> -constexpr CUDA_HOST_DEVICE array_expression&, BinaryAdd, U> -operator+(U a, const array_ref& Ar) { - return array_expression&, BinaryAdd, U>(Ar, a); -} - -/// Subtracts array_ref by a scalar -template ::value, int>::type = 0> -constexpr CUDA_HOST_DEVICE array_expression&, BinarySub, U> -operator-(const array_ref& Ar, U a) { - return array_expression&, BinarySub, U>(Ar, a); -} - -/// Subtracts array_ref by a scalar (reverse order) -template ::value, int>::type = 0> -constexpr CUDA_HOST_DEVICE array_expression&> -operator-(U a, const array_ref& Ar) { - return array_expression&>(a, Ar); -} - /// `array_ref` specialisation is created to be used as a placeholder /// type in the overloaded derived function. All `array_ref` types are /// implicitly convertible to `array_ref` type.