diff --git a/src/Utils.jl b/src/Utils.jl index 1d779b9d..5cc2f92f 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -415,6 +415,11 @@ macro unpack(x) end end +""" + compute_dt(S::JustRelax.StokesArrays, args...) + +Compute the time step `dt` for the simulation. +""" function compute_dt(S::JustRelax.StokesArrays, args...) return compute_dt(backend(S), S, args...) end @@ -423,51 +428,21 @@ function compute_dt(::CPUBackendTrait, S::JustRelax.StokesArrays, args...) return _compute_dt(S, args...) end -""" - compute_dt(S::JustRelax.StokesArrays, di) - -Compute the time step `dt` for the velocity field `S.V` for a regular grid with grid spacing `di`. -""" -@inline _compute_dt(S::JustRelax.StokesArrays, di) = _compute_dt(@velocity(S), di, Inf) - -""" - compute_dt(S::JustRelax.StokesArrays, di, dt_diff) +@inline _compute_dt(S::JustRelax.StokesArrays, di) = + _compute_dt(@velocity(S), di, Inf, maximum) -Compute the time step `dt` for the velocity field `S.V` and the diffusive maximum time step -`dt_diff` for a regular gridwith grid spacing `di`. -""" @inline _compute_dt(S::JustRelax.StokesArrays, di, dt_diff) = - _compute_dt(@velocity(S), di, dt_diff) + _compute_dt(@velocity(S), di, dt_diff, maximum) -@inline function _compute_dt(V::NTuple, di, dt_diff) - n = inv(length(V) + 0.1) - dt_adv = mapreduce(x -> x[1] * inv(maximum(abs.(x[2]))), min, zip(di, V)) * n - return min(dt_diff, dt_adv) -end -""" - compute_dt(S::JustRelax.StokesArrays, di, igg) - -Compute the time step `dt` for the velocity field `S.V` for a regular gridwith grid spacing `di`. -The implicit global grid variable `I` implies that the time step is calculated globally and not -separately on each block. -""" -@inline _compute_dt(S::JustRelax.StokesArrays, di, I::IGG) = - _compute_dt(@velocity(S), di, Inf, I::IGG) +@inline _compute_dt(S::JustRelax.StokesArrays, di, dt_diff, ::IGG) = + _compute_dt(@velocity(S), di, dt_diff, maximum_mpi) -""" - compute_dt(S::JustRelax.StokesArrays, di, dt_diff) - -Compute the time step `dt` for the velocity field `S.V` and the diffusive maximum time step -`dt_diff` for a regular gridwith grid spacing `di`. The implicit global grid variable `I` -implies that the time step is calculated globally and not separately on each block. -""" -@inline function _compute_dt(S::JustRelax.StokesArrays, di, dt_diff, I::IGG) - return _compute_dt(@velocity(S), di, dt_diff, I::IGG) -end +@inline _compute_dt(S::JustRelax.StokesArrays, di, ::IGG) = + _compute_dt(@velocity(S), di, Inf, maximum_mpi) -@inline function _compute_dt(V::NTuple, di, dt_diff, I::IGG) +@inline function _compute_dt(V::NTuple, di, dt_diff, max_fun::F) where {F<:Function} n = inv(length(V) + 0.1) - dt_adv = mapreduce(x -> x[1] * inv(maximum_mpi(abs.(x[2]))), max, zip(di, V)) * n + dt_adv = mapreduce(x -> x[1] * inv(max_fun(abs.(x[2]))), max, zip(di, V)) * n return min(dt_diff, dt_adv) end