diff --git a/docs/backprop.md b/docs/backprop.md new file mode 100644 index 0000000..c67d61b --- /dev/null +++ b/docs/backprop.md @@ -0,0 +1,24 @@ +# Gradients & Backpropagation + +There are two ways to compute gradients of the dynamics of an ODE, so the neural network +in the case of neural ODEs, with respect to the solution of the ODE. The first is to +backpropagate straight through the solver. After all, an ODE solver is just a series of +simple operations that define a dynamic computation graph that can be backpropagated +through with pytorch's autograd. This is implemented in `to.AutoDiffAdjoint`, so called +because it uses the autodiff/autograd mechanism. In general, this is the preferred method +as long as you have enough memory, because it is fast and gives accurate gradients. + +If you run out of memory, you can compute gradients by solving the so called adjoint +equations, which basically solve the ODE backwards and track gradients along the way. This +is implemented in `to.BacksolveAdjoint`. Solving the adjoint equations requires the +computation of gradients of the model at different steps in time, which +`to.BacksolveAdjoint` implements with `torch.func`. If your model is not compatible and +you get errors because of this, you can fall back to `to.JointBacksolveAdjoint`. This +computes the model gradients with pytorch's usual autograd and should always work but +comes with two caveats. However, to make this work, `to.JointBacksolveAdjoint` needs to +solve the `n` independent adjoint equations jointly as one joint system that is jointly +discretized. This breaks with torchode's approach of solving each ODE completely +independently, because the joint discretization introduces a subtle coupling between the +solutions of your batch of ODEs. Therefore, `to.JointBacksolveAdjoint` should be your +backpropagation choice of last resort. Furthermore, it is only applicable if all ODEs in +your batch have the same evaluation points. diff --git a/mkdocs.yml b/mkdocs.yml index 91bc887..37d9fc0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -12,6 +12,7 @@ plugins: nav: - "Introduction": README.md - step-size-controllers.md + - backprop.md - jit.ipynb - extra-args.md - extra-stats.ipynb