Skip to content

Commit

Permalink
included everything from branch refactorPredict
Browse files Browse the repository at this point in the history
  • Loading branch information
behinger committed Mar 21, 2024
1 parent 3b3c81e commit 822a791
Showing 1 changed file with 114 additions and 81 deletions.
195 changes: 114 additions & 81 deletions src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ function StatsBase.predict(model::UnfoldModel, events)
timesVec = gen_timeev(times(model), size(newevents, 1))
else
@debug size(formulas), typeof(model)
fromTo, timesVec, eff = yhat(model, formulas, newevents)


eff = yhat(model, formulas, newevents)
@debug typeof(model)
timesVec = yhat_timevec(model, formulas, newevents)
fromTo = yhat_nranges(model, formulas, newevents) # formerly fromTo == n_ranges
end


Expand All @@ -42,17 +47,19 @@ function StatsBase.predict(model::UnfoldModel, events)
# shift variable to keep track of multiple basisfunctions
shift = 0
# for each basis function

for (bIx, basisfun) in enumerate(fromTo)
@debug "fromTo" fromTo
for (bIx, n_range) in enumerate(fromTo)
@debug "n_range" typeof(n_range[1:end]), size(events)
#basistime = range(1, step = n_range, length = size(events, 1))

# go through all predictors
for (i, fstart) in enumerate(basisfun[1:end])

fend = basisfun[i] + basisfun.step - 1
for (i, fstart) in enumerate(n_range[1:end])
fend = n_range[i] + n_range.step .- 1
#fend = fstart + basistime.step - 1

# couldn't figure out how to broadcast everything directly (i.e out[fstart:fend,names(newevents)] .= newevents[i,:])
# copy the correct metadata

@debug fstart, fend
for j = fstart:fend
metaData[j+shift, names(newevents)] = newevents[i, :]
end
Expand All @@ -62,7 +69,7 @@ function StatsBase.predict(model::UnfoldModel, events)

end
# the next meta data has to be at the end
shift += basisfun[end] - 1 + basisfun.step
shift += n_range[end] - 1 + n_range.step

end
end
Expand Down Expand Up @@ -119,92 +126,118 @@ yhat(model::UnfoldLinearModelContinuousTime, formulas::FormulaTerm, newevents) =
end


yhat(model::UnfoldLinearModelContinuousTime, formulas::MatrixTerm, events) =
yhat(model, formulas.terms, events)
@traitfn function yhat(
m::T,
f::AbstractArray,
events,
) where {T <: UnfoldModel; ContinuousTimeTrait{T}}
@debug f
yhat(m, blockdiag(yhat.(Ref(m), f, Ref(events))...)) # blockdiag([Xsingle1,Xsingle2,Xsingle3]...)
end

@traitfn yhat(
model::T,
f::MatrixTerm,
events,
) where {T <: UnfoldModel; ContinuousTimeTrait{T}} = yhat(m, f.terms, events) # terms[1] or not?!?

@traitfn yhat(
model::T,
f::FormulaTerm,
events,
) where {T <: UnfoldModel; ContinuousTimeTrait{T}} = yhat(m, f.rhs, events)



@traitfn function yhat(
model::T,
formulas,
f::TimeExpandedTerm,
events,
) where {T <: UnfoldModel; ContinuousTimeTrait{T}}
@debug "Yes ContinuousTime yhat"
X = AbstractArray[]
fromTo = []
timesVec = []
for f in formulas

# due to many reasons, there can be several different options of how formula arrives here - not easy to catch via multiple dispatch
if !(isa(f, TimeExpandedTerm))
if !(isa(f, MatrixTerm))
f = f.rhs
#elseif !(isa(f,Effects.TypicalTerm))

else
f = f.terms[1]
end

end
# find out how long each designmatrix is
n_range = length(times(f.basisfunction))
# find out how much to shift so that X[1,:] is the the first "sample"
n_negative = f.basisfunction.shift_onset
# generate the correct eventfields (default: latencies)
events[:, f.eventfields[1]] =
range(-n_negative + 1, step = n_range, length = size(events, 1))



# get the model
@debug f, first(events)
Xsingle = modelcols(f, events)

timesSingle = times(f.basisfunction)

# this pertains only to FIR-models

# remove the last time-point because it is attached due to non-integer latency/eventonsets.
# e.g. x denotes a sample
# x- - -x- - -x- - -x
# e- - - -f- - - -g-
#
# e is aligned (integer) with the sample
# f&g are between two samples, thus the design matrix would interpolate between them. Thus has as a result, that the designmatrix is +1 longer than what would naively be expected
#
# because in "predict" we define where samples onset, we can remove the last sample, it s always 0 anyway, but to be sure we test it

if typeof(f.basisfunction) <: FIRBasis

keep = fill(true, size(Xsingle, 1))
keep[range(
length(timesSingle),
size(Xsingle, 1),
step = length(timesSingle),
)] .= false

Xsingle = Xsingle[keep, :] # unfortunately @view is not compatible with blockdiag of SparseArrays.
timesSingle = timesSingle[1:end-1]
n_range = n_range - 1
end
# combine designmats
append!(X, [Xsingle])
# find out how long each designmatrix is
n_range = length(times(f.basisfunction))
# find out how much to shift so that X[1,:] is the the first "sample"
n_negative = f.basisfunction.shift_onset
# generate the correct eventfields (default: latencies)
events = deepcopy(events)

events[:, f.eventfields[1]] =
range(-n_negative + 1, step = n_range, length = size(events, 1))
# get the model
Xsingle = modelcols(f, events)

timesSingle = times(f.basisfunction)

# this pertains only to FIR-models

# remove the last time-point because it is attached due to non-integer latency/eventonsets.
# e.g. x denotes a sample
# x- - -x- - -x- - -x
# e- - - -f- - - -g-
#
# e is aligned (integer) with the sample
# f&g are between two samples, thus the design matrix would interpolate between them. Thus has as a result, that the designmatrix is +1 longer than what would naively be expected
#
# because in "predict" we define where samples onset, we can remove the last sample, it s always 0 anyway, but to be sure we test it

if typeof(f.basisfunction) <: FIRBasis
keep = ones(size(Xsingle, 1))
keep[range(length(timesSingle), size(Xsingle, 1), step = length(timesSingle))] .= 0
Xsingle = Xsingle[keep.==1, :]
end
return Xsingle
# don't include, via broadcast: append!(X, [Xsingle])
end

# keep track of how long each event is
append!(fromTo, [range(1, step = n_range, length = size(events, 1))])
# keep track of the times
append!(timesVec, repeat(timesSingle, size(events, 1)))

end

# Concat them, but without introducing any overlap, we want the isolated responses
Xconcat = blockdiag(X...)

# calculate yhat
eff = yhat(model, Xconcat)

# output is a bit ugly, but we need the other two vectors as well. Should be refactored at some point - but not right now ;-) #XXX
return fromTo, timesVec, eff
@traitfn yhat_nranges(
model::T,
formulas::Array{<:FormulaTerm},
events,
) where {T <: UnfoldModel; ContinuousTimeTrait{T}} =
yhat_nranges.(Ref(model), formulas, Ref(events)) # todo: add hcat? vcat? who knows


@traitfn function yhat_nranges(
model::T,
f::FormulaTerm,
events,
) where {T <: UnfoldModel; ContinuousTimeTrait{T}}
n_range = length(times(f.rhs.basisfunction))
if typeof(f.rhs.basisfunction) <: FIRBasis
n_range = n_range - 1
end
return range(1, step = n_range, length = size(events, 1))
end

@traitfn yhat_timevec(
model::T,
formulas::Array{<:FormulaTerm},
events,
) where {T <: UnfoldModel; ContinuousTimeTrait{T}} =
yhat_timevec.(Ref(model), formulas, Ref(events))


@traitfn function yhat_timevec(
model::T,
f::FormulaTerm,
events,
) where {T <: UnfoldModel; ContinuousTimeTrait{T}}
@debug typeof(f)
# keep track of the times
timesSingle = times(f.rhs.basisfunction)
# see yhat blabla why this is necessary
if typeof(f.rhs.basisfunction) <: FIRBasis
timesSingle = timesSingle[1:end-1]
end
return repeat(timesSingle, size(events, 1))
end


function yhat(
model::UnfoldLinearModelContinuousTime,
X::AbstractArray{T,2};
Expand Down

0 comments on commit 822a791

Please sign in to comment.