Skip to content

Commit

Permalink
Add a doc page about the different adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
martenlienen committed Aug 31, 2024
1 parent 8ec5bef commit a0afd16
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
24 changes: 24 additions & 0 deletions docs/backprop.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Gradients & Backpropagation

There are two ways to compute gradients of the dynamics of an ODE, so the neural network
in the case of neural ODEs, with respect to the solution of the ODE. The first is to
backpropagate straight through the solver. After all, an ODE solver is just a series of
simple operations that define a dynamic computation graph that can be backpropagated
through with pytorch's autograd. This is implemented in `to.AutoDiffAdjoint`, so called
because it uses the autodiff/autograd mechanism. In general, this is the preferred method
as long as you have enough memory, because it is fast and gives accurate gradients.

If you run out of memory, you can compute gradients by solving the so called adjoint
equations, which basically solve the ODE backwards and track gradients along the way. This
is implemented in `to.BacksolveAdjoint`. Solving the adjoint equations requires the
computation of gradients of the model at different steps in time, which
`to.BacksolveAdjoint` implements with `torch.func`. If your model is not compatible and
you get errors because of this, you can fall back to `to.JointBacksolveAdjoint`. This
computes the model gradients with pytorch's usual autograd and should always work but
comes with two caveats. However, to make this work, `to.JointBacksolveAdjoint` needs to
solve the `n` independent adjoint equations jointly as one joint system that is jointly
discretized. This breaks with torchode's approach of solving each ODE completely
independently, because the joint discretization introduces a subtle coupling between the
solutions of your batch of ODEs. Therefore, `to.JointBacksolveAdjoint` should be your
backpropagation choice of last resort. Furthermore, it is only applicable if all ODEs in
your batch have the same evaluation points.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ plugins:
nav:
- "Introduction": README.md
- step-size-controllers.md
- backprop.md
- jit.ipynb
- extra-args.md
- extra-stats.ipynb
Expand Down

0 comments on commit a0afd16

Please sign in to comment.