Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor for more multi dispatch #79

Merged
merged 21 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
061562b
moved callable functions into `generate_latent_infs` methods
SamuelBrand1 Feb 26, 2024
5aa43e7
Functional rewrite of latent processes
SamuelBrand1 Feb 26, 2024
44163bb
functional patterns for observation model
SamuelBrand1 Feb 26, 2024
79ced40
Merge branch 'main' into refactor-for-more-multi-dispatch
SamuelBrand1 Feb 26, 2024
74129fd
Default method for generate latent process doesn't need to be generat…
SamuelBrand1 Feb 26, 2024
8f18376
Functional refactor for observation models
SamuelBrand1 Feb 26, 2024
aafacc8
functional refactor for initialisation
SamuelBrand1 Feb 26, 2024
5bb9692
temporary broken state
SamuelBrand1 Feb 26, 2024
d3ea60c
Shift initialisation into epimodel
SamuelBrand1 Feb 26, 2024
813bc21
Merge branch 'main' into refactor-for-more-multi-dispatch
SamuelBrand1 Feb 27, 2024
0cd2efc
Moved the function closure approach of `generate_latent_infs` for `Re…
SamuelBrand1 Feb 27, 2024
2228c4b
new unit tests for overall approach
SamuelBrand1 Feb 27, 2024
a0e91fe
remove undefined exports and fix broken unit test due to default prio…
SamuelBrand1 Feb 27, 2024
5b06d43
Delete accidental script commit
SamuelBrand1 Feb 27, 2024
5b293dd
New model diagram
SamuelBrand1 Feb 27, 2024
6373362
Update toy_model_log_infs_RW.jl
SamuelBrand1 Feb 27, 2024
a7e3d60
reformat and delete commented code
SamuelBrand1 Feb 27, 2024
1f5d461
delete initialisation --- not used
SamuelBrand1 Feb 27, 2024
270c2c2
reformat to latest version of SciMLStyle
SamuelBrand1 Feb 27, 2024
9331c2f
Create safety for rt/Rt modelling so doesn't sample huge epidemics th…
SamuelBrand1 Feb 27, 2024
5533e53
Update test_models.jl
SamuelBrand1 Feb 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 42 additions & 37 deletions EpiAware/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,63 +9,68 @@
- Solid lines indicate implemented features/analysis.
- Dashed lines indicate planned features/analysis.

## Proposed `EpiAware` model diagram
## Current `EpiAware` model diagram
```mermaid
flowchart LR

A["Underlying dists.
and specify length of sims
---------------------
EpiData"]
A["Underlying GI
Bijector"]

EpiModel["AbstractEpiModel
----------------------
Choice of target
for latent process:

B["Choice of target
for latent process
---------------------
DirectInfections
ExpGrowthRate
Renewal"]

C["Observational Data
InitModel["Priors for
initial scale of incidence"]

DataW[Data wrangling and QC]


ObsData["Observational Data
---------------------
Obs. cases y_t"]
D["Latent processes

LatentProcPriors["Latent process priors"]

LatentProc["AbstractLatentProcess
---------------------
RandomWalkLatentProcess"]

ObsModelPriors["Observation model priors
choice of delayed obs. model"]

ObsModel["AbstractObservationModel
---------------------
random_walk"]
DelayObservations"]

E["Turing model constructor
---------------------
make_epi_inference_model"]
F["Latent Process priors
---------------------
default_rw_priors"]

G[Posterior draws]
H[Posterior checking]
I[Post-processing]
DataW[Data wrangling and QC]
J["Observation models
---------------------
delay_observations"]
K["Observation model priors
---------------------
default_delay_obs_priors"]
ObservationModel["ObservationModel
---------------------
delay_observations_model"]
LatentProcess["LatentProcess
---------------------
random_walk_process"]

A --> EpiModel
B --> EpiModel


A --> EpiData
EpiData --> EpiModel
InitModel --> EpiModel
EpiModel -->E
C-->E
D-->LatentProcess
F-->LatentProcess
J-->ObservationModel
K-->ObservationModel
LatentProcess-->E
ObservationModel-->E
ObsData-->E
DataW-.->ObsData
LatentProcPriors-->LatentProc
LatentProc-->E
ObsModelPriors-->ObsModel
ObsModel-->E


E-->|sample...NUTS...| G
G-.->H
H-.->I
DataW-.->C
```
7 changes: 2 additions & 5 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,18 @@ using Distributions,
DataFramesMeta

# Exported utilities
export create_discrete_pmf, default_rw_priors, default_delay_obs_priors,
default_initialisation_prior, spread_draws
export create_discrete_pmf, spread_draws

# Exported types
export EpiData, Renewal, ExpGrowthRate, DirectInfections

# Exported Turing model constructors
export make_epi_inference_model, delay_observations_model, random_walk_process,
initialize_incidence
export make_epi_inference_model

include("epimodel.jl")
include("utilities.jl")
include("latent-processes.jl")
include("observation-processes.jl")
include("initialisation.jl")
include("models.jl")

end
128 changes: 77 additions & 51 deletions EpiAware/src/epimodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,91 +2,117 @@ abstract type AbstractEpiModel end

struct EpiData{T <: Real, F <: Function}
gen_int::Vector{T}
delay_int::Vector{T}
delay_kernel::SparseMatrixCSC{T, Integer}
cluster_coeff::T
len_gen_int::Integer
len_delay_int::Integer
time_horizon::Integer
transformation::F

#Inner constructors for EpiData object
function EpiData(
gen_int,
delay_int,
cluster_coeff,
time_horizon::Integer,
transformation::Function
)
function EpiData(gen_int,
transformation::Function)
@assert all(gen_int .>= 0) "Generation interval must be non-negative"
@assert all(delay_int .>= 0) "Delay interval must be non-negative"
@assert sum(gen_int)≈1 "Generation interval must sum to 1"
@assert sum(delay_int)≈1 "Delay interval must sum to 1"

K = generate_observation_kernel(delay_int, time_horizon)

new{eltype(gen_int), typeof(transformation)}(
gen_int,
delay_int,
K,
cluster_coeff,
new{eltype(gen_int), typeof(transformation)}(gen_int,
length(gen_int),
length(delay_int),
time_horizon,
transformation
)
transformation)
end

function EpiData(
gen_distribution::ContinuousDistribution,
delay_distribution::ContinuousDistribution,
cluster_coeff,
time_horizon::Integer;
function EpiData(gen_distribution::ContinuousDistribution;
D_gen,
D_delay,
Δd = 1.0,
transformation::Function = exp
)
transformation::Function = exp)
gen_int = create_discrete_pmf(gen_distribution, Δd = Δd, D = D_gen) |>
p -> p[2:end] ./ sum(p[2:end])
delay_int = create_discrete_pmf(delay_distribution, Δd = Δd, D = D_delay)

return EpiData(gen_int, delay_int, cluster_coeff, time_horizon, transformation)
return EpiData(gen_int, transformation)
end
end

struct DirectInfections <: AbstractEpiModel
struct DirectInfections{S <: Sampleable} <: AbstractEpiModel
data::EpiData
initialisation_prior::S
end

function (epimodel::DirectInfections)(_It, init)
epimodel.data.transformation.(init .+ _It)
struct ExpGrowthRate{S <: Sampleable} <: AbstractEpiModel
data::EpiData
initialisation_prior::S
end

struct ExpGrowthRate <: AbstractEpiModel
struct Renewal{S <: Sampleable} <: AbstractEpiModel
data::EpiData
initialisation_prior::S
end

function (epimodel::ExpGrowthRate)(rt, init)
init .+ cumsum(rt) .|> exp
"""
function (epimodel::Renewal)(recent_incidence, Rt)

Compute new incidence based on recent incidence and Rt.

This is a callable function on `Renewal` structs, that encodes new incidence prediction
seabbs marked this conversation as resolved.
Show resolved Hide resolved
given recent incidence and Rt according to basic renewal process.

```math
I_t = R_t \\sum_{i=1}^{n-1} I_{t-i} g_i
```

where `I_t` is the new incidence, `R_t` is the reproduction number, `I_{t-i}` is the recent incidence
and `g_i` is the generation interval.


# Arguments
- `recent_incidence`: Array of recent incidence values.
- `Rt`: Reproduction number.

# Returns
- Tuple containing the updated incidence array and the new incidence value.

"""
function (epimodel::Renewal)(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epimodel.data.gen_int)
return ([new_incidence; recent_incidence[1:(epimodel.data.len_gen_int - 1)]],
new_incidence)
end

struct Renewal <: AbstractEpiModel
data::EpiData
function generate_latent_infs(epimodel::AbstractEpiModel, latent_process)
@info "No concrete implementation for `generate_latent_infs` is defined."
return nothing
end

function (epimodel::Renewal)(_Rt, init)
I₀ = epimodel.data.transformation(init)
@model function generate_latent_infs(epimodel::DirectInfections, _It)
init_incidence ~ epimodel.initialisation_prior
return epimodel.data.transformation.(init_incidence .+ _It)
seabbs marked this conversation as resolved.
Show resolved Hide resolved
end

@model function generate_latent_infs(epimodel::ExpGrowthRate, rt)
init_incidence ~ epimodel.initialisation_prior
return exp.(init_incidence .+ cumsum(rt))
end

"""
generate_latent_infs(epimodel::Renewal, _Rt)
seabbs marked this conversation as resolved.
Show resolved Hide resolved

`Turing` model constructor for latent infections using the `Renewal` object `epimodel` and time-varying unconstrained reproduction number `_Rt`.

`generate_latent_infs` creates a `Turing` model for sampling latent infections with given unconstrainted
reproduction number `_Rt` but random initial incidence scale. The initial incidence pre-time one is given as
a scale on top of an exponential growing process with exponential growth rate given by `R_to_r`applied to the
first value of `Rt`.

# Arguments
- `epimodel::Renewal`: The epidemiological model.
- `_Rt`: Time-varying unconstrained (e.g. log-) reproduction number.

# Returns
- `I_t`: Array of latent infections over time.

"""
@model function generate_latent_infs(epimodel::Renewal, _Rt)
init_incidence ~ epimodel.initialisation_prior
I₀ = epimodel.data.transformation(init_incidence)
Rt = epimodel.data.transformation.(_Rt)

r_approx = R_to_r(Rt[1], epimodel)
init = I₀ * [exp(-r_approx * t) for t in 0:(epimodel.data.len_gen_int - 1)]

function generate_infs(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epimodel.data.gen_int)
[new_incidence; recent_incidence[1:(epimodel.data.len_gen_int - 1)]], new_incidence
end

I_t, _ = scan(generate_infs, init, Rt)
I_t, _ = scan(epimodel, init, Rt)
return I_t
end
30 changes: 0 additions & 30 deletions EpiAware/src/initialisation.jl

This file was deleted.

Loading
Loading