Skip to content

Commit

Permalink
Merge pull request #25882 from dschwen/error_propagation_2
Browse files Browse the repository at this point in the history
Add comparison operators, conditional, and min/max
  • Loading branch information
dschwen authored Oct 31, 2023
2 parents 5efcc88 + bb638b9 commit 5d465f5
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 63 deletions.
232 changes: 169 additions & 63 deletions framework/include/utils/CompileTimeDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,62 @@ class CTBinary : public CTBase
const R _right;
};

template <typename C, typename L, typename R>
auto conditional(const C &, const L &, const R &);

/**
* Base class for a ternary functions
*/
template <typename C, typename L, typename R>
class CTConditional : public CTBinary<L, R>
{
public:
CTConditional(C condition, L left, R right) : CTBinary<L, R>(left, right), _condition(condition)
{
}
using typename CTBinary<L, R>::ResultType;

auto operator()() const { return _condition() ? _left() : _right(); }
template <CTTag dtag>
auto D() const
{
return conditional(_condition, _left.template D<dtag>(), _right.template D<dtag>());
}

template <typename Self>
std::string print() const
{
return "conditional(" + _condition.print() + ", " + _left.print() + ", " + _right.print() + ")";
}

protected:
const C _condition;

using CTBinary<L, R>::_left;
using CTBinary<L, R>::_right;
};

template <typename C, typename L, typename R>
auto
conditional(const C & condition, const L & left, const R & right)
{
return CTConditional<C, L, R>(condition, left, right);
}

template <typename L, typename R>
auto
min(const L & left, const R & right)
{
return CTConditional<decltype(left < right), L, R>(left < right, left, right);
}

template <typename L, typename R>
auto
max(const L & left, const R & right)
{
return CTConditional<decltype(left > right), L, R>(left > right, left, right);
}

/**
* Constant value
*/
Expand All @@ -230,16 +286,16 @@ class CTValue : public CTBase
{
public:
CTValue(const T value) : _value(value) {}
auto operator()() const { return _value; }
std::string print() const { return Moose::stringify(_value); }
typedef T ResultType;

auto operator()() const { return _value; }
template <CTTag dtag>
auto D() const
{
return CTNull<ResultType>();
}

typedef T ResultType;
std::string print() const { return Moose::stringify(_value); }

protected:
T _value;
Expand Down Expand Up @@ -527,6 +583,83 @@ class CTDiv : public CTBinary<L, R>
using CTBinary<L, R>::_right;
};

enum class CTComparisonEnum
{
Less,
LessEqual,
Greater,
GreaterEqual,
Equal,
Unequal
};

/**
* Binary comparison operator node
*/
template <CTComparisonEnum C, typename L, typename R>
class CTCompare : public CTBinary<L, R>
{
public:
CTCompare(L left, R right) : CTBinary<L, R>(left, right) {}
typedef bool ResultType;

ResultType operator()() const
{
if constexpr (C == CTComparisonEnum::Less)
return _left() < _right();
if constexpr (C == CTComparisonEnum::LessEqual)
return _left() <= _right();
if constexpr (C == CTComparisonEnum::Greater)
return _left() > _right();
if constexpr (C == CTComparisonEnum::GreaterEqual)
return _left() >= _right();
if constexpr (C == CTComparisonEnum::Equal)
return _left() == _right();
if constexpr (C == CTComparisonEnum::Unequal)
return _left() != _right();
}
std::string print() const
{
if constexpr (C == CTComparisonEnum::Less)
return this->printParens(this, "<");
if constexpr (C == CTComparisonEnum::LessEqual)
return this->printParens(this, "<=");
if constexpr (C == CTComparisonEnum::Greater)
return this->printParens(this, ">");
if constexpr (C == CTComparisonEnum::GreaterEqual)
return this->printParens(this, ">=");
if constexpr (C == CTComparisonEnum::Equal)
return this->printParens(this, "==");
if constexpr (C == CTComparisonEnum::Unequal)
return this->printParens(this, "!=");
}
constexpr static int precedence() { return 9; }
constexpr static bool leftAssociative() { return true; }

template <CTTag dtag>
auto D() const
{
return CTNull<ResultType>();
}

using CTBinary<L, R>::_left;
using CTBinary<L, R>::_right;
};

/// template aliases for the comparison operator nodes
template <typename L, typename R>
using CTCompareLess = CTCompare<CTComparisonEnum::Less, L, R>;
template <typename L, typename R>
using CTCompareLessEqual = CTCompare<CTComparisonEnum::LessEqual, L, R>;
template <typename L, typename R>
using CTCompareGreater = CTCompare<CTComparisonEnum::Greater, L, R>;
template <typename L, typename R>
using CTCompareGreaterEqual = CTCompare<CTComparisonEnum::GreaterEqual, L, R>;
template <typename L, typename R>
using CTCompareEqual = CTCompare<CTComparisonEnum::Equal, L, R>;
template <typename L, typename R>
using CTCompareUnequal = CTCompare<CTComparisonEnum::Unequal, L, R>;

/**
* Power operator where both base and exponent can be arbitrary operators.
*/
Expand Down Expand Up @@ -656,40 +789,32 @@ pow(const B & base)
#define CT_OPERATOR_BINARY(op, OP) \
template <typename L, \
typename R, \
class = std::enable_if_t<std::is_base_of<CTBase, L>::value && \
class = std::enable_if_t<std::is_base_of<CTBase, L>::value || \
std::is_base_of<CTBase, R>::value>> \
auto operator op(const L & left, const R & right) \
{ \
return OP(left, right); \
}

#define CT_OPERATOR_BINARY_MIX(op, OP, ot) \
template <typename L, class = std::enable_if_t<std::is_base_of<CTBase, L>::value>> \
auto operator op(const L & left, const ot & right) \
{ \
return OP(left, makeValue(right)); \
} \
template <typename R, class = std::enable_if_t<std::is_base_of<CTBase, R>::value>> \
auto operator op(const ot & left, const R & right) \
{ \
return OP(makeValue(left), right); \
/* We need a template arguments here because: */ \
/* alias template deduction is only available with '-std=c++2a' or '-std=gnu++2a' */ \
if constexpr (std::is_base_of<CTBase, L>::value && std::is_base_of<CTBase, R>::value) \
return OP<L, R>(left, right); \
else if constexpr (std::is_base_of<CTBase, L>::value) \
return OP<L, decltype(makeValue(right))>(left, makeValue(right)); \
else if constexpr (std::is_base_of<CTBase, R>::value) \
return OP<decltype(makeValue(left)), R>(makeValue(left), right); \
else \
static_assert(always_false<L>, "This should not be instantiated."); \
}

CT_OPERATOR_BINARY(+, CTAdd)
CT_OPERATOR_BINARY(-, CTSub)
CT_OPERATOR_BINARY(*, CTMul)
CT_OPERATOR_BINARY(/, CTDiv)

#define CT_OPERATORS_BINARY_MIX(ot) \
CT_OPERATOR_BINARY_MIX(+, CTAdd, ot) \
CT_OPERATOR_BINARY_MIX(-, CTSub, ot) \
CT_OPERATOR_BINARY_MIX(*, CTMul, ot) \
CT_OPERATOR_BINARY_MIX(/, CTDiv, ot)

// Add entries here to support other types in CTD expressions
CT_OPERATORS_BINARY_MIX(double)
CT_OPERATORS_BINARY_MIX(float)
CT_OPERATORS_BINARY_MIX(int)
CT_OPERATOR_BINARY(<, CTCompareLess)
CT_OPERATOR_BINARY(<=, CTCompareLessEqual)
CT_OPERATOR_BINARY(>, CTCompareGreater)
CT_OPERATOR_BINARY(>=, CTCompareGreaterEqual)
CT_OPERATOR_BINARY(==, CTCompareEqual)
CT_OPERATOR_BINARY(!=, CTCompareUnequal)

