Skip to content

Commit

Permalink
add timestep
Browse files Browse the repository at this point in the history
  • Loading branch information
aelligp committed Jun 3, 2024
1 parent a9c1b52 commit 431be3f
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 34 deletions.
11 changes: 7 additions & 4 deletions src/IO/H5.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ macro namevar(x)
end

"""
checkpointing_hdf5(dst, stokes, T, η, time)
checkpointing_hdf5(dst, stokes, T, η, time, timestep)
Save necessary data in `dst` as and HDF5 file to restart the model from the state at `time`
"""
function checkpointing_hdf5(dst, stokes, T, time)
function checkpointing_hdf5(dst, stokes, T, time, timestep)
!isdir(dst) && mkpath(dst) # create folder in case it does not exist
fname = joinpath(dst, "checkpoint")

Expand All @@ -27,6 +27,7 @@ function checkpointing_hdf5(dst, stokes, T, time)
tmpfname = joinpath(tmpdir, basename(fname))
h5open("$(tmpfname).h5", "w") do file
write(file, @namevar(time)...)
write(file, @namevar(timestep)...)
write(file, @namevar(stokes.V.Vx)...)
write(file, @namevar(stokes.V.Vy)...)
if !isnothing(stokes.V.Vz)
Expand Down Expand Up @@ -59,14 +60,15 @@ Load the state of the simulation from an .h5 file.
- `Vz`: The loaded state of the z-component of the velocity variable.
- `η`: The loaded state of the viscosity variable.
- `t`: The loaded simulation time.
- `dt`: The loaded simulation time.
# Example
```julia
# Define the path to the .h5 file
file_path = "path/to/your/file.h5"
# Use the load_checkpoint function to load the variables from the file
P, T, Vx, Vy, Vz, η, t = load_checkpoint(file_path)
P, T, Vx, Vy, Vz, η, t, dt = `load_checkpoint(file_path)``
"""
Expand All @@ -83,8 +85,9 @@ function load_checkpoint_hdf5(file_path)
end
η = read(h5file["η"]) # Read the stokes.viscosity.η variable
t = read(h5file["time"]) # Read the t variable
dt = read(h5file["timestep"]) # Read the t variable
close(h5file) # Close the file
return P, T, Vx, Vy, Vz, η, t
return P, T, Vx, Vy, Vz, η, t, dt
end

"""
Expand Down
30 changes: 16 additions & 14 deletions src/IO/JLD2.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
checkpointing_jld2(dst, stokes, thermal, time, igg)
checkpointing_jld2(dst, stokes, thermal, time, timestep, igg)
Save necessary data in `dst` as a jld2 file to restart the model from the state at `time`.
If run in parallel, the file will be named after the corresponidng rank e.g. `checkpoint0000.jld2`
Expand All @@ -13,8 +13,6 @@ by providing a dollar sign and the rank number.
"path/to/dst",
stokes,
thermal,
particles,
pPhases,
t,
igg,
)
Expand All @@ -24,26 +22,32 @@ by providing a dollar sign and the rank number.
checkpoint_name(dst) = "$dst/checkpoint.jld2"
checkpoint_name(dst, igg::IGG) = "$dst/checkpoint" * lpad("$(igg.me)", 4, "0") * ".jld2"

function checkpointing_jld2(dst, stokes, thermal, time)
function checkpointing_jld2(dst, stokes, thermal, time, timestep)
fname = checkpoint_name(dst)
checkpointing_jld2(dst, stokes, thermal, time, fname)
checkpointing_jld2(dst, stokes, thermal, time, timestep, fname)
return nothing
end

function checkpointing_jld2(dst, stokes, thermal, time, igg::IGG)
function checkpointing_jld2(dst, stokes, thermal, time, timestep, igg::IGG)
fname = checkpoint_name(dst, igg)
checkpointing_jld2(dst, stokes, thermal, time, fname)
checkpointing_jld2(dst, stokes, thermal, time, timestep, fname)
return nothing
end

function checkpointing_jld2(dst, stokes, thermal, time, fname::String)
function checkpointing_jld2(dst, stokes, thermal, time, timestep, fname::String)
!isdir(dst) && mkpath(dst) # create folder in case it does not exist

