diff --git a/src/stokes/PressureKernels.jl b/src/stokes/PressureKernels.jl index fe966fdf..500ab55c 100644 --- a/src/stokes/PressureKernels.jl +++ b/src/stokes/PressureKernels.jl @@ -59,11 +59,12 @@ function compute_P!( r, θ_dτ; ΔTc=nothing, + ϕ=nothing, kwargs..., ) where {N} ni = size(P) @parallel (@idx ni) compute_P_kernel!( - P, P0, RP, ∇V, η, rheology, phase_ratio, dt, r, θ_dτ, ΔTc + P, P0, RP, ∇V, η, rheology, phase_ratio, dt, r, θ_dτ, ΔTc, ϕ ) return nothing end @@ -80,6 +81,7 @@ end r, θ_dτ, ::Nothing, + ::Nothing, ) where {N,C<:JustRelax.CellArray} K = fn_ratio(get_bulk_modulus, rheology, phase_ratio[I...]) RP[I...], P[I...] = _compute_P!(P[I...], P0[I...], ∇V[I...], η[I...], K, dt, r, θ_dτ) @@ -87,7 +89,7 @@ end end @parallel_indices (I...) function compute_P_kernel!( - P, P0, RP, ∇V, η, rheology::NTuple{N,MaterialParams}, phase_ratio::C, dt, r, θ_dτ, ΔTc + P, P0, RP, ∇V, η, rheology::NTuple{N,MaterialParams}, phase_ratio::C, dt, r, θ_dτ, ΔTc, ::Nothing ) where {N,C<:JustRelax.CellArray} K = fn_ratio(get_bulk_modulus, rheology, phase_ratio[I...]) α = fn_ratio(get_thermal_expansion, rheology, phase_ratio[I...]) @@ -97,6 +99,17 @@ end return nothing end +@parallel_indices (I...) function compute_P_kernel!( + P, P0, RP, ∇V, η, rheology::NTuple{N,MaterialParams}, phase_ratio::C, dt, r, θ_dτ, ΔTc, ϕ +) where {N,C<:JustRelax.CellArray} + K = fn_ratio(get_bulk_modulus, rheology, phase_ratio[I...]) + α = fn_ratio(get_thermal_expansion, rheology, phase_ratio[I...], (;ϕ=ϕ[I...])) + RP[I...], P[I...] = _compute_P!( + P[I...], P0[I...], ∇V[I...], ΔTc[I...], α, η[I...], K, dt, r, θ_dτ + ) + return nothing +end + # Pressure innermost kernels function _compute_P!(P, ∇V, η, r, θ_dτ) @@ -108,9 +121,9 @@ end function _compute_P!(P, P0, ∇V, η, K, dt, r, θ_dτ) _Kdt = inv(K * dt) RP = fma(-(P - P0), _Kdt, -∇V) - ψ = 1 / (1.0 / (r / θ_dτ * η) + 1.0 * _Kdt) - # P = ((fma(P0, _Kdt, -∇V)) * ψ + P) / (1 + _Kdt * ψ) - P += RP / (1.0 / (r / θ_dτ * η) + 1.0 * _Kdt) + ψ = inv(inv(r / θ_dτ * η) + _Kdt) + P = ((fma(P0, _Kdt, -∇V)) * ψ + P) / (1 + _Kdt * ψ) + # P += RP / (1.0 / (r / θ_dτ * η) + 1.0 * _Kdt) return RP, P end @@ -118,7 +131,7 @@ function _compute_P!(P, P0, ∇V, ΔTc, α, η, K, dt, r, θ_dτ) _Kdt = inv(K * dt) _dt = inv(dt) RP = fma(-(P - P0), _Kdt, (-∇V + (α * (ΔTc * _dt)))) - ψ = 1 / (1.0 / (r / θ_dτ * η) + 1.0 * _Kdt) + ψ = inv(inv(r / θ_dτ * η) + _Kdt) P = ((fma(P0, _Kdt, (-∇V + (α * (ΔTc * _dt))))) * ψ + P) / (1 + _Kdt * ψ) # P += RP / (1.0 / (r / θ_dτ * η) + 1.0 * _Kdt)