diff --git a/torchode/adjoints.py b/torchode/adjoints.py index a31d963..43b41a3 100644 --- a/torchode/adjoints.py +++ b/torchode/adjoints.py @@ -1,6 +1,5 @@ from typing import Any, Dict, Optional -import functorch import torch import torch.nn as nn @@ -449,22 +448,24 @@ def __init__( self.vmap_args_dims = vmap_args_dims self.vmap_randomness = vmap_randomness - self.func, self.params, self.buffers = functorch.make_functional_with_buffers( - self.term.f - ) + f = self.term.f + orig_params = dict(f.named_parameters()) + buffers = dict(f.named_buffers()) + params_names = orig_params.keys() def vjp_single_sample(t_i, y_i, adj_y_i, arg_i): def wrapper(params, t_, y_): + args = (t_, y_) if self.term.with_args: - return self.func(params, self.buffers, t_, y_, arg_i) - else: - return self.func(params, self.buffers, t_, y_) + args = args + (arg_i,) + params_dict = {name: value for name, value in zip(params_names, params)} + return torch.func.functional_call(f, (params_dict, buffers), args) - dy, vjp = functorch.vjp(wrapper, self.params, t_i, y_i) + dy, vjp = torch.func.vjp(wrapper, tuple(orig_params.values()), t_i, y_i) vjp_params, vjp_t, vjp_y = vjp(-adj_y_i) return dy, vjp_t, vjp_y, vjp_params - self.vjp_vf = functorch.vmap( + self.vjp_vf = torch.func.vmap( vjp_single_sample, in_dims=(0, 0, 0, self.vmap_args_dims), randomness=self.vmap_randomness,