From 551942fc3c52a5b5de33820c56e858ff10c8289a Mon Sep 17 00:00:00 2001 From: ilykuleshov Date: Wed, 2 Oct 2024 17:27:55 +0000 Subject: [PATCH] remove syncing at each eval point --- torchode/adjoints.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/torchode/adjoints.py b/torchode/adjoints.py index 43b41a3..079eba9 100644 --- a/torchode/adjoints.py +++ b/torchode/adjoints.py @@ -220,17 +220,16 @@ def solve( ) >= 0.0 ) & not_yet_evaluated - if to_be_evaluated.any(): - interpolation = step_method.build_interpolation(interp_data) - nonzero = to_be_evaluated.nonzero() - sample_idx, eval_t_idx = nonzero[:, 0], nonzero[:, 1] - y_eval[sample_idx, eval_t_idx] = interpolation.evaluate( - t_eval[sample_idx, eval_t_idx], sample_idx - ) + interpolation = step_method.build_interpolation(interp_data) + nonzero = to_be_evaluated.nonzero() + sample_idx, eval_t_idx = nonzero[:, 0], nonzero[:, 1] + y_eval[sample_idx, eval_t_idx] = interpolation.evaluate( + t_eval[sample_idx, eval_t_idx], sample_idx + ) - not_yet_evaluated = torch.logical_xor( - to_be_evaluated, not_yet_evaluated - ) + not_yet_evaluated = torch.logical_xor( + to_be_evaluated, not_yet_evaluated + ) ######################## # Update the step size #