Skip to content

Commit

Permalink
Refactor PERK2 time integration method to use fixed time step (trixi-…
Browse files Browse the repository at this point in the history
…framework#1958)

* add fixed time step and test

* fmt

* fix test values

* Update test_tree_1d_advection.jl

* adjustment so that fixed timestep work with savesolcallback when dt is specified

* attempt to avoid argument error

* test with save solution

* add save_solution to elixir

* update test value

---------

Co-authored-by: Daniel Doehring <[email protected]>
  • Loading branch information
warisa-r and DanielDoehring authored Jul 8, 2024
1 parent 4bce711 commit ea2dd3a
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 14 deletions.
8 changes: 7 additions & 1 deletion examples/tree_1d_dgsem/elixir_advection_perk2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,15 @@ stepsize_callback = StepsizeCallback(cfl = 2.5)

alive_callback = AliveCallback(alive_interval = analysis_interval)

save_solution = SaveSolutionCallback(dt = 0.1,
save_initial_solution = true,
save_final_solution = true,
solution_variables = cons2prim)

# Create a CallbackSet to collect all callbacks such that they can be passed to the ODE solver
callbacks = CallbackSet(summary_callback,
alive_callback,
save_solution,
analysis_callback,
stepsize_callback)

Expand All @@ -59,7 +65,7 @@ callbacks = CallbackSet(summary_callback,
ode_algorithm = Trixi.PairedExplicitRK2(6, tspan, semi)

sol = Trixi.solve(ode, ode_algorithm,
dt = 1.0, # solve needs some value here but it will be overwritten by the stepsize_callback
dt = 1.0, # Manual time step value, will be overwritten by the stepsize_callback when it is specified.
save_everystep = false, callback = callbacks);

# Print the timer summary
Expand Down
78 changes: 67 additions & 11 deletions src/time_integration/paired_explicit_runge_kutta/methods_PERK2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,25 @@ function PairedExplicitRK2(num_stages, tspan, eig_vals::Vector{ComplexF64};
end

# This struct is needed to fake https://github.com/SciML/OrdinaryDiffEq.jl/blob/0c2048a502101647ac35faabd80da8a5645beac7/src/integrators/type.jl#L1
mutable struct PairedExplicitRKOptions{Callback}
mutable struct PairedExplicitRKOptions{Callback, TStops}
callback::Callback # callbacks; used in Trixi
adaptive::Bool # whether the algorithm is adaptive; ignored
adaptive::Bool # whether the algorithm is adaptive
dtmax::Float64 # ignored
maxiters::Int # maximal number of time steps
tstops::Vector{Float64} # tstops from https://diffeq.sciml.ai/v6.8/basics/common_solver_opts/#Output-Control-1; ignored
tstops::TStops # tstops from https://diffeq.sciml.ai/v6.8/basics/common_solver_opts/#Output-Control-1; ignored
end

function PairedExplicitRKOptions(callback, tspan; maxiters = typemax(Int), kwargs...)
PairedExplicitRKOptions{typeof(callback)}(callback, false, Inf, maxiters,
[last(tspan)])
tstops_internal = BinaryHeap{eltype(tspan)}(FasterForward())
# We add last(tspan) to make sure that the time integration stops at the end time
push!(tstops_internal, last(tspan))
# We add 2 * last(tspan) because add_tstop!(integrator, t) is only called by DiffEqCallbacks.jl if tstops contains a time that is larger than t
# (https://github.com/SciML/DiffEqCallbacks.jl/blob/025dfe99029bd0f30a2e027582744528eb92cd24/src/iterative_and_periodic.jl#L92)
push!(tstops_internal, 2 * last(tspan))
PairedExplicitRKOptions{typeof(callback), typeof(tstops_internal)}(callback,
false, Inf,
maxiters,
tstops_internal)
end

abstract type PairedExplicitRK end
Expand All @@ -207,20 +215,42 @@ mutable struct PairedExplicitRK2Integrator{RealT <: Real, uType, Params, Sol, F,
du::uType
u_tmp::uType
t::RealT
tdir::RealT
dt::RealT # current time step
dtcache::RealT # ignored
dtcache::RealT # manually set time step
iter::Int # current number of time steps (iteration)
p::Params # will be the semidiscretization from Trixi
sol::Sol # faked
f::F
alg::Alg # This is our own class written above; Abbreviation for ALGorithm
opts::PairedExplicitRKOptions
finalstep::Bool # added for convenience
dtchangeable::Bool
force_stepfail::Bool
# PairedExplicitRK2 stages:
k1::uType
k_higher::uType
end

"""
add_tstop!(integrator::PairedExplicitRK2Integrator, t)
Add a time stop during the time integration process.
This function is called after the periodic SaveSolutionCallback to specify the next stop to save the solution.
"""
function add_tstop!(integrator::PairedExplicitRK2Integrator, t)
integrator.tdir * (t - integrator.t) < zero(integrator.t) &&
error("Tried to add a tstop that is behind the current time. This is strictly forbidden")
# We need to remove the first entry of tstops when a new entry is added.
# Otherwise, the simulation gets stuck at the previous tstop and dt is adjusted to zero.
if length(integrator.opts.tstops) > 1
pop!(integrator.opts.tstops)
end
push!(integrator.opts.tstops, integrator.tdir * t)
end

has_tstop(integrator::PairedExplicitRK2Integrator) = !isempty(integrator.opts.tstops)
first_tstop(integrator::PairedExplicitRK2Integrator) = first(integrator.opts.tstops)

# Forward integrator.stats.naccept to integrator.iter (see GitHub PR#771)
function Base.getproperty(integrator::PairedExplicitRK, field::Symbol)
if field === :stats
Expand All @@ -241,15 +271,16 @@ function init(ode::ODEProblem, alg::PairedExplicitRK2;
k_higher = zero(u0)

t0 = first(ode.tspan)
tdir = sign(ode.tspan[end] - ode.tspan[1])
iter = 0

integrator = PairedExplicitRK2Integrator(u0, du, u_tmp, t0, dt, zero(dt), iter,
integrator = PairedExplicitRK2Integrator(u0, du, u_tmp, t0, tdir, dt, dt, iter,
ode.p,
(prob = ode,), ode.f, alg,
PairedExplicitRKOptions(callback,
ode.tspan;
kwargs...),
false,
false, true, false,
k1, k_higher)

# initialize callbacks
Expand Down Expand Up @@ -301,6 +332,8 @@ function step!(integrator::PairedExplicitRK2Integrator)
error("time step size `dt` is NaN")
end

modify_dt_for_tstops!(integrator)

# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
Expand Down Expand Up @@ -383,17 +416,40 @@ u_modified!(integrator::PairedExplicitRK, ::Bool) = false

# used by adaptive timestepping algorithms in DiffEq
function set_proposed_dt!(integrator::PairedExplicitRK, dt)
integrator.dt = dt
(integrator.dt = dt; integrator.dtcache = dt)
end

function get_proposed_dt(integrator::PairedExplicitRK)
return integrator.dt
return ifelse(integrator.opts.adaptive, integrator.dt, integrator.dtcache)
end

# stop the time integration
function terminate!(integrator::PairedExplicitRK)
integrator.finalstep = true
empty!(integrator.opts.tstops)
end

"""
modify_dt_for_tstops!(integrator::PairedExplicitRK)
Modify the time-step size to match the time stops specified in integrator.opts.tstops.
To avoid adding OrdinaryDiffEq to Trixi's dependencies, this routine is a copy of
https://github.com/SciML/OrdinaryDiffEq.jl/blob/d76335281c540ee5a6d1bd8bb634713e004f62ee/src/integrators/integrator_utils.jl#L38-L54
"""
function modify_dt_for_tstops!(integrator::PairedExplicitRK)
if has_tstop(integrator)
tdir_t = integrator.tdir * integrator.t
tdir_tstop = first_tstop(integrator)
if integrator.opts.adaptive
integrator.dt = integrator.tdir *
min(abs(integrator.dt), abs(tdir_tstop - tdir_t)) # step! to the end
elseif iszero(integrator.dtcache) && integrator.dtchangeable
integrator.dt = integrator.tdir * abs(tdir_tstop - tdir_t)
elseif integrator.dtchangeable && !integrator.force_stepfail
# always try to step! with dtcache, but lower if a tstop
# however, if force_stepfail then don't set to dtcache, and no tstop worry
integrator.dt = integrator.tdir *
min(abs(integrator.dtcache), abs(tdir_tstop - tdir_t)) # step! to the end
end
end
end

# used for AMR (Adaptive Mesh Refinement)
Expand Down
24 changes: 22 additions & 2 deletions test/test_tree_1d_advection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,28 @@ end

@trixi_testset "elixir_advection_perk2.jl" begin
@test_trixi_include(joinpath(EXAMPLES_DIR, "elixir_advection_perk2.jl"),
l2=[0.014139242834192841],
linf=[0.01999756655819429])
l2=[0.011288030389423475],
linf=[0.01596735472556976])
# Ensure that we do not have excessive memory allocations
# (e.g., from type instabilities)
let
t = sol.t[end]
u_ode = sol.u[end]
du_ode = similar(u_ode)
@test (@allocated Trixi.rhs!(du_ode, u_ode, semi, t)) < 8000
end
end

# Testing the second-order paired explicit Runge-Kutta (PERK) method without stepsize callback
@trixi_testset "elixir_advection_perk2.jl(fixed time step)" begin
@test_trixi_include(joinpath(EXAMPLES_DIR, "elixir_advection_perk2.jl"),
dt=2.0e-3,
tspan=(0.0, 20.0),
save_solution=SaveSolutionCallback(dt = 0.1 + 1.0e-8),
callbacks=CallbackSet(summary_callback, save_solution,
analysis_callback, alive_callback),
l2=[9.886271430207691e-6],
linf=[3.729460413781638e-5])
# Ensure that we do not have excessive memory allocations
# (e.g., from type instabilities)
let
Expand Down

0 comments on commit ea2dd3a

Please sign in to comment.