Skip to content

Commit

Permalink
Differentiate for loop condition expression (#746)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohanjulka19 committed Mar 12, 2024
1 parent 5736df6 commit 1cdea66
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
12 changes: 8 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1058,9 +1058,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// FIXME: for now we assume that cond has no differentiable effects,
// but it is not generally true, e.g. for (...; (x = y); ...)...
StmtDiff cond;
if (FS->getCond())
cond = Visit(FS->getCond());
StmtDiff condDiff;
StmtDiff condExprDiff;
if (FS->getCond()) {
std::tie(condDiff, condExprDiff) = DifferentiateSingleExpr(FS->getCond());
addToCurrentBlock(unwrapIfSingleStmt(condDiff.getStmt()));
}
const auto* IDRE = dyn_cast<DeclRefExpr>(FS->getInc());
const Expr* inc = IDRE ? Visit(FS->getInc()).getExpr() : FS->getInc();

Expand Down Expand Up @@ -1108,7 +1111,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from
/// overwriting local variables on different loop passes.
Expr* forwardCond = cond.getExpr();
Expr* forwardCond = condExprDiff.getExpr();
/// If there is a declaration in the condition, `cond` will be
/// a DeclRefExpr of the declared variable. There is no point in
/// inserting it since condVarRes.getExpr() represents an assignment with
Expand Down Expand Up @@ -1145,6 +1148,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Forward = endBlock(direction::forward);
addToCurrentBlock(loopCounter.getPop(), direction::reverse);
addToCurrentBlock(initResult.getStmt_dx(), direction::reverse);
addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse);
addToCurrentBlock(Reverse, direction::reverse);
Reverse = endBlock(direction::reverse);
endScope();
Expand Down
59 changes: 59 additions & 0 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,63 @@ double f_loop_init_var(double lower, double upper) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn20(double i, double j) {
double res = 0;
for (int c = 0; (res = i * j); ++c) {
if (c == 1)
break;
}
return res;
}

// CHECK: void fn20_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
// CHECK-NEXT: double _d_res = 0;
// CHECK-NEXT: unsigned long _t0;
// CHECK-NEXT: int _d_c = 0;
// CHECK-NEXT: int c = 0;
// CHECK-NEXT: clad::tape<double> _t1 = {};
// CHECK-NEXT: clad::tape<bool> _t3 = {};
// CHECK-NEXT: clad::tape<unsigned long> _t4 = {};
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: _t0 = 0;
// CHECK-NEXT: {
// CHECK-NEXT: clad::push(_t1, res);
// CHECK-NEXT: for (c = 0; (res = i * j); ++c) {
// CHECK-NEXT: _t0++;
// CHECK-NEXT: bool _t2 = c == 1;
// CHECK-NEXT: {
// CHECK-NEXT: if (_t2) {
// CHECK-NEXT: clad::push(_t4, 1UL);
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: clad::push(_t3, _t2);
// CHECK-NEXT: }
// CHECK-NEXT: clad::push(_t4, 2UL);
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: {
// CHECK-NEXT: for (; _t0; _t0--)
// CHECK-NEXT: switch (clad::pop(_t4)) {
// CHECK-NEXT: case 2UL:
// CHECK-NEXT: ;
// CHECK-NEXT: --c;
// CHECK-NEXT: if (clad::pop(_t3))
// CHECK-NEXT: case 1UL:
// CHECK-NEXT: ;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: res = clad::pop(_t1);
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: _d_res -= _r_d0;
// CHECK-NEXT: * _d_i += _r_d0 * j;
// CHECK-NEXT: * _d_j += i * _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }

#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient(F);\
Expand Down Expand Up @@ -1692,4 +1749,6 @@ int main() {

TEST_GRADIENT(fn19, 1, arr, 5, d_arr);
TEST_2(f_loop_init_var, 1, 2); // CHECK-EXEC: {-1.00, 4.00}
TEST_2(fn20, 3, 5); // CHECK-EXEC: {5.00, 3.00}

}

0 comments on commit 1cdea66

Please sign in to comment.