/**
* Macro for implementing a simple unary function overload. No function specific optimizations are
Expand All @@ -701,26 +826,15 @@ CT_OPERATORS_BINARY_MIX(int)
class CTF##name : public CTUnary<T> \
{ \
public: \
CTF##name(T arg) : CTUnary<T>(arg) \
{ \
} \
auto operator()() const \
{ \
return std::name(_arg()); \
} \
CTF##name(T arg) : CTUnary<T>(arg) {} \
auto operator()() const { return std::name(_arg()); } \
template <CTTag dtag> \
auto D() const \
{ \
return derivative; \
} \
std::string print() const \
{ \
return #name "(" + _arg.print() + ")"; \
} \
constexpr static int precedence() \
{ \
return 2; \
} \
std::string print() const { return #name "(" + _arg.print() + ")"; } \
constexpr static int precedence() { return 2; } \
using typename CTUnary<T>::ResultType; \
using CTUnary<T>::_arg; \
}; \
Expand Down Expand Up @@ -752,35 +866,25 @@ CT_SIMPLE_UNARY_FUNCTION(atan, 1.0 / (pow<2>(_arg) + 1.0) * _arg.template D<dtag
* possible. The parameters are the function name and the expression that returns the derivative
* of the function.
*/
#define CT_SIMPLE_BINARY_FUNCTION(name, derivative) \
#define CT_SIMPLE_BINARY_FUNCTION_CLASS(name, derivative) \
template <typename L, typename R> \
class CTF##name : public CTBinary<L, R> \
{ \
public: \
CTF##name(L left, R right) : CTBinary<L, R>(left, right) \
{ \
} \
auto operator()() const \
{ \
return std::name(_left(), _right()); \
} \
CTF##name(L left, R right) : CTBinary<L, R>(left, right) {} \
auto operator()() const { return std::name(_left(), _right()); } \
template <CTTag dtag> \
auto D() const \
{ \
return derivative; \
} \
std::string print() const \
{ \
return #name "(" + _left.print() + ", " + _right.print() + ")"; \
} \
constexpr static int precedence() \
{ \
return 2; \
} \
std::string print() const { return #name "(" + _left.print() + ", " + _right.print() + ")"; } \
constexpr static int precedence() { return 2; } \
using typename CTBinary<L, R>::ResultType; \
using CTBinary<L, R>::_left; \
using CTBinary<L, R>::_right; \
}; \
};
#define CT_SIMPLE_BINARY_FUNCTION_FUNC(name) \
template <typename L, typename R> \
auto name(const L & l, const R & r) \
{ \
Expand All @@ -795,9 +899,11 @@ CT_SIMPLE_UNARY_FUNCTION(atan, 1.0 / (pow<2>(_arg) + 1.0) * _arg.template D<dtag
return CTF##name(makeValue(l), makeValue(r)); \
}

CT_SIMPLE_BINARY_FUNCTION(atan2,
(-_left * _right.template D<dtag>() + _left.template D<dtag>() * _right) /
(pow<2>(_left) + pow<2>(_right)))
CT_SIMPLE_BINARY_FUNCTION_CLASS(atan2,
(-_left * _right.template D<dtag>() +
_left.template D<dtag>() * _right) /
(pow<2>(_left) + pow<2>(_right)))
CT_SIMPLE_BINARY_FUNCTION_FUNC(atan2)

template <typename T, int N, int M>
class CTMatrix
Expand Down
30 changes: 30 additions & 0 deletions unit/src/CompileTimeDerivativesTest.C
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ TEST(CompileTimeDerivativesTest, evaluate)

CTD_EVALTEST(x * (1.0 + x * (3.0 - x * (2.0 + x * (5.0 - x)))), -10, 10, 0.63)

CTD_EVALTEST(x * -1.0 > x * -2.0, -10, 10, 0.63)
CTD_EVALTEST(x * -1.0 < x * -2.0, -10, 10, 0.63)
CTD_EVALTEST(x * -1.0 >= x * -2.0, -10, 10, 1)
CTD_EVALTEST(x * -1.0 <= x * -2.0, -10, 10, 1)
CTD_EVALTEST(0.0 < x, -1, 1, 1)
CTD_EVALTEST(0.0 > x, -1, 1, 1)
CTD_EVALTEST(0.0 <= x, -1, 1, 1)
CTD_EVALTEST(0.0 >= x, -1, 1, 1)
CTD_EVALTEST(0.0 == x, -1, 1, 1)
CTD_EVALTEST(0.0 != x, -1, 1, 1)

using namespace std;
CTD_EVALTEST(sin(x), -10, 10, 0.72)
CTD_EVALTEST(cos(x), -10, 10, 0.72)
Expand All @@ -101,6 +112,9 @@ TEST(CompileTimeDerivativesTest, evaluate)
CTD_EVALTEST(cosh(x), -4, 4, 0.1)
CTD_EVALTEST(atan(x), -4, 4, 0.1)

CTD_EVALTEST(min(x, x * x), -2, 2, 0.271)
CTD_EVALTEST(max(x, x * x), -2, 2, 0.271)

const auto v = makeValue(0.5);
const auto r1 = CompileTimeDerivatives::atan2(1.0, 2.0);
const auto r2 = CompileTimeDerivatives::atan2(v, 2.0);
Expand Down Expand Up @@ -293,3 +307,19 @@ TEST(CompileTimeDerivativesTest, makeStandardDeviation)

EXPECT_NEAR(std_dev(), 0.6133922073192649, 1e-15);
}

TEST(CompileTimeDerivativesTest, conditional)
{
Real vx = 0.0;
const auto x = makeRef(vx);

const auto result = conditional(x < 3, 2 * x, 5 * x);

for (vx = 0.0; vx < 6.0; vx += 0.31)
{
if (vx < 3)
EXPECT_EQ(result(), 2 * vx);
else
EXPECT_EQ(result(), 5 * vx);
}
}

0 comments on commit 5d465f5

Please sign in to comment.