diff --git a/torchode/adjoints.py b/torchode/adjoints.py index 43b41a3..079eba9 100644 --- a/torchode/adjoints.py +++ b/torchode/adjoints.py @@ -220,17 +220,16 @@ def solve( ) >= 0.0 ) & not_yet_evaluated - if to_be_evaluated.any(): - interpolation = step_method.build_interpolation(interp_data) - nonzero = to_be_evaluated.nonzero() - sample_idx, eval_t_idx = nonzero[:, 0], nonzero[:, 1] - y_eval[sample_idx, eval_t_idx] = interpolation.evaluate( - t_eval[sample_idx, eval_t_idx], sample_idx - ) + interpolation = step_method.build_interpolation(interp_data) + nonzero = to_be_evaluated.nonzero() + sample_idx, eval_t_idx = nonzero[:, 0], nonzero[:, 1] + y_eval[sample_idx, eval_t_idx] = interpolation.evaluate( + t_eval[sample_idx, eval_t_idx], sample_idx + ) - not_yet_evaluated = torch.logical_xor( - to_be_evaluated, not_yet_evaluated - ) + not_yet_evaluated = torch.logical_xor( + to_be_evaluated, not_yet_evaluated + ) ######################## # Update the step size #