From da96097218f82965f698e017f819fafb3cad7886 Mon Sep 17 00:00:00 2001 From: mrazomej Date: Wed, 15 Jan 2025 08:45:08 -0800 Subject: [PATCH] fix stats module and include tests for it --- src/stats.jl | 154 ++++------------------- test/runtests.jl | 1 + test/stats_tests.jl | 296 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 318 insertions(+), 133 deletions(-) create mode 100644 test/stats_tests.jl diff --git a/src/stats.jl b/src/stats.jl index f91d10e..eaf409e 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -1029,10 +1029,6 @@ least the following columns: - `neutral_col::Symbol=:neutral`: Name of the column in `data` defining whether the barcode belongs to a neutral lineage or not. The column must contain entries of type `Bool`. -- `rm_T0::Bool=false`: Optional argument to remove the first time point from the -inference. Commonly, the data from this first time point is of much lower -quality. Therefore, removing this first time point might result in a better -inference. - `pseudocount::Int=1`: Pseudo count number to add to all counts. This is useful to avoid divisions by zero. @@ -1047,20 +1043,11 @@ function naive_fitness( time_col::Symbol=:time, count_col::Symbol=:count, neutral_col::Symbol=:neutral, - rm_T0::Bool=false, pseudocount::Int=1 ) # Keep only the needed data to work with data = data[:, [id_col, time_col, count_col, neutral_col]] - # Extract unique time points - timepoints = sort(unique(data[:, time_col])) - - # Remove T0 if indicated - if rm_T0 - data = data[.!(data[:, time_col] .== first(timepoints)), :] - end # if - # Add pseudo-count to each barcode to avoid division by zero data[:, count_col] .+= pseudocount @@ -1118,85 +1105,6 @@ function naive_fitness( ) end # function -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # -# Define full-rank normal distribution for variational inference -# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # - -@doc raw""" -Function to build a full-rank distribution to be used for ADVI optimization. -The code in this function comes from (`Turing.jl -tutorial`)[https://turinglang.org/v0.28/tutorials/09-variational-inference/] - -# Arguments -- `dim::Int`: Dimensionality of parameter space. -- `model::DynamicPPL.model`: Turing model to be fit using ADVI. - -# Returns -Initialized distribution to be used when fitting a full-rank variational model. -""" -function build_getq(dim, model) - # Define base distribution as standard normal. - base_dist = Turing.DistributionsAD.TuringDiagMvNormal(zeros(dim), ones(dim)) - - # From Turing.jl: - # > bijector(model::Turing.Model) is defined by Turing, and will return a - # bijector which takes you from the space of the latent variables to the - # real space. We're interested in using a normal distribution as a - # base-distribution and transform samples to the latent space, thus we need - # the inverse mapping from the reals to the latent space. - constrained_dist = Bijectors.inverse(Bijectors.bijector(model)) - - # Define proto array with parameters for full-rank normal distribution. - # Note: Using the ComponentArray makes things much simpler to work with. - proto_arr = ComponentArrays.ComponentArray(; - L=zeros(dim, dim), b=zeros(dim) - ) - - # Get Axes for proto-array. This basically returns the shape of each element - # in proto_arr - proto_axes = ComponentArrays.getaxes(proto_arr) - # Define number of parameters - num_params = length(proto_arr) - - # Define getq function to be returned with specific dimensions for full-rank - # variational problem. - function getq(θ) - # Unpack parameters. This is where the combination of - # `ComponentArrays.jl` and `UnPack.jl` become handy. - L, b = begin - # Unpack covariance matrix and mean array - UnPack.@unpack L, b = ComponentArrays.ComponentArray(θ, proto_axes) - # Convert covariance matrix to lower diagonal covariance matrix to - # use Cholesky decomposition. - LinearAlgebra.LowerTriangular(L), b - end - - # From Turing.jl: - # > For this to represent a covariance matrix we need to ensure that the - # diagonal is positive. We can enforce this by zeroing out the diagonal - # and then adding back the diagonal exponentiated. - - # 1. Extract diagonal elements of matrix L - D = LinearAlgebra.Diagonal(LinearAlgebra.diag(L)) - # 2. Subtract diagonal elements to make the L diagonal be all zeros, - # then, add the exponential of the diagonal to ensure positivit. - A = L - D + exp(D) - - # Define unconstrained parameters by composing the constrained - # distribution with the bijectors Shift and Scale. The ∘ operator means - # to compose functions (f∘g)(x) = f(g(x)). NOTE: I do not fully - # follow how this works, I am using Turing.jl's example. - - b = constrained_dist ∘ Bijectors.Shift(b) ∘ Bijectors.Scale(A) - - return Turing.transformed(base_dist, b) - end - - # Return resulting getq function - return getq - -end # function - # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # # Computing naive parameter priors based on neutrals data # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% # @@ -1245,10 +1153,6 @@ used to sample from the population mean fitness posterior distribution. experimental replicates each point belongs to. - `pseudocount::Int=1`: Pseudo counts to add to raw counts to avoid dividing by zero. This is useful if some of the barcodes go extinct. -- `rm_T0::Bool=false`: Optional argument to remove the first time point from the -inference. Commonly, the data from this first time point is of much lower -quality. Therefore, removing this first time point might result in a better -inference. # Returns - `prior_params::Dict`: Dictionary with **two** entries: @@ -1276,22 +1180,7 @@ function naive_prior( neutral_col::Symbol=:neutral, rep_col::Union{Nothing,Symbol}=nothing, pseudocount::Int=1, - rm_T0::Bool=false ) - # Deep copy data to avoid any changes to original - data = deepcopy(data) - - # Extract unique time points - timepoints = sort(unique(data[:, time_col])) - - # Remove T0 if indicated - if rm_T0 - data = data[.!(data[:, time_col] .== first(timepoints)), :] - end # if - - # Re-extract unique time points - timepoints = sort(unique(data[:, time_col])) - # Add pseudocount to count column data[:, count_col] = data[:, count_col] .+ pseudocount @@ -1303,15 +1192,14 @@ function naive_prior( count_col=count_col, neutral_col=neutral_col, rep_col=rep_col, - verbose=false ) if typeof(rep_col) <: Nothing # Compute frequencies - data_mats[:freq] = data_mats[:bc_count] ./ data_mats[:bc_total] + bc_freq = data_mats.bc_count ./ data_mats.bc_total # Extract neutral lineages frequencies - neutral_freq = data_mats[:freq][:, 1:data_mats[:n_neutral]] + neutral_freq = bc_freq[:, 1:data_mats.n_neutral] # Compute log-frequency ratios neutral_logfreq = log.( @@ -1319,14 +1207,14 @@ function naive_prior( ) elseif typeof(rep_col) <: Symbol # Check if all replicates had the same number of time points - if typeof(data_mats[:bc_count]) <: Array{Int64,3} + if typeof(data_mats.bc_count) <: Array{<:Int,3} # Initialize array to save frequencies - freqs = Array{Float64}(undef, size(data_mats[:bc_count])...) + freqs = Array{Float64}(undef, size(data_mats.bc_count)...) # Compute frequencies freqlist = [ - x ./ data_mats[:bc_total] - for x in eachslice(data_mats[:bc_count]; dims=2) + x ./ data_mats.bc_total + for x in eachslice(data_mats.bc_count; dims=2) ] # Loop through each slice of freqs @@ -1335,27 +1223,27 @@ function naive_prior( end # for # Assign frequencies - data_mats[:freq] = freqs + bc_freq = freqs # Extract neutral lineages frequencies - neutral_freq = data_mats[:freq][:, 1:data_mats[:n_neutral], :] + neutral_freq = bc_freq[:, 1:data_mats.n_neutral, :] # Compute log-frequency ratios neutral_logfreq = log.( neutral_freq[2:end, :, :] ./ neutral_freq[1:end-1, :, :] ) - elseif typeof(data_mats[:bc_count]) <: Vector{Matrix{Int64}} + elseif typeof(data_mats.bc_count) <: Vector{<:Matrix{<:Int}} # Define number of replicates - global n_rep = length(data_mats[:bc_count]) + global n_rep = data_mats.n_rep # Compute frequencies - data_mats[:freq] = [ - data_mats[:bc_count][rep] ./ data_mats[:bc_total][rep] + bc_freq = [ + data_mats.bc_count[rep] ./ data_mats.bc_total[rep] for rep in 1:n_rep ] # Extract neutral lineages frequencies neutral_freq = [ - data_mats[:freq][rep][:, 1:data_mats[:n_neutral]] + bc_freq[rep][:, 1:data_mats.n_neutral] for rep = 1:n_rep ] # Compute log-frequency ratios @@ -1379,7 +1267,7 @@ function naive_prior( ) elseif typeof(rep_col) <: Symbol # Check if all replicates have the same number of time points - if typeof(data_mats[:bc_count]) <: Array{Int64,3} + if typeof(data_mats.bc_count) <: Array{<:Int,3} # Initialize array to save means logfreq_mean = Matrix{Float64}( undef, size(neutral_logfreq)[1], size(neutral_logfreq)[3] @@ -1396,7 +1284,7 @@ function naive_prior( ) end # for end # for - elseif typeof(data_mats[:bc_count]) <: Vector{Matrix{Int64}} + elseif typeof(data_mats.bc_count) <: Vector{<:Matrix{<:Int}} # Compute mean per time point per replicate logfreq_mean = vcat([ StatsBase.mean.( @@ -1419,7 +1307,7 @@ function naive_prior( ) elseif typeof(rep_col) <: Symbol # Check if all replicates have the same number of time points - if typeof(data_mats[:bc_count]) <: Array{Int64,3} + if typeof(data_mats.bc_count) <: Array{<:Int,3} # Initialize array to save means logfreq_std = Matrix{Float64}( undef, size(neutral_logfreq)[1], size(neutral_logfreq)[3] @@ -1436,7 +1324,7 @@ function naive_prior( ) end # for end # for - elseif typeof(data_mats[:bc_count]) <: Vector{Matrix{Int64}} + elseif typeof(data_mats.bc_count) <: Vector{<:Matrix{<:Int}} # Compute mean per time point per replicate logfreq_std = vcat([ StatsBase.std.( @@ -1455,11 +1343,11 @@ function naive_prior( # Compute mean per time point for approximate mean fitness making sure we to # remove infinities. if (typeof(rep_col) <: Nothing) | - (typeof(data_mats[:bc_count]) <: Array{Int64,3}) - logλ_prior = log.(data_mats[:bc_count])[:] - elseif typeof(data_mats[:bc_count]) <: Vector{Matrix{Int64}} + (typeof(data_mats.bc_count) <: Array{Int64,3}) + logλ_prior = log.(data_mats.bc_count)[:] + elseif typeof(data_mats.bc_count) <: Vector{<:Matrix{<:Int}} logλ_prior = vcat( - [log.(data_mats[:bc_count][rep])[:] for rep = 1:n_rep]... + [log.(data_mats.bc_count[rep])[:] for rep = 1:n_rep]... ) end # if diff --git a/test/runtests.jl b/test/runtests.jl index 60c5fc1..5afb72f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ @testset "BarBay.jl" begin include("utils_tests.jl") include("vi_tests.jl") + include("stats_tests.jl") end # ---------------------------------------------------------------------------- # diff --git a/test/stats_tests.jl b/test/stats_tests.jl new file mode 100644 index 0000000..93eb1af --- /dev/null +++ b/test/stats_tests.jl @@ -0,0 +1,296 @@ +using BarBay +using Test +using CSV +using DataFrames +using StatsBase +using Turing +using Distributions +using LinearAlgebra +using AdvancedVI +using ReverseDiff + +## ----------------------------------------------------------------------------- + +@testset "stats tests" begin + # ======================================================================== + # Test matrix_quantile_range function + # ======================================================================== + @testset "matrix_quantile_range" begin + # Create test matrix + test_matrix = randn(10, 5) + + # Test basic functionality + quantiles = [0.95] + result = BarBay.stats.matrix_quantile_range(quantiles, test_matrix) + + # Test output shape + @test size(result) == (size(test_matrix, 2), length(quantiles), 2) + + # Test quantile values are reasonable + @test all(result[:, 1, 2] .>= result[:, 1, 1]) # upper bounds > lower bounds + + # Test error for invalid quantiles + @test_throws ErrorException BarBay.stats.matrix_quantile_range([1.5], test_matrix) + @test_throws ErrorException BarBay.stats.matrix_quantile_range([-0.5], test_matrix) + + # Test error for invalid dimensions + @test_throws ErrorException BarBay.stats.matrix_quantile_range(quantiles, test_matrix, dims=3) + end + + # ======================================================================== + # Test posterior predictive check functions + # ======================================================================== + @testset "Posterior Predictive Checks" begin + # Load test data + data = CSV.read("data/data001_single.csv", DataFrame) + + @testset "freq_bc_ppc" begin + # Create test DataFrame with required columns + test_df = DataFrame( + :s⁽ᵐ⁾ => randn(100), + :σ⁽ᵐ⁾ => abs.(randn(100)), + Symbol("f̲⁽ᵐ⁾[1]") => abs.(randn(100)), + :s̲ₜ₁ => randn(100), + :s̲ₜ₂ => randn(100) + ) + + # Test basic functionality + result = BarBay.stats.freq_bc_ppc(test_df, 10) + @test size(result, 2) == 3 # initial + 2 timepoints + + # Test with non-default parameters + custom_params = Dict( + :bc_mean_fitness => :custom_mean, + :bc_std_fitness => :custom_std, + :bc_freq => :custom_freq, + :population_mean_fitness => :custom_pop + ) + + # Add custom columns + test_df.custom_mean = test_df.s⁽ᵐ⁾ + test_df.custom_std = test_df.σ⁽ᵐ⁾ + test_df.custom_freq = test_df[!, Symbol("f̲⁽ᵐ⁾[1]")] + test_df.custom_pop₁ = test_df.s̲ₜ₁ + test_df.custom_pop₂ = test_df.s̲ₜ₂ + + result_custom = BarBay.stats.freq_bc_ppc(test_df, 10, param=custom_params) + @test size(result_custom) == size(result) + end + + @testset "logfreq_ratio_bc_ppc" begin + # Create test DataFrame + test_df = DataFrame( + s⁽ᵐ⁾=randn(100), + σ⁽ᵐ⁾=abs.(randn(100)), + s̲ₜ₁=randn(100), + s̲ₜ₂=randn(100) + ) + + # Test basic functionality + result = BarBay.stats.logfreq_ratio_bc_ppc(test_df, 10) + @test size(result, 2) == 2 # 2 timepoints + + # Test non-flattened output + result_nonflat = BarBay.stats.logfreq_ratio_bc_ppc(test_df, 10, flatten=false) + @test size(result_nonflat, 3) == 10 # 10 samples + end + + @testset "logfreq_ratio_popmean_ppc" begin + # Create test DataFrame + test_df = DataFrame( + sₜ₁=randn(100), + sₜ₂=randn(100), + σₜ₁=abs.(randn(100)), + σₜ₂=abs.(randn(100)) + ) + + # Test basic functionality + result = BarBay.stats.logfreq_ratio_popmean_ppc(test_df, 10) + @test size(result, 2) == 2 # 2 timepoints + + # Test non-default parameters + custom_params = Dict( + :population_mean_fitness => :custom_mean, + :population_std_fitness => :custom_std + ) + + # Add custom columns + test_df.custom_mean₁ = test_df.sₜ₁ + test_df.custom_mean₂ = test_df.sₜ₂ + test_df.custom_std₁ = test_df.σₜ₁ + test_df.custom_std₂ = test_df.σₜ₂ + + result_custom = BarBay.stats.logfreq_ratio_popmean_ppc( + test_df, 10, param=custom_params + ) + @test size(result_custom) == size(result) + end + end + + # ======================================================================== + # Test naive estimation functions + # ======================================================================== + @testset "Naive Estimation Functions" begin + # Load test data + data = CSV.read("data/data001_single.csv", DataFrame) + + @testset "naive_fitness" begin + # Test basic functionality + result = BarBay.stats.naive_fitness(data) + @test result isa DataFrame + @test :fitness in propertynames(result) + @test :barcode in propertynames(result) + + # Test with custom parameters + result_custom = BarBay.stats.naive_fitness( + data, + id_col=:barcode, + time_col=:time, + count_col=:count, + neutral_col=:neutral, + pseudocount=2 + ) + @test result_custom isa DataFrame + end + + end + + @testset "Expanded naive_prior tests" begin + # ======================================================================== + # Test naive_prior with data001_single.csv (single condition) + # ======================================================================== + @testset "Single Condition Data (data001)" begin + # Load test data + data = CSV.read("data/data001_single.csv", DataFrame) + + # Test basic functionality + result = BarBay.stats.naive_prior(data) + + # Test result properties + @test result isa Dict + @test :s_pop_prior in keys(result) + @test :logσ_pop_prior in keys(result) + @test :logλ_prior in keys(result) + + # Test that the priors have reasonable values + @test !any(isnan.(result[:s_pop_prior])) + @test !any(isnan.(result[:logσ_pop_prior])) + @test !any(isnan.(result[:logλ_prior])) + + # Test that λ prior dimensions match data dimensions + expected_λ_length = size(unique(data.time), 1) * + size(unique(data.barcode), 1) + @test length(result[:logλ_prior]) == expected_λ_length + end + + # ======================================================================== + # Test naive_prior with data002_hier-rep.csv (hierarchical replicates) + # ======================================================================== + @testset "Hierarchical Replicates Data (data002)" begin + # Load test data + data = CSV.read("data/data002_hier-rep.csv", DataFrame) + + # Test with replicates - all timepoints equal + result = BarBay.stats.naive_prior(data, rep_col=:rep) + + # Test result properties + @test result isa Dict + @test all(k in keys(result) for k in [:s_pop_prior, :logσ_pop_prior, :logλ_prior]) + + # Test dimensions with replicates + n_reps = length(unique(data.rep)) + n_times = length(unique(data.time)) + n_barcodes = length(unique(data.barcode)) + + # Test that population parameters account for all replicates + @test length(result[:s_pop_prior]) == (n_times - 1) * n_reps + @test length(result[:logσ_pop_prior]) == (n_times - 1) * n_reps + + # Test with uneven timepoints between replicates + data_uneven = data[ + (data.rep.!=maximum(data.rep)).|(data.time.!=maximum(data.time)), + :] + result_uneven = BarBay.stats.naive_prior(data_uneven, rep_col=:rep) + + # Test result properties for uneven timepoints + @test result_uneven isa Dict + @test all(k in keys(result_uneven) for k in [:s_pop_prior, :logσ_pop_prior, :logλ_prior]) + + # Verify dimensions for uneven timepoints + n_reps_uneven = length(unique(data_uneven.rep)) + times_per_rep = [length(unique(group.time)) for group in groupby(data_uneven, :rep)] + expected_params = sum(times_per_rep .- 1) + @test length(result_uneven[:s_pop_prior]) == expected_params + @test length(result_uneven[:logσ_pop_prior]) == expected_params + end + + # ======================================================================== + # Test naive_prior with data003_multienv.csv (multiple environments) + # ======================================================================== + @testset "Multiple Environments Data (data003)" begin + # Load test data + data = CSV.read("data/data003_multienv.csv", DataFrame) + + # Test with environments + result = BarBay.stats.naive_prior(data) + + # Test result properties + @test result isa Dict + @test all(k in keys(result) for k in [:s_pop_prior, :logσ_pop_prior, :logλ_prior]) + + # Test dimensions with environments + n_times = length(unique(data.time)) + n_envs = length(unique(data.env)) + n_barcodes = length(unique(data.barcode)) + + # Test that population parameters account for environment transitions + @test length(result[:s_pop_prior]) == n_times - 1 + @test length(result[:logσ_pop_prior]) == n_times - 1 + + # Test proper initialization with environment changes + env_changes = diff(data[sortperm(data.time), :env]) + @test !any(isnan.(result[:s_pop_prior])) + @test !any(isnan.(result[:logσ_pop_prior])) + end + + # ======================================================================== + # Test naive_prior with data004_multigen.csv (multiple genotypes) + # ======================================================================== + @testset "Multiple Genotypes Data (data004)" begin + # Load test data + data = CSV.read("data/data004_multigen.csv", DataFrame) + + # Test with genotypes + result = BarBay.stats.naive_prior(data) + + # Test result properties + @test result isa Dict + @test all(k in keys(result) for k in [:s_pop_prior, :logσ_pop_prior, :logλ_prior]) + + # Test dimensions with genotypes + n_times = length(unique(data.time)) + n_genotypes = length(unique(data.genotype)) + n_barcodes = length(unique(data.barcode)) + + # Test that population parameters are computed correctly + @test length(result[:s_pop_prior]) == n_times - 1 + @test length(result[:logσ_pop_prior]) == n_times - 1 + + # Test that priors are reasonable for each genotype + @test !any(isnan.(result[:s_pop_prior])) + @test !any(isnan.(result[:logσ_pop_prior])) + end + + # ======================================================================== + # Test error handling and edge cases + # ======================================================================== + @testset "Error Handling and Edge Cases" begin + # Load base data + data = CSV.read("data/data001_single.csv", DataFrame) + + # Test with missing timepoints + data_missing = data[2:end, :] + @test_throws ErrorException BarBay.stats.naive_prior(data_missing) + end + end +end \ No newline at end of file