Skip to content

Commit

Permalink
Issue 465: Add an infection generating model for ODE problems (#510)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 authored Nov 18, 2024
1 parent a5ff3b4 commit 0ddac54
Show file tree
Hide file tree
Showing 21 changed files with 1,633 additions and 54 deletions.
2 changes: 2 additions & 0 deletions EpiAware/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -36,6 +37,7 @@ FillArrays = "1.11"
LinearAlgebra = ">= 1.9"
LogExpFunctions = "0.3"
MCMCChains = "6.0"
OrdinaryDiffEq = "6.89.0"
Pathfinder = "0.8, 0.9"
QuadGK = "2.9"
Random = ">= 1.9"
Expand Down
2 changes: 1 addition & 1 deletion EpiAware/src/EpiAwareUtils/prefix_submodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ submodel = prefix_submodel(FixedIntercept(0.1), generate_latent, string(1), 2)
We can now draw a sample from the submodel.
```julia
```@example
rand(submodel)
```
"
Expand Down
10 changes: 8 additions & 2 deletions EpiAware/src/EpiInfModels/EpiInfModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@ module EpiInfModels
using ..EpiAwareBase
using ..EpiAwareUtils

using Turing, Distributions, DocStringExtensions, LinearAlgebra, LogExpFunctions
using LogExpFunctions: xexpy

using Turing, Distributions, DocStringExtensions, LinearAlgebra, OrdinaryDiffEq

#Export parameter helpers
export EpiData

#Export models
export EpiData, DirectInfections, ExpGrowthRate, Renewal
export DirectInfections, ExpGrowthRate, Renewal, ODEProcess

#Export functions
export R_to_r, r_to_R, expected_Rt
Expand All @@ -20,6 +25,7 @@ include("DirectInfections.jl")
include("ExpGrowthRate.jl")
include("RenewalSteps.jl")
include("Renewal.jl")
include("ODEProcess.jl")
include("utils.jl")

end
306 changes: 306 additions & 0 deletions EpiAware/src/EpiInfModels/ODEProcess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
@doc raw"""
A structure representing an infection process modeled by an Ordinary Differential Equation (ODE).
At a high level, an `ODEProcess` struct object combines:
- An `AbstractTuringParamModel` which defines the ODE model in terms of `OrdinaryDiffEq` types,
the parameters of the ODE model and a method to generate the parameters.
- A technique for solving and interpreting the ODE model using the `SciML` ecosystem. This includes
the solver used in the ODE solution, keyword arguments to send to the solver and a function
to map the `ODESolution` solution object to latent infections.
# Constructors
- `ODEProcess(prob::ODEProblem; ts, solver, sol2infs)`: Create an `ODEProcess`
object with the ODE problem `prob`, time points `ts`, solver `solver`, and function `sol2infs`.
# Predefined ODE models
Two basic ODE models are provided in the `EpiAware` package: `SIRParams` and `SEIRParams`.
In both cases these are defined in terms of the proportions of the population in each compartment
of the SIR and SEIR models respectively.
## SIR model
```math
\begin{aligned}
\frac{dS}{dt} &= -\beta SI \\
\frac{dI}{dt} &= \beta SI - \gamma I \\
\frac{dR}{dt} &= \gamma I
\end{aligned}
```
Where `S` is the proportion of the population that is susceptible, `I` is the proportion of the
population that is infected and `R` is the proportion of the population that is recovered. The
parameters are the infectiousness `β` and the recovery rate `γ`.
```jldoctest sirexample; output = false
using EpiAware, OrdinaryDiffEq, Distributions
# Create an instance of SIRParams
sirparams = SIRParams(
tspan = (0.0, 100.0),
infectiousness = LogNormal(log(0.3), 0.05),
recovery_rate = LogNormal(log(0.1), 0.05),
initial_prop_infected = Beta(1, 99)
)
nothing
# output
```
## SEIR model
```math
\begin{aligned}
\frac{dS}{dt} &= -\beta SI \\
\frac{dE}{dt} &= \beta SI - \alpha E \\
\frac{dI}{dt} &= \alpha E - \gamma I \\
\frac{dR}{dt} &= \gamma I
\end{aligned}
```
Where `S` is the proportion of the population that is susceptible, `E` is the proportion of the
population that is exposed, `I` is the proportion of the population that is infected and `R` is
the proportion of the population that is recovered. The parameters are the infectiousness `β`,
the incubation rate `α` and the recovery rate `γ`.
```jldoctest; output = false
using EpiAware, OrdinaryDiffEq, Distributions, Random
Random.seed!(1234)
# Create an instance of SIRParams
seirparams = SEIRParams(
tspan = (0.0, 100.0),
infectiousness = LogNormal(log(0.3), 0.05),
incubation_rate = LogNormal(log(0.1), 0.05),
recovery_rate = LogNormal(log(0.1), 0.05),
initial_prop_infected = Beta(1, 99)
)
nothing
# output
```
# Usage example with `ODEProcess` and predefined SIR model
In this example we define an `ODEProcess` object using the predefined `SIRParams` model from
above. We then generate latent infections using the `generate_latent_infs` function, and refit
the model using a `Turing` model.
We assume that the latent infections are observed with a Poisson likelihood around their
ODE model prediction. The population size is `N = 1000`, which we put into the `sol2infs` function,
which maps the ODE solution to the number of infections. Recall that the `EpiAware` default SIR
implementation assumes the model is in density/proportion form. Also, note that since the `sol2infs`
function is a link function that maps the ODE solution to the expected number of infections we also
apply the `LogExpFunctions.softplus` function to ensure that the expected number of infections is non-negative.
Note that the `softplus` function is a smooth approximation to the ReLU function `x -> max(0, x)`.
The utility of this approach is that small negative output from the ODE solver (e.g. ~ -1e-10) will be
mapped to small positive values, without needing to use strict positivity constraints in the model.
First, we define the `ODEProcess` object which combines the SIR model with the `sol2infs` link
function and the solver options.
```jldoctest sirexample; output = false
using Turing, LogExpFunctions
N = 1000.0
sir_process = ODEProcess(
params = sirparams,
sol2infs = sol -> softplus.(N .* sol[2, :]),
solver_options = Dict(:verbose => false, :saveat => 1.0)
)
nothing
# output
```
Second, we define a `PoissionError` observation model for linking the the number of infections.
```jldoctest sirexample; output = false
pois_obs = PoissonError()
nothing
# output
```
Next, we create a `Turing` model for the full generative process: this solves the ODE model for
the latent infections and then samples the observed infections from a Poisson distribution with this
as the average.
NB: The `nothing` argument is a dummy latent process, e.g. a log-Rt time series, that is not
used in the SIR model, but might be used in other models.
```jldoctest sirexample; output = false
@model function fit_ode_model(data)
@submodel I_t = generate_latent_infs(sir_process, nothing)
@submodel y_t = generate_observations(pois_obs, data, I_t)
return y_t
end
nothing
# output
```
We can generate some test data from the model by passing `missing` as the argument to the model.
This tells `Turing` that there is no data to condition on, so it will sample from the prior parameters
and then generate infections. In this case, we do it in a way where we cache the sampled parameters
as `θ` for later use.
```jldoctest sirexample; output = false
# Sampled parameters
gen_mdl = fit_ode_model(missing)
θ = rand(gen_mdl)
test_data = (gen_mdl | θ)()
nothing
# output
```
Now, we can refit the model but this time we condition on the test data. We suppress the
output of the sampling process to keep the output clean, but you can remove the `@suppress` macro.
```jldoctest sirexample; output = false
using Suppressor
inference_mdl = fit_ode_model(test_data)
chn = Suppressor.@suppress sample(inference_mdl, NUTS(), 2_000)
summarize(chn)
nothing
# output
```
We can compare the summarized chain to the sampled parameters in `θ` to see that the model is
fitting the data well and recovering a credible interval containing the true parameters.
# Custom ODE models
To define a custom ODE model, you need to define:
- Some `CustomModel <: AbstractTuringLatentModel` struct
that contains the ODE problem as a field called `prob`, as well as sufficient fields to
define or sample the parameters of the ODE model.
- A method for `EpiAwareBase.generate_latent(params::CustomModel, Z_t)` that generates the
initial condition and parameters of the ODE model, potentially conditional on a sample from a latent process `Z_t`.
This method must return a `Tuple` `(u0, p)` where `u0` is the initial condition and `p` is the parameters.
Here is an example of a simple custom ODE model for _specified_ exponential growth:
```jldoctest customexample; output = false
using EpiAware, Turing, OrdinaryDiffEq
# Define a simple exponential growth model for testing
function expgrowth(du, u, p, t)
du[1] = p[1] * u[1]
end
r = log(2) / 7 # Growth rate corresponding to 7 day doubling time
# Define the ODE problem using SciML
prob = ODEProblem(expgrowth, [1.0], (0.0, 10.0), [r])
# Define the custom parameters struct
struct CustomModel <: AbstractTuringLatentModel
prob::ODEProblem
r::Float64
u0::Float64
end
custom_ode = CustomModel(prob, r, 1.0)
# Define the custom generate_latent function
@model function EpiAwareBase.generate_latent(params::CustomModel, n)
return ([params.u0], [params.r])
end
nothing
# output
```
This model is not random! But we can still use it to generate latent infections.
```jldoctest customexample; output = false
# Define the ODEProcess
expgrowth_model = ODEProcess(
params = custom_ode,
sol2infs = sol -> sol[1, :]
)
infs = generate_latent_infs(expgrowth_model, nothing)()
nothing
# output
```
"""
@kwdef struct ODEProcess{
P <: AbstractTuringLatentModel, S, F <: Function, D <:
Union{Dict, NamedTuple}} <:
EpiAwareBase.AbstractTuringEpiModel
"The ODE problem and parameters, where `P` is a subtype of `AbstractTuringLatentModel`."
params::P
"The solver used for the ODE problem. Default is `AutoVern7(Rodas5())`, which is an auto
switching solver aimed at medium/low tolerances."
solver::S = AutoVern7(Rodas5())
"A function that maps the solution object of the ODE to infection counts."
sol2infs::F
"The extra solver options for the ODE problem. Can be either a `Dict` or a `NamedTuple`
containing the solver options."
solver_options::D = Dict(:verbose => false, :saveat => 1.0)
end

@doc raw"""
Implement the `generate_latent_infs` function for the `ODEProcess` model.
This function remakes the ODE problem with the provided initial conditions and parameters,
solves it using the specified solver, and then transforms the solution into latent infections
using the `sol2infs` function.
# Example usage with predefined SIR model
In this example we define an `ODEProcess` object using the predefined `SIRParams` model and
generate an expected infection time series using SIR model parameters sampled from their priors.
```jldoctest; output = false
using EpiAware, OrdinaryDiffEq, Distributions, Turing, LogExpFunctions
# Create an instance of SIRParams
sirparams = SIRParams(
tspan = (0.0, 100.0),
infectiousness = LogNormal(log(0.3), 0.05),
recovery_rate = LogNormal(log(0.1), 0.05),
initial_prop_infected = Beta(1, 99)
)
#Population size
N = 1000.0
sir_process = ODEProcess(
params = sirparams,
sol2infs = sol -> softplus.(N .* sol[2, :]),
solver_options = Dict(:verbose => false, :saveat => 1.0)
)
generated_It = generate_latent_infs(sir_process, nothing)()
nothing
# output
```
"""
@model function EpiAwareBase.generate_latent_infs(epi_model::ODEProcess, Z_t)
prob, solver, sol2infs, solver_options = epi_model.params.prob,
epi_model.solver, epi_model.sol2infs, epi_model.solver_options
n = isnothing(Z_t) ? 0 : size(Z_t, 1)

@submodel u0, p = generate_latent(epi_model.params, n)

_prob = remake(prob; u0 = u0, p = p)
sol = solve(_prob, solver; solver_options...)
I_t = sol2infs(sol)

return I_t
end
10 changes: 8 additions & 2 deletions EpiAware/src/EpiLatentModels/EpiLatentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ using LogExpFunctions: softmax

using FillArrays: Fill

using Turing, Distributions, DocStringExtensions, LinearAlgebra
using Turing, Distributions, DocStringExtensions, LinearAlgebra, SparseArrays,
OrdinaryDiffEq

#Export models
export FixedIntercept, Intercept, RandomWalk, AR, HierarchicalNormal

#Export ODE definitions
export SIRParams, SEIRParams

# Export tools for manipulating latent models
export CombineLatentModels, ConcatLatentModels, BroadcastLatentModel

Expand All @@ -29,10 +33,13 @@ export broadcast_rule, broadcast_dayofweek, broadcast_weekly, equal_dimensions
export DiffLatentModel, TransformLatentModel, PrefixLatentModel, RecordExpectedLatent

include("docstrings.jl")
include("utils.jl")
include("models/Intercept.jl")
include("models/RandomWalk.jl")
include("models/AR.jl")
include("models/HierarchicalNormal.jl")
include("odemodels/SIRParams.jl")
include("odemodels/SEIRParams.jl")
include("modifiers/DiffLatentModel.jl")
include("modifiers/TransformLatentModel.jl")
include("modifiers/PrefixLatentModel.jl")
Expand All @@ -42,6 +49,5 @@ include("manipulators/ConcatLatentModels.jl")
include("manipulators/broadcast/LatentModel.jl")
include("manipulators/broadcast/rules.jl")
include("manipulators/broadcast/helpers.jl")
include("utils.jl")

end
Loading

0 comments on commit 0ddac54

Please sign in to comment.