Skip to content

Commit

Permalink
Dolci/tape recompute count (#172)
Browse files Browse the repository at this point in the history
* Add a tape computed counter.
  • Loading branch information
Ig-dolci authored Nov 6, 2024
1 parent 87862e1 commit 9fbb0b1
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
1 change: 1 addition & 0 deletions pyadjoint/reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __call__(self, values):

self.tape.reset_blocks()
blocks = self.tape.get_blocks()
self.tape._recompute_count += 1
with self.marked_controls():
with stop_annotating():
if self.tape._checkpoint_manager:
Expand Down
9 changes: 8 additions & 1 deletion pyadjoint/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class Tape(object):
__slots__ = ["_blocks", "_tf_tensors", "_tf_added_blocks", "_nodes",
"_tf_registered_blocks", "_bar", "_package_data",
"_checkpoint_manager", "latest_checkpoint",
"_eagerly_checkpoint_outputs"]
"_eagerly_checkpoint_outputs", "_recompute_count"]

def __init__(self, blocks=None, package_data=None):
# Initialize the list of blocks on the tape.
Expand All @@ -182,6 +182,8 @@ def __init__(self, blocks=None, package_data=None):
self._checkpoint_manager = None
# Whether to store the adjoint dependencies.
self._eagerly_checkpoint_outputs = False
# A counter for the number of tape recomputations.
self._recompute_count = 0

def clear_tape(self):
"""Clear the tape."""
Expand All @@ -196,6 +198,11 @@ def latest_timestep(self):
"""The current time step to which blocks will be added."""
return max(len(self._blocks.steps) - 1, 0)

@property
def recompute_count(self):
"""The number of times the tape has been recomputed."""
return self._recompute_count

def end_timestep(self):
"""Mark the end of a timestep when taping the forward model."""
if self._checkpoint_manager:
Expand Down
4 changes: 3 additions & 1 deletion tests/firedrake_adjoint/test_solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def J(f):
J0 = J(f)
rf = ReducedFunctional(J0, Control(f))
assert_approx_equal(rf(f), J0)

assert rf.tape.recompute_count == 1
_test_adjoint(J, f)


Expand Down Expand Up @@ -298,6 +298,8 @@ def test_two_nonlinear_solves():
J = assemble(dot(u1, u1)*dx)
rf = ReducedFunctional(J, c)
assert taylor_test(rf, ui, Constant(0.1)) > 1.95
# Taylor test recomputes the functional 5 times.
assert rf.tape.recompute_count == 5


def convergence_rates(E_values, eps_values):
Expand Down

0 comments on commit 9fbb0b1

Please sign in to comment.