Skip to content

Commit

Permalink
fix stats module and include tests for it
Browse files Browse the repository at this point in the history
  • Loading branch information
mrazomej committed Jan 15, 2025
1 parent 124daa4 commit da96097
Show file tree
Hide file tree
Showing 3 changed files with 318 additions and 133 deletions.
154 changes: 21 additions & 133 deletions src/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1029,10 +1029,6 @@ least the following columns:
- `neutral_col::Symbol=:neutral`: Name of the column in `data` defining whether
the barcode belongs to a neutral lineage or not. The column must contain
entries of type `Bool`.
- `rm_T0::Bool=false`: Optional argument to remove the first time point from the
inference. Commonly, the data from this first time point is of much lower
quality. Therefore, removing this first time point might result in a better
inference.
- `pseudocount::Int=1`: Pseudo count number to add to all counts. This is
useful to avoid divisions by zero.
Expand All @@ -1047,20 +1043,11 @@ function naive_fitness(
time_col::Symbol=:time,
count_col::Symbol=:count,
neutral_col::Symbol=:neutral,
rm_T0::Bool=false,
pseudocount::Int=1
)
# Keep only the needed data to work with
data = data[:, [id_col, time_col, count_col, neutral_col]]

# Extract unique time points
timepoints = sort(unique(data[:, time_col]))

# Remove T0 if indicated
if rm_T0
data = data[.!(data[:, time_col] .== first(timepoints)), :]
end # if

# Add pseudo-count to each barcode to avoid division by zero
data[:, count_col] .+= pseudocount

Expand Down Expand Up @@ -1118,85 +1105,6 @@ function naive_fitness(
)
end # function

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Define full-rank normal distribution for variational inference
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #

@doc raw"""
Function to build a full-rank distribution to be used for ADVI optimization.
The code in this function comes from (`Turing.jl
tutorial`)[https://turinglang.org/v0.28/tutorials/09-variational-inference/]
# Arguments
- `dim::Int`: Dimensionality of parameter space.
- `model::DynamicPPL.model`: Turing model to be fit using ADVI.
# Returns
Initialized distribution to be used when fitting a full-rank variational model.
"""
function build_getq(dim, model)
# Define base distribution as standard normal.
base_dist = Turing.DistributionsAD.TuringDiagMvNormal(zeros(dim), ones(dim))

# From Turing.jl:
# > bijector(model::Turing.Model) is defined by Turing, and will return a
# bijector which takes you from the space of the latent variables to the
# real space. We're interested in using a normal distribution as a
# base-distribution and transform samples to the latent space, thus we need
# the inverse mapping from the reals to the latent space.
constrained_dist = Bijectors.inverse(Bijectors.bijector(model))

# Define proto array with parameters for full-rank normal distribution.
# Note: Using the ComponentArray makes things much simpler to work with.
proto_arr = ComponentArrays.ComponentArray(;
L=zeros(dim, dim), b=zeros(dim)
)

# Get Axes for proto-array. This basically returns the shape of each element
# in proto_arr
proto_axes = ComponentArrays.getaxes(proto_arr)
# Define number of parameters
num_params = length(proto_arr)

# Define getq function to be returned with specific dimensions for full-rank
# variational problem.
function getq(θ)
# Unpack parameters. This is where the combination of
# `ComponentArrays.jl` and `UnPack.jl` become handy.
L, b = begin
# Unpack covariance matrix and mean array
UnPack.@unpack L, b = ComponentArrays.ComponentArray(θ, proto_axes)
# Convert covariance matrix to lower diagonal covariance matrix to
# use Cholesky decomposition.
LinearAlgebra.LowerTriangular(L), b
end

# From Turing.jl:
# > For this to represent a covariance matrix we need to ensure that the
# diagonal is positive. We can enforce this by zeroing out the diagonal
# and then adding back the diagonal exponentiated.

# 1. Extract diagonal elements of matrix L
D = LinearAlgebra.Diagonal(LinearAlgebra.diag(L))
# 2. Subtract diagonal elements to make the L diagonal be all zeros,
# then, add the exponential of the diagonal to ensure positivit.
A = L - D + exp(D)

# Define unconstrained parameters by composing the constrained
# distribution with the bijectors Shift and Scale. The ∘ operator means
# to compose functions (f∘g)(x) = f(g(x)). NOTE: I do not fully
# follow how this works, I am using Turing.jl's example.

b = constrained_dist Bijectors.Shift(b) Bijectors.Scale(A)

return Turing.transformed(base_dist, b)
end

# Return resulting getq function
return getq

end # function

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
# Computing naive parameter priors based on neutrals data
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
Expand Down Expand Up @@ -1245,10 +1153,6 @@ used to sample from the population mean fitness posterior distribution.
experimental replicates each point belongs to.
- `pseudocount::Int=1`: Pseudo counts to add to raw counts to avoid dividing by
zero. This is useful if some of the barcodes go extinct.
- `rm_T0::Bool=false`: Optional argument to remove the first time point from the
inference. Commonly, the data from this first time point is of much lower
quality. Therefore, removing this first time point might result in a better
inference.
# Returns
- `prior_params::Dict`: Dictionary with **two** entries:
Expand Down Expand Up @@ -1276,22 +1180,7 @@ function naive_prior(
neutral_col::Symbol=:neutral,
rep_col::Union{Nothing,Symbol}=nothing,
pseudocount::Int=1,
rm_T0::Bool=false
)
# Deep copy data to avoid any changes to original
data = deepcopy(data)

# Extract unique time points
timepoints = sort(unique(data[:, time_col]))

# Remove T0 if indicated
if rm_T0
data = data[.!(data[:, time_col] .== first(timepoints)), :]
end # if

# Re-extract unique time points
timepoints = sort(unique(data[:, time_col]))

# Add pseudocount to count column
data[:, count_col] = data[:, count_col] .+ pseudocount

Expand All @@ -1303,30 +1192,29 @@ function naive_prior(
count_col=count_col,
neutral_col=neutral_col,
rep_col=rep_col,
verbose=false
)

if typeof(rep_col) <: Nothing
# Compute frequencies
data_mats[:freq] = data_mats[:bc_count] ./ data_mats[:bc_total]
bc_freq = data_mats.bc_count ./ data_mats.bc_total

# Extract neutral lineages frequencies
neutral_freq = data_mats[:freq][:, 1:data_mats[:n_neutral]]
neutral_freq = bc_freq[:, 1:data_mats.n_neutral]

# Compute log-frequency ratios
neutral_logfreq = log.(
neutral_freq[2:end, :] ./ neutral_freq[1:end-1, :]
)
elseif typeof(rep_col) <: Symbol
# Check if all replicates had the same number of time points
if typeof(data_mats[:bc_count]) <: Array{Int64,3}
if typeof(data_mats.bc_count) <: Array{<:Int,3}
# Initialize array to save frequencies
freqs = Array{Float64}(undef, size(data_mats[:bc_count])...)
freqs = Array{Float64}(undef, size(data_mats.bc_count)...)

# Compute frequencies
freqlist = [
x ./ data_mats[:bc_total]
for x in eachslice(data_mats[:bc_count]; dims=2)
x ./ data_mats.bc_total
for x in eachslice(data_mats.bc_count; dims=2)
]

# Loop through each slice of freqs
Expand All @@ -1335,27 +1223,27 @@ function naive_prior(
end # for

# Assign frequencies
data_mats[:freq] = freqs
bc_freq = freqs

# Extract neutral lineages frequencies
neutral_freq = data_mats[:freq][:, 1:data_mats[:n_neutral], :]
neutral_freq = bc_freq[:, 1:data_mats.n_neutral, :]
# Compute log-frequency ratios
neutral_logfreq = log.(
neutral_freq[2:end, :, :] ./ neutral_freq[1:end-1, :, :]
)
elseif typeof(data_mats[:bc_count]) <: Vector{Matrix{Int64}}
elseif typeof(data_mats.bc_count) <: Vector{<:Matrix{<:Int}}
# Define number of replicates
global n_rep = length(data_mats[:bc_count])
global n_rep = data_mats.n_rep

# Compute frequencies
data_mats[:freq] = [
data_mats[:bc_count][rep] ./ data_mats[:bc_total][rep]
bc_freq = [
data_mats.bc_count[rep] ./ data_mats.bc_total[rep]
for rep in 1:n_rep
]

# Extract neutral lineages frequencies
neutral_freq = [
data_mats[:freq][rep][:, 1:data_mats[:n_neutral]]
bc_freq[rep][:, 1:data_mats.n_neutral]
for rep = 1:n_rep
]
# Compute log-frequency ratios
Expand All @@ -1379,7 +1267,7 @@ function naive_prior(
)
elseif typeof(rep_col) <: Symbol
# Check if all replicates have the same number of time points
if typeof(data_mats[:bc_count]) <: Array{Int64,3}
if typeof(data_mats.bc_count) <: Array{<:Int,3}
# Initialize array to save means
logfreq_mean = Matrix{Float64}(
undef, size(neutral_logfreq)[1], size(neutral_logfreq)[3]
Expand All @@ -1396,7 +1284,7 @@ function naive_prior(
)
end # for
end # for
elseif typeof(data_mats[:bc_count]) <: Vector{Matrix{Int64}}
elseif typeof(data_mats.bc_count) <: Vector{<:Matrix{<:Int}}
# Compute mean per time point per replicate
logfreq_mean = vcat([
StatsBase.mean.(
Expand All @@ -1419,7 +1307,7 @@ function naive_prior(
)
elseif typeof(rep_col) <: Symbol
# Check if all replicates have the same number of time points
if typeof(data_mats[:bc_count]) <: Array{Int64,3}
if typeof(data_mats.bc_count) <: Array{<:Int,3}
# Initialize array to save means
logfreq_std = Matrix{Float64}(
undef, size(neutral_logfreq)[1], size(neutral_logfreq)[3]
Expand All @@ -1436,7 +1324,7 @@ function naive_prior(
)
end # for
end # for
elseif typeof(data_mats[:bc_count]) <: Vector{Matrix{Int64}}
elseif typeof(data_mats.bc_count) <: Vector{<:Matrix{<:Int}}
# Compute mean per time point per replicate
logfreq_std = vcat([
StatsBase.std.(
Expand All @@ -1455,11 +1343,11 @@ function naive_prior(
# Compute mean per time point for approximate mean fitness making sure we to
# remove infinities.
if (typeof(rep_col) <: Nothing) |
(typeof(data_mats[:bc_count]) <: Array{Int64,3})
logλ_prior = log.(data_mats[:bc_count])[:]
elseif typeof(data_mats[:bc_count]) <: Vector{Matrix{Int64}}
(typeof(data_mats.bc_count) <: Array{Int64,3})
logλ_prior = log.(data_mats.bc_count)[:]
elseif typeof(data_mats.bc_count) <: Vector{<:Matrix{<:Int}}
logλ_prior = vcat(
[log.(data_mats[:bc_count][rep])[:] for rep = 1:n_rep]...
[log.(data_mats.bc_count[rep])[:] for rep = 1:n_rep]...
)
end # if

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@testset "BarBay.jl" begin
include("utils_tests.jl")
include("vi_tests.jl")
include("stats_tests.jl")
end

# ---------------------------------------------------------------------------- #
Expand Down
Loading

0 comments on commit da96097

Please sign in to comment.