Skip to content

Commit

Permalink
fix bug in phase ratios
Browse files Browse the repository at this point in the history
  • Loading branch information
albert-de-montserrat authored May 6, 2024
1 parent 49542d2 commit a785464
Showing 1 changed file with 17 additions and 43 deletions.
60 changes: 17 additions & 43 deletions src/phases/phases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,50 +89,24 @@ end
function phase_ratio_weights(
pxi::NTuple{NP,C}, ph::SVector{N1,T}, cell_center, di, ::Val{NC}
) where {N1,NC,NP,T,C}
if @generated
quote
Base.@_inline_meta
# Initiaze phase ratio weights (note: can't use ntuple() here because of the @generated function)
Base.@nexprs $NC i -> w_i = zero($T)
w = Base.@ncall $NC tuple w

# initialie sum of weights
sumw = zero($T)
Base.@nexprs $N1 i -> begin
# bilinear weight (1-(xᵢ-xc)/dx)*(1-(yᵢ-yc)/dy)
x = bilinear_weight(cell_center, getindex.(pxi, i), di)
sumw += x # reduce
ph_local = ph[i]
# this is doing sum(w * δij(i, phase)), where δij is the Kronecker delta
# Base.@nexprs $NC j -> tmp_j = w[j] + x * δ(Int(ph_local), j)
Base.@nexprs $NC j -> tmp_j = w[j] + x * (ph_local == j)
w = Base.@ncall $NC tuple tmp
end

# return phase ratios weights w = sum(w * δij(i, phase)) / sum(w)
_sumw = inv(sum(w))
Base.@nexprs $NC i -> w_i = w[i] * _sumw
w = Base.@ncall $NC tuple w
return w
end
else
# Initiaze phase ratio weights (note: can't use ntuple() here because of the @generated function)
w = ntuple(_ -> zero(T), Val(NC))
# initialie sum of weights
sumw = zero(T)

for i in eachindex(pxi)
# bilinear weight (1-(xᵢ-xc)/dx)*(1-(yᵢ-yc)/dy)
x = @inline bilinear_weight(cell_center, getindex.(pxi, i), di)
sumw += x # reduce
ph_local = ph[i]
# this is doing sum(w * δij(i, phase)), where δij is the Kronecker delta
# w = w .+ x .* ntuple(j -> δ(Int(ph_local), j), Val(NC))
w = w .+ x .* ntuple(j -> (ph_local == j), Val(NC))
end
w = w .* inv(sum(w))
return w

# Initiaze phase ratio weights (note: can't use ntuple() here because of the @generated function)
w = ntuple(_ -> zero(T), Val(NC))
sumw = zero(T)

for i in eachindex(ph)
# bilinear weight (1-(xᵢ-xc)/dx)*(1-(yᵢ-yc)/dy)
p = getindex.(pxi, i)
isnan(first(p)) && continue
x = @inline bilinear_weight(cell_center, p, di)
sumw += x # reduce
ph_local = ph[i]
# this is doing sum(w * δij(i, phase)), where δij is the Kronecker delta
# w = w .+ x .* ntuple(j -> δ(Int(ph_local), j), Val(NC))
w = w .+ x .* ntuple(j -> (ph_local == j), Val(NC))
end
w = w .* inv(sumw)
return w
end

@generated function bilinear_weight(
Expand Down

0 comments on commit a785464

Please sign in to comment.