-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a doc page about the different adjoints
- Loading branch information
1 parent
8ec5bef
commit a0afd16
Showing
2 changed files
with
25 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Gradients & Backpropagation | ||
|
||
There are two ways to compute gradients of the dynamics of an ODE, so the neural network | ||
in the case of neural ODEs, with respect to the solution of the ODE. The first is to | ||
backpropagate straight through the solver. After all, an ODE solver is just a series of | ||
simple operations that define a dynamic computation graph that can be backpropagated | ||
through with pytorch's autograd. This is implemented in `to.AutoDiffAdjoint`, so called | ||
because it uses the autodiff/autograd mechanism. In general, this is the preferred method | ||
as long as you have enough memory, because it is fast and gives accurate gradients. | ||
|
||
If you run out of memory, you can compute gradients by solving the so called adjoint | ||
equations, which basically solve the ODE backwards and track gradients along the way. This | ||
is implemented in `to.BacksolveAdjoint`. Solving the adjoint equations requires the | ||
computation of gradients of the model at different steps in time, which | ||
`to.BacksolveAdjoint` implements with `torch.func`. If your model is not compatible and | ||
you get errors because of this, you can fall back to `to.JointBacksolveAdjoint`. This | ||
computes the model gradients with pytorch's usual autograd and should always work but | ||
comes with two caveats. However, to make this work, `to.JointBacksolveAdjoint` needs to | ||
solve the `n` independent adjoint equations jointly as one joint system that is jointly | ||
discretized. This breaks with torchode's approach of solving each ODE completely | ||
independently, because the joint discretization introduces a subtle coupling between the | ||
solutions of your batch of ODEs. Therefore, `to.JointBacksolveAdjoint` should be your | ||
backpropagation choice of last resort. Furthermore, it is only applicable if all ODEs in | ||
your batch have the same evaluation points. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters