Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve array_expression/array/array_ref operators #1138

Merged
merged 2 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 0 additions & 100 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,6 @@ template <typename T> class array {
*this);
}

/// Subtracts the number from every element in the array and returns a new
/// array, when the number is on the left side.
template <typename U, typename std::enable_if<std::is_arithmetic<U>::value,
int>::type = 0>
CUDA_HOST_DEVICE friend array_expression<U, BinarySub, const array<T>&>
operator-(U n, const array<T>& arr) {
return array_expression<U, BinarySub, const array<T>&>(n, arr);
}

/// Implicitly converts from clad::array to pointer to an array of type T
CUDA_HOST_DEVICE operator T*() const { return m_arr; }
}; // class array
Expand All @@ -355,97 +346,6 @@ template <typename T> CUDA_HOST_DEVICE array<T> zero_vector(std::size_t n) {
return array<T>(n);
}

/// Overloaded operators for clad::array which return a new array.

/// Multiplies the number to every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryMul, U>
operator*(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinaryMul, U>(arr, n);
}

/// Multiplies the number to every element in the array and returns an array
/// expression, when the number is on the left side.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryMul, U>
operator*(U n, const array<T>& arr) {
return array_expression<const array<T>&, BinaryMul, U>(arr, n);
}

/// Divides the number from every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryDiv, U>
operator/(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinaryDiv, U>(arr, n);
}

/// Adds the number to every element in the array and returns a new array
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryAdd, U>
operator+(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinaryAdd, U>(arr, n);
}

/// Adds the number to every element in the array and returns an array
/// expression, when the number is on the left side.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryAdd, U>
operator+(U n, const array<T>& arr) {
return array_expression<const array<T>&, BinaryAdd, U>(arr, n);
}

/// Subtracts the number from every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinarySub, U>
operator-(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinarySub, U>(arr, n);
}

/// Function to define element wise adding of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryAdd, const array<U>&>
operator+(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinaryAdd, const array<U>&>(arr1,
arr2);
}

/// Function to define element wise subtraction of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinarySub, const array<U>&>
operator-(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinarySub, const array<U>&>(arr1,
arr2);
}

/// Function to define element wise multiplication of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryMul, const array<U>&>
operator*(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinaryMul, const array<U>&>(arr1,
arr2);
}

/// Function to define element wise division of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryDiv, const array<U>&>
operator/(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinaryDiv, const array<U>&>(arr1,
arr2);
}

} // namespace clad

#endif // CLAD_ARRAY_H
119 changes: 51 additions & 68 deletions include/clad/Differentiator/ArrayExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,82 +81,65 @@ class array_expression {
}

std::size_t size() const { return std::max(get_size(l), get_size(r)); }
};

// Operator overload for addition.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinaryAdd, RE>
operator+(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinaryAdd, RE>(
*this, r);
}
// A template class to determine whether a given type is array_expression, array
// or array_ref.
template <typename T> class array;
template <typename T> class array_ref;

// Operator overload for multiplication.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinaryMul, RE>
operator*(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinaryMul, RE>(
*this, r);
}

// Operator overload for subtraction.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinarySub, RE>
operator-(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinarySub, RE>(
*this, r);
}
template <typename T> struct is_clad_type : std::false_type {};

// Operator overload for division.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinaryDiv, RE>
operator/(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinaryDiv, RE>(
*this, r);
}
};
template <typename LeftExp, typename BinaryOp, typename RightExp>
struct is_clad_type<array_expression<LeftExp, BinaryOp, RightExp>>
: std::true_type {};

template <typename T> struct is_clad_type<array<T>> : std::true_type {};

template <typename T> struct is_clad_type<array_ref<T>> : std::true_type {};

// Operator overload for addition, when one of the operands is array_expression,
// array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinaryAdd, const T2&> operator+(const T1& l,
const T2& r) {
return {l, r};
}

// Operator overload for addition, when the right operand is an array_expression
// and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinaryAdd,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator+(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
return array_expression<T, BinaryAdd,
const array_expression<LeftExp, BinaryOp, RightExp>&>(
l, r);
// Operator overload for multiplication, when one of the operands is
// array_expression, array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinaryMul, const T2&> operator*(const T1& l,
const T2& r) {
return {l, r};
}

// Operator overload for multiplication, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinaryMul,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator*(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
return array_expression<T, BinaryMul,
const array_expression<LeftExp, BinaryOp, RightExp>&>(
l, r);
// Operator overload for subtraction, when one of the operands is
// array_expression, array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinarySub, const T2&> operator-(const T1& l,
const T2& r) {
return {l, r};
}

// Operator overload for subtraction, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinarySub,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator-(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
return array_expression<T, BinarySub,
const array_expression<LeftExp, BinaryOp, RightExp>&>(
l, r);
// Operator overload for division, when one of the operands is array_expression,
// array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinaryDiv, const T2&> operator/(const T1& l,
const T2& r) {
return {l, r};
}
} // namespace clad
// NOLINTEND(*-pointer-arithmetic)
Expand Down
Loading
Loading