diff --git a/src/predict.jl b/src/predict.jl index a5f28358..c5a4d680 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -16,7 +16,12 @@ function StatsBase.predict(model::UnfoldModel, events) timesVec = gen_timeev(times(model), size(newevents, 1)) else @debug size(formulas), typeof(model) - fromTo, timesVec, eff = yhat(model, formulas, newevents) + + + eff = yhat(model, formulas, newevents) + @debug typeof(model) + timesVec = yhat_timevec(model, formulas, newevents) + fromTo = yhat_nranges(model, formulas, newevents) # formerly fromTo == n_ranges end @@ -42,17 +47,19 @@ function StatsBase.predict(model::UnfoldModel, events) # shift variable to keep track of multiple basisfunctions shift = 0 # for each basis function - - for (bIx, basisfun) in enumerate(fromTo) + @debug "fromTo" fromTo + for (bIx, n_range) in enumerate(fromTo) + @debug "n_range" typeof(n_range[1:end]), size(events) + #basistime = range(1, step = n_range, length = size(events, 1)) # go through all predictors - for (i, fstart) in enumerate(basisfun[1:end]) - - fend = basisfun[i] + basisfun.step - 1 + for (i, fstart) in enumerate(n_range[1:end]) + fend = n_range[i] + n_range.step .- 1 + #fend = fstart + basistime.step - 1 # couldn't figure out how to broadcast everything directly (i.e out[fstart:fend,names(newevents)] .= newevents[i,:]) # copy the correct metadata - + @debug fstart, fend for j = fstart:fend metaData[j+shift, names(newevents)] = newevents[i, :] end @@ -62,7 +69,7 @@ function StatsBase.predict(model::UnfoldModel, events) end # the next meta data has to be at the end - shift += basisfun[end] - 1 + basisfun.step + shift += n_range[end] - 1 + n_range.step end end @@ -119,92 +126,118 @@ yhat(model::UnfoldLinearModelContinuousTime, formulas::FormulaTerm, newevents) = end -yhat(model::UnfoldLinearModelContinuousTime, formulas::MatrixTerm, events) = - yhat(model, formulas.terms, events) +@traitfn function yhat( + m::T, + f::AbstractArray, + events, +) where {T <: UnfoldModel; ContinuousTimeTrait{T}} + @debug f + yhat(m, blockdiag(yhat.(Ref(m), f, Ref(events))...)) # blockdiag([Xsingle1,Xsingle2,Xsingle3]...) +end + +@traitfn yhat( + model::T, + f::MatrixTerm, + events, +) where {T <: UnfoldModel; ContinuousTimeTrait{T}} = yhat(m, f.terms, events) # terms[1] or not?!? + +@traitfn yhat( + model::T, + f::FormulaTerm, + events, +) where {T <: UnfoldModel; ContinuousTimeTrait{T}} = yhat(m, f.rhs, events) + + @traitfn function yhat( model::T, - formulas, + f::TimeExpandedTerm, events, ) where {T <: UnfoldModel; ContinuousTimeTrait{T}} - @debug "Yes ContinuousTime yhat" - X = AbstractArray[] - fromTo = [] - timesVec = [] - for f in formulas - - # due to many reasons, there can be several different options of how formula arrives here - not easy to catch via multiple dispatch - if !(isa(f, TimeExpandedTerm)) - if !(isa(f, MatrixTerm)) - f = f.rhs - #elseif !(isa(f,Effects.TypicalTerm)) - - else - f = f.terms[1] - end - end - # find out how long each designmatrix is - n_range = length(times(f.basisfunction)) - # find out how much to shift so that X[1,:] is the the first "sample" - n_negative = f.basisfunction.shift_onset - # generate the correct eventfields (default: latencies) - events[:, f.eventfields[1]] = - range(-n_negative + 1, step = n_range, length = size(events, 1)) - - - - # get the model - @debug f, first(events) - Xsingle = modelcols(f, events) - - timesSingle = times(f.basisfunction) - - # this pertains only to FIR-models - - # remove the last time-point because it is attached due to non-integer latency/eventonsets. - # e.g. x denotes a sample - # x- - -x- - -x- - -x - # e- - - -f- - - -g- - # - # e is aligned (integer) with the sample - # f&g are between two samples, thus the design matrix would interpolate between them. Thus has as a result, that the designmatrix is +1 longer than what would naively be expected - # - # because in "predict" we define where samples onset, we can remove the last sample, it s always 0 anyway, but to be sure we test it - - if typeof(f.basisfunction) <: FIRBasis - - keep = fill(true, size(Xsingle, 1)) - keep[range( - length(timesSingle), - size(Xsingle, 1), - step = length(timesSingle), - )] .= false - - Xsingle = Xsingle[keep, :] # unfortunately @view is not compatible with blockdiag of SparseArrays. - timesSingle = timesSingle[1:end-1] - n_range = n_range - 1 - end - # combine designmats - append!(X, [Xsingle]) + # find out how long each designmatrix is + n_range = length(times(f.basisfunction)) + # find out how much to shift so that X[1,:] is the the first "sample" + n_negative = f.basisfunction.shift_onset + # generate the correct eventfields (default: latencies) + events = deepcopy(events) + + events[:, f.eventfields[1]] = + range(-n_negative + 1, step = n_range, length = size(events, 1)) + # get the model + Xsingle = modelcols(f, events) + + timesSingle = times(f.basisfunction) + + # this pertains only to FIR-models + + # remove the last time-point because it is attached due to non-integer latency/eventonsets. + # e.g. x denotes a sample + # x- - -x- - -x- - -x + # e- - - -f- - - -g- + # + # e is aligned (integer) with the sample + # f&g are between two samples, thus the design matrix would interpolate between them. Thus has as a result, that the designmatrix is +1 longer than what would naively be expected + # + # because in "predict" we define where samples onset, we can remove the last sample, it s always 0 anyway, but to be sure we test it + + if typeof(f.basisfunction) <: FIRBasis + keep = ones(size(Xsingle, 1)) + keep[range(length(timesSingle), size(Xsingle, 1), step = length(timesSingle))] .= 0 + Xsingle = Xsingle[keep.==1, :] + end + return Xsingle + # don't include, via broadcast: append!(X, [Xsingle]) +end - # keep track of how long each event is - append!(fromTo, [range(1, step = n_range, length = size(events, 1))]) - # keep track of the times - append!(timesVec, repeat(timesSingle, size(events, 1))) - end - # Concat them, but without introducing any overlap, we want the isolated responses - Xconcat = blockdiag(X...) - # calculate yhat - eff = yhat(model, Xconcat) - # output is a bit ugly, but we need the other two vectors as well. Should be refactored at some point - but not right now ;-) #XXX - return fromTo, timesVec, eff +@traitfn yhat_nranges( + model::T, + formulas::Array{<:FormulaTerm}, + events, +) where {T <: UnfoldModel; ContinuousTimeTrait{T}} = + yhat_nranges.(Ref(model), formulas, Ref(events)) # todo: add hcat? vcat? who knows + + +@traitfn function yhat_nranges( + model::T, + f::FormulaTerm, + events, +) where {T <: UnfoldModel; ContinuousTimeTrait{T}} + n_range = length(times(f.rhs.basisfunction)) + if typeof(f.rhs.basisfunction) <: FIRBasis + n_range = n_range - 1 + end + return range(1, step = n_range, length = size(events, 1)) +end + +@traitfn yhat_timevec( + model::T, + formulas::Array{<:FormulaTerm}, + events, +) where {T <: UnfoldModel; ContinuousTimeTrait{T}} = + yhat_timevec.(Ref(model), formulas, Ref(events)) + +@traitfn function yhat_timevec( + model::T, + f::FormulaTerm, + events, +) where {T <: UnfoldModel; ContinuousTimeTrait{T}} + @debug typeof(f) + # keep track of the times + timesSingle = times(f.rhs.basisfunction) + # see yhat blabla why this is necessary + if typeof(f.rhs.basisfunction) <: FIRBasis + timesSingle = timesSingle[1:end-1] + end + return repeat(timesSingle, size(events, 1)) end + + function yhat( model::UnfoldLinearModelContinuousTime, X::AbstractArray{T,2};