Skip to content

Commit

Permalink
renamed bunch, predict works for some cases now
Browse files Browse the repository at this point in the history
  • Loading branch information
behinger committed Mar 21, 2024
1 parent cd9a931 commit ee25c3c
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 173 deletions.
12 changes: 6 additions & 6 deletions ext/UnfoldMixedModelsExt/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function StatsModels.modelmatrix(
modelmatrix2 = Unfold.modelmatrices(Xs[k])

@debug typeof(modelmatrix1), typeof(modelmatrix2)
Xcomb_temp = Unfold.equalize_lengths(modelmatrix1, modelmatrix2)
Xcomb_temp = Unfold.extend_to_larger(modelmatrix1, modelmatrix2)
@debug "tmp" typeof(Xcomb_temp)
Xcomb = lmm_combine_modelmatrices!(Xcomb_temp, Xcomb, Xs[k])
@debug "Xcomb" typeof(Xcomb)
Expand Down Expand Up @@ -76,8 +76,8 @@ function StatsModels.fit!(
nchan = size(data, 1)


Xs = (Unfold.equalize_lengths(Xs[1]), Xs[2:end]...)#(Unfold.equalize_lengths(Xs[1]), Xs[2:end]...)
_, data = Unfold.zeropad(Xs[1], data)
Xs = (Unfold.extend_to_larger(Xs[1]), Xs[2:end]...)#(Unfold.extend_to_larger(Xs[1]), Xs[2:end]...)
_, data = Unfold.equalize_size(Xs[1], data)
# get a un-fitted mixed model object

Xs = (disallowmissing(Xs[1]), Xs[2:end]...)
Expand Down Expand Up @@ -197,15 +197,15 @@ function LinearMixedModel_wrapper(
) where {TData<:Number}
# function LinearMixedModel_wrapper(form,data::Array{<:Union{Missing,TData},1},Xs;wts = []) where {TData<:Number}
@debug "LMM wrapper, $(typeof(Xs))"
Xs = (Unfold.equalize_lengths(Xs[1]), Xs[2:end]...)
# XXX Push this to utilities zeropad
Xs = (Unfold.extend_to_larger(Xs[1]), Xs[2:end]...)
# XXX Push this to utilities equalize_size
# Make sure X & y are the same size
@assert isa(Xs[1], AbstractMatrix) & isa(Xs[2], ReMat) "Xs[1] was a $(typeof(Xs[1])), should be a AbstractMatrix, and Xs[2] was a $(typeof(Xs[2])) should be a ReMat"
m = size(Xs[1])[1]


if m != size(data)[1]
fe, data = Unfold.zeropad(Xs[1], data)
fe, data = Unfold.equalize_size(Xs[1], data)

Xs = change_modelmatrix_size!(size(data)[1], fe, Xs[2:end])
end
Expand Down
22 changes: 9 additions & 13 deletions src/basisfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ $(FIELDS)
julia> b = FIRBasis(range(0,1,length=10),"basisA",-1)
```
"""
struct FIRBasis <: BasisFunction
" vector of times along rows of kernel-output (in seconds)"
mutable struct FIRBasis <: BasisFunction
"vector of times along rows of kernel-output (in seconds)"
times::Vector
"name of the event, random 1:1000 if unspecified"
name::String
"name of the event, should be the actual eventName in `eventcolumn` of the dataframes later"
name::Any
"by how many samples do we need to shift the event onsets? This number is determined by how many 'negative' timepoints the basisfunction defines"
shift_onset::Int64
end
Expand All @@ -48,7 +48,7 @@ collabel(basis::FIRBasis) = :time
colnames(basis::FIRBasis) = basis.times[1:end-1]


struct SplineBasis <: BasisFunction
mutable struct SplineBasis <: BasisFunction
kernel::Function

"vector of names along columns of kernel-output"
Expand All @@ -62,7 +62,7 @@ struct SplineBasis <: BasisFunction
end


struct HRFBasis <: BasisFunction
mutable struct HRFBasis <: BasisFunction
kernel::Function
"vector of names along columns of kernel-output"
colnames::AbstractVector
Expand Down Expand Up @@ -91,7 +91,7 @@ julia> f(103.3)
```
"""
function firbasis(τ, sfreq, name::String = "basis_" * string(rand(1:10000)))
function firbasis(τ, sfreq, name::String = "")
τ = round_times(τ, sfreq)
times = range(τ[1], stop = τ[2] + 1 ./ sfreq, step = 1 ./ sfreq) # stop + 1 step, because we support fractional event-timings

Expand All @@ -101,7 +101,7 @@ function firbasis(τ, sfreq, name::String = "basis_" * string(rand(1:10000)))
end
# cant multiple dispatch on optional arguments
#firbasis(;τ,sfreq) = firbasis(τ,sfreq)
firbasis(; τ, sfreq, name = "basis_" * string(rand(1:10000))) = firbasis(τ, sfreq, name)
firbasis(; τ, sfreq) = firbasis(τ, sfreq)


"""
Expand Down Expand Up @@ -164,11 +164,7 @@ julia> f(103.3,4.1)
"""
function hrfbasis(
TR::Float64;
parameters = [6.0 16.0 1.0 1.0 6.0 0.0 32.0],
name::String = "basis_" * string(rand(1:10000)),
)
function hrfbasis(TR::Float64; parameters = [6.0 16.0 1.0 1.0 6.0 0.0 32.0], name = "")
# Haemodynamic response function adapted from SPM12b "spm_hrf.m"
# Parameters:
# defaults
Expand Down
47 changes: 32 additions & 15 deletions src/designmatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ function designmatrix(
tbl,
basisfunction;
contrasts = Dict{Symbol,Any}(),
eventname = Any,
kwargs...,
)
@debug("generating DesignMatrix")
Expand Down Expand Up @@ -95,8 +96,12 @@ function designmatrix(
@debug "applying schema $unfoldmodeltype"
form = unfold_apply_schema(unfoldmodeltype, f, schema(f, tbl_nomissing, contrasts))

form =
apply_basisfunction(form, basisfunction, get(Dict(kwargs), :eventfields, nothing))
form = apply_basisfunction(
form,
basisfunction,
get(Dict(kwargs), :eventfields, nothing),
eventname,
)

# Evaluate the designmatrix

Expand Down Expand Up @@ -201,14 +206,21 @@ function designmatrix(
eventTbl,
collect(f[bIx])[1];
contrasts = contrasts,
eventname = eventname,
kwargs...,
)
else
# normal way
@debug f
X =
X +
designmatrix(typeof(uf), f[fIx], eventTbl; contrasts = contrasts, kwargs...)
X + designmatrix(
typeof(uf),
f[fIx],
eventTbl;
contrasts = contrasts,
eventname = eventname,
kwargs...,
)
end
end
return X
Expand All @@ -222,12 +234,17 @@ $(SIGNATURES)
timeexpand the rhs-term of the formula with the basisfunction
"""
function apply_basisfunction(form, basisfunction::BasisFunction, eventfields)
function apply_basisfunction(form, basisfunction::BasisFunction, eventfields, eventname)
@debug("apply_basisfunction")
if basisfunction.name == ""
basisfunction.name = eventname
elseif basisfunction.name != eventname
@error "since unfold 0.7 basisfunction names need to be equivalent to the event.name (or =\"\" for autofilling)."
end
return FormulaTerm(form.lhs, TimeExpandedTerm(form.rhs, basisfunction, eventfields))
end

function apply_basisfunction(form, basisfunction::Nothing, eventfields)
function apply_basisfunction(form, basisfunction::Nothing, eventfields, eventname)
# in case of no basisfunctin, do nothing
return form
end
Expand Down Expand Up @@ -256,21 +273,21 @@ function StatsModels.modelmatrix(uf::UnfoldLinearModel, basisfunction)
end

# catch all case
equalize_lengths(modelmatrix::AbstractMatrix) = modelmatrix
extend_to_larger(modelmatrix::AbstractMatrix) = modelmatrix

# UnfoldLinearMixedModelContinuousTime case
equalize_lengths(modelmatrix::Tuple) =
(equalize_lengths(modelmatrix[1]), modelmatrix[2:end]...)
extend_to_larger(modelmatrix::Tuple) =
(extend_to_larger(modelmatrix[1]), modelmatrix[2:end]...)

# UnfoldLinearModel - they have to be equal already
equalize_lengths(modelmatrix::Vector{<:AbstractMatrix}) = modelmatrix
extend_to_larger(modelmatrix::Vector{<:AbstractMatrix}) = modelmatrix

#UnfoldLinearModelContinuousTime
equalize_lengths(modelmatrix::Vector{<:SparseMatrixCSC}) = equalize_lengths(modelmatrix...)
equalize_lengths(modelmatrix1::SparseMatrixCSC, modelmatrix2::SparseMatrixCSC, args...) =
equalize_lengths(equalize_lengths(modelmatrix1, modelmatrix2), args...)
extend_to_larger(modelmatrix::Vector{<:SparseMatrixCSC}) = extend_to_larger(modelmatrix...)
extend_to_larger(modelmatrix1::SparseMatrixCSC, modelmatrix2::SparseMatrixCSC, args...) =
extend_to_larger(extend_to_larger(modelmatrix1, modelmatrix2), args...)

function equalize_lengths(modelmatrix1::SparseMatrixCSC, modelmatrix2::SparseMatrixCSC)
function extend_to_larger(modelmatrix1::SparseMatrixCSC, modelmatrix2::SparseMatrixCSC)
sX1 = size(modelmatrix1, 1)
sX2 = size(modelmatrix2, 1)

Expand Down Expand Up @@ -316,7 +333,7 @@ modelcols_nobasis(f::FormulaTerm, tbl::AbstractDataFrame) = modelcols(f.rhs.term
StatsModels.modelmatrix(uf::UnfoldModel) = modelmatrix(designmatrix(uf))#modelmatrix(uf.design,uf.designmatrix.events)
StatsModels.modelmatrix(d::AbstractDesignMatrix) = modelmatrices(d)
StatsModels.modelmatrix(d::Vector{<:AbstractDesignMatrix}) =
equalize_lengths(modelmatrices.(d))
extend_to_larger(modelmatrices.(d))


"""
Expand Down
2 changes: 1 addition & 1 deletion src/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function effects(design::AbstractDict, model::T; typical = mean) where {T<:Unfol
@debug form_typical
#@debug "type form_typical[1]", typeof(form_typical[1])
eff = predict(model, form_typical, reference_grid)

@debug "eff" typeof(eff), typeof(form_typical)
# because coefficients are 2D/3D arry, we have to cast it correctly to one big dataframe
if isa(eff, Tuple)
# TimeContinuous Model, we also get back other things like times & fromWhereToWhere a BasisFunction goes
Expand Down
2 changes: 1 addition & 1 deletion src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ function StatsModels.fit!(


# mass univariate, data = ch x times x epochs
X, data = zeropad(X, data)
X, data = equalize_size(X, data)
@debug typeof(uf.modelfit), typeof(T), typeof(X), typeof(data)
uf.modelfit = solver(X, data)
return uf
Expand Down
Loading

0 comments on commit ee25c3c

Please sign in to comment.