Skip to content

Commit

Permalink
Replace functorch with torch.func
Browse files Browse the repository at this point in the history
  • Loading branch information
martenlienen committed Nov 10, 2023
1 parent 6883722 commit 1218aea
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions torchode/adjoints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict, Optional

import functorch
import torch
import torch.nn as nn

Expand Down Expand Up @@ -449,22 +448,24 @@ def __init__(
self.vmap_args_dims = vmap_args_dims
self.vmap_randomness = vmap_randomness

self.func, self.params, self.buffers = functorch.make_functional_with_buffers(
self.term.f
)
f = self.term.f
orig_params = dict(f.named_parameters())
buffers = dict(f.named_buffers())
params_names = orig_params.keys()

def vjp_single_sample(t_i, y_i, adj_y_i, arg_i):
def wrapper(params, t_, y_):
args = (t_, y_)
if self.term.with_args:
return self.func(params, self.buffers, t_, y_, arg_i)
else:
return self.func(params, self.buffers, t_, y_)
args = args + (arg_i,)
params_dict = {name: value for name, value in zip(params_names, params)}
return torch.func.functional_call(f, (params_dict, buffers), args)

dy, vjp = functorch.vjp(wrapper, self.params, t_i, y_i)
dy, vjp = torch.func.vjp(wrapper, tuple(orig_params.values()), t_i, y_i)
vjp_params, vjp_t, vjp_y = vjp(-adj_y_i)
return dy, vjp_t, vjp_y, vjp_params

self.vjp_vf = functorch.vmap(
self.vjp_vf = torch.func.vmap(
vjp_single_sample,
in_dims=(0, 0, 0, self.vmap_args_dims),
randomness=self.vmap_randomness,
Expand Down

0 comments on commit 1218aea

Please sign in to comment.