# Create a temporary directory
mktempdir() do tmpdir
# Save the checkpoint file in the temporary directory
tmpfname = joinpath(tmpdir, basename(fname))
jldsave(tmpfname; stokes=Array(stokes), thermal=Array(thermal), time=time)
jldsave(
tmpfname;
stokes=Array(stokes),
thermal=Array(thermal),
time=time,
timestep=timestep,
)
# Move the checkpoint file from the temporary directory to the destination directory
mv(tmpfname, fname; force=true)
end
Expand All @@ -61,16 +65,14 @@ Load the state of the simulation from a .jld2 file.
# Returns
- `stokes`: The loaded state of the stokes variable.
- `thermal`: The loaded state of the thermal variable.
- `particles`: The loaded state of the particles variable.
- `phases`: The loaded state of the phases variable.
- `time`: The loaded simulation time.
- `timestep`: The loaded time step.
"""
function load_checkpoint_jld2(file_path)
restart = load(file_path) # Load the file
stokes = restart["stokes"] # Read the stokes variable
thermal = restart["thermal"] # Read the thermal variable
time = restart["time"] # Read the time variable
return stokes, thermal, time
timestep = restart["timestep"] # Read the timestep variable
return stokes, thermal, time, timestep
end

# Use the function
41 changes: 25 additions & 16 deletions test/test_checkpointing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const backend = CPUBackend # Options: CPUBackend, CUDABackend, AMDGPUBackend
# Load script dependencies
using GeoParams

@testset "Test checkpointing" begin
@testset "Test Checkpointing and Metadata" begin
@suppress begin
# Set up mock data
# Physical domain ------------------------------------
Expand All @@ -42,38 +42,46 @@ using GeoParams
# temperature
pT, pPhases = init_cell_arrays(particles, Val(2))
time = 1.0
dt = 0.1

stokes.viscosity.η .= fill(1.0)
stokes.V.Vy .= fill(10)
thermal.T .= fill(100)
stokes.viscosity.η .= @fill(1.0)
stokes.V.Vy .= @fill(10)
thermal.T .= @fill(100)

# Save metadata to directory
metadata(pwd(), dst, "test_traits.jl", "test_types.jl")
@test isfile(joinpath(dst, "test_traits.jl"))
@test isfile(joinpath(dst, "test_types.jl"))
@test isfile(joinpath(dst, "Manifest.toml"))
@test isfile(joinpath(dst, "Project.toml"))

# Call the function
checkpointing_jld2(dst, stokes, thermal, time, igg)
checkpointing_jld2(dst, stokes, thermal, time)
checkpointing_jld2(dst, stokes, thermal, time, dt, igg)
checkpointing_jld2(dst, stokes, thermal, time, dt)

# Check that the file was created
fname = joinpath(dst, "checkpoint" * lpad("$(igg.me)", 4, "0") * ".jld2")
@test isfile(fname)

# Load the data from the file
load_checkpoint_jld2(fname)
stokes1, thermal1, t, dt1 = load_checkpoint_jld2(fname)

@test stokes.viscosity.η[1] == 1.0
@test stokes.V.Vy[1] == 10
@test thermal.T[1] == 100
@test stokes1.viscosity.η[1] == 1.0
@test stokes1.V.Vy[1] == 10
@test thermal1.T[1] == 100
@test isnothing(stokes.V.Vz)
@test dt1 == 0.1


# check the if the hdf5 function also works
checkpointing_hdf5(dst, stokes, thermal.T, time)
checkpointing_hdf5(dst, stokes, thermal.T, time, dt)

# Check that the file was created
fname = joinpath(dst, "checkpoint.h5")
@test isfile(fname)

# Load the data from the file
P, T, Vx, Vy, Vz, η, t = load_checkpoint_hdf5(fname)
P, T, Vx, Vy, Vz, η, t, dt = load_checkpoint_hdf5(fname)

stokes.viscosity.η .= η
stokes.V.Vy .= Vy
Expand All @@ -82,6 +90,7 @@ using GeoParams
@test stokes.V.Vy[1] == 10
@test thermal.T[1] == 100
@test isnothing(Vz)
@test dt == 0.1

# 3D case
stokes = StokesArrays(backend_JR, (nx,ny,1))
Expand All @@ -100,8 +109,8 @@ using GeoParams


# Call the function
checkpointing_jld2(dst, stokes, thermal, time, igg)
checkpointing_jld2(dst, stokes, thermal, time)
checkpointing_jld2(dst, stokes, thermal, time, dt, igg)
checkpointing_jld2(dst, stokes, thermal, time, dt)

# Check that the file was created
fname = joinpath(dst, "checkpoint" * lpad("$(igg.me)", 4, "0") * ".jld2")
Expand All @@ -117,14 +126,14 @@ using GeoParams


# check the if the hdf5 function also works
checkpointing_hdf5(dst, stokes, thermal.T, time)
checkpointing_hdf5(dst, stokes, thermal.T, time, dt)

# Check that the file was created
fname = joinpath(dst, "checkpoint.h5")
@test isfile(fname)

# Load the data from the file
P, T, Vx, Vy, Vz, η, t = load_checkpoint_hdf5(fname)
P, T, Vx, Vy, Vz, η, t, dt = load_checkpoint_hdf5(fname)

stokes.viscosity.η .= η
stokes.V.Vy .= Vy
Expand Down

0 comments on commit 431be3f

Please sign in to comment.