Skip to content

Commit

Permalink
Merge pull request #143 from JuliaGeodynamics/adm/fix_conversion
Browse files Browse the repository at this point in the history
Checkpointing
  • Loading branch information
albert-de-montserrat authored Sep 26, 2024
2 parents 37ebdb1 + 2decd8d commit e435861
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 15 deletions.
21 changes: 17 additions & 4 deletions ext/JustPICAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ import JustPIC: AbstractBackend, AMDGPUBackend

JustPIC.TA(::Type{AMDGPUBackend}) = ROCArray

function AMDGPU.ROCArray(particles::JustPIC.Particles{JustPIC.AMDGPUBackend})
(; coords, index, nxcell, max_xcell, min_xcell, np) = particles
coords_gpu = CuArray.(coords);
return Particles(CUDABackend, coords_gpu, CuArray(index), nxcell, max_xcell, min_xcell, np)
end

function AMDGPU.ROCArray(phase_ratios::JustPIC.PhaseRatios{JustPIC.AMDGPUBackend})
(; vertex, center) = phase_ratios
return JustPIC.PhaseRatios(CUDABackend, CuArray(vertex), CuArray(center))
end

module _2D
using AMDGPU
using ImplicitGlobalGrid
Expand Down Expand Up @@ -36,6 +47,7 @@ module _2D
include(joinpath(@__DIR__, "../src/common.jl"))
include(joinpath(@__DIR__, "../src/AMDGPUExt/CellArrays.jl"))


function JustPIC._2D.Particles(
coords,
index::CellArray{StaticArraysCore.SVector{N1,Bool},2,0,ROCArray{Bool,N2}},
Expand Down Expand Up @@ -232,7 +244,7 @@ module _2D
function JustPIC._2D.PhaseRatios(
::Type{AMDGPUBackend}, nphases::Integer, ni::NTuple{N,Integer}
) where {N}
return PhaseRatios(Float64, AMDGPUBackend, nphases, ni)
return JustPIC._2D.PhaseRatios(Float64, AMDGPUBackend, nphases, ni)
end

function JustPIC._2D.PhaseRatios(
Expand All @@ -241,7 +253,7 @@ module _2D
center = cell_array(0.0, (nphases,), ni)
vertex = cell_array(0.0, (nphases,), ni .+ 1)

return PhaseRatios{B,typeof(center)}(center, vertex)
return JustPIC.PhaseRatios(AMDGPUBackend, center, vertex)
end

function JustPIC._2D.phase_ratios_center!(
Expand Down Expand Up @@ -297,6 +309,7 @@ module _3D
include(joinpath(@__DIR__, "../src/common.jl"))
include(joinpath(@__DIR__, "../src/AMDGPUExt/CellArrays.jl"))


function JustPIC._3D.Particles(
coords,
index::CellArray{StaticArraysCore.SVector{N1,Bool},3,0,ROCArray{Bool,N2}},
Expand Down Expand Up @@ -470,7 +483,7 @@ module _3D
function JustPIC._3D.PhaseRatios(
::Type{AMDGPUBackend}, nphases::Integer, ni::NTuple{N,Integer}
) where {N}
return PhaseRatios(Float64, AMDGPUBackend, nphases, ni)
return JustPIC._3D.PhaseRatios(Float64, AMDGPUBackend, nphases, ni)
end

function JustPIC._3D.PhaseRatios(
Expand All @@ -479,7 +492,7 @@ module _3D
center = cell_array(0.0, (nphases,), ni)
vertex = cell_array(0.0, (nphases,), ni .+ 1)

return PhaseRatios{B,typeof(center)}(center, vertex)
return JustPIC.PhaseRatios(AMDGPUBackend, center, vertex)
end

function JustPIC._3D.phase_ratios_center!(
Expand Down
23 changes: 19 additions & 4 deletions ext/JustPICCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@ using JustPIC

JustPIC.TA(::Type{CUDABackend}) = CuArray

function CUDA.CuArray(particles::JustPIC.Particles{JustPIC.CPUBackend})
(; coords, index, nxcell, max_xcell, min_xcell, np) = particles
coords_gpu = CuArray.(coords);
return Particles(CUDABackend, coords_gpu, CuArray(index), nxcell, max_xcell, min_xcell, np)
end

function CUDA.CuArray(phase_ratios::JustPIC.PhaseRatios{JustPIC.CPUBackend})
(; vertex, center) = phase_ratios
return JustPIC.PhaseRatios(CUDABackend, CuArray(vertex), CuArray(center))
end

module _2D
using CUDA
using ImplicitGlobalGrid
Expand Down Expand Up @@ -36,6 +47,8 @@ module _2D
include(joinpath(@__DIR__, "../src/common.jl"))
include(joinpath(@__DIR__, "../src/CUDAExt/CellArrays.jl"))

# Conversions

function JustPIC._2D.Particles(
coords,
index::CellArray{StaticArraysCore.SVector{N1,Bool},2,0,CuArray{Bool,N2}},
Expand Down Expand Up @@ -226,7 +239,7 @@ module _2D
function JustPIC._2D.PhaseRatios(
::Type{CUDABackend}, nphases::Integer, ni::NTuple{N,Integer}
) where {N}
return PhaseRatios(Float64, CUDABackend, nphases, ni)
return JustPIC._2D.PhaseRatios(Float64, CUDABackend, nphases, ni)
end

function JustPIC._2D.PhaseRatios(
Expand All @@ -235,7 +248,7 @@ module _2D
center = cell_array(0.0, (nphases,), ni)
vertex = cell_array(0.0, (nphases,), ni .+ 1)

return PhaseRatios{B,typeof(center)}(center, vertex)
return JustPIC.PhaseRatios(CUDABackend, center, vertex)
end

function JustPIC._2D.phase_ratios_center!(
Expand Down Expand Up @@ -291,6 +304,8 @@ module _3D

include(joinpath(@__DIR__, "../src/common.jl"))
include(joinpath(@__DIR__, "../src/CUDAExt/CellArrays.jl"))

# Conversions

function JustPIC._3D.Particles(
coords,
Expand Down Expand Up @@ -465,7 +480,7 @@ module _3D
function JustPIC._3D.PhaseRatios(
::Type{CUDABackend}, nphases::Integer, ni::NTuple{N,Integer}
) where {N}
return PhaseRatios(Float64, CUDABackend, nphases, ni)
return JustPIC._3D.PhaseRatios(Float64, CUDABackend, nphases, ni)
end

function JustPIC._3D.PhaseRatios(
Expand All @@ -474,7 +489,7 @@ module _3D
center = cell_array(0.0, (nphases,), ni)
vertex = cell_array(0.0, (nphases,), ni .+ 1)

return PhaseRatios{B,typeof(center)}(center, vertex)
return JustPIC.PhaseRatios(CUDABackend, center, vertex)
end

function JustPIC._3D.phase_ratios_center!(
Expand Down
4 changes: 2 additions & 2 deletions src/CellArrays/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import Base: Array, copy
@inline remove_parameters(::T) where {T} = Base.typename(T).wrapper

# detect if we are on the CPU (`Val{false}`) or GPU (`Val{true}`)
@inline isdevice(::Type{Array}) = Val(false)
@inline isdevice(::Type{Array{T, N}}) where {T,N} = Val(false)
@inline isdevice(::Type{T}) where {T<:AbstractArray} = Val(true) # this is a big assumption but still
@inline isdevice(::T) where {T} =
throw(ArgumentError("$(T) is not a supported CellArray type."))
Expand Down Expand Up @@ -33,8 +33,8 @@ function Array(x::T) where {T<:AbstractParticles}
return T_clean(CPUBackend, cpu_fields...)
end

_Array(x) = x
_Array(::Nothing) = nothing
_Array(::T) where {T} = T
_Array(x::AbstractArray) = Array(x)
_Array(x::NTuple{N,T}) where {N,T} = ntuple(i -> _Array(x[i]), Val(N))

Expand Down
6 changes: 3 additions & 3 deletions src/JustPIC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ function CA end
TA() = Array
TA(::Type{CPUBackend}) = Array

include("PhaseRatios/PhaseRatios.jl")
export nphases, numphases

include("particles.jl")
export AbstractParticles, Particles, MarkerChain, PassiveMarkers, cell_index, cell_length

include("PhaseRatios/PhaseRatios.jl")
export nphases, numphases

include("Advection/types.jl")
export AbstractAdvectionIntegrator, Euler, RungeKutta2

Expand Down
7 changes: 6 additions & 1 deletion src/PhaseRatios/PhaseRatios.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
struct PhaseRatios{Backend,T}
struct PhaseRatios{Backend,T} <: AbstractParticles
center::T
vertex::T

function PhaseRatios(::Type{B}, center::T, vertex::T) where {B, T}
return new{B, T}(center, vertex)
end
end


"""
nphases(x::PhaseRatios)
Expand Down
6 changes: 5 additions & 1 deletion src/PhaseRatios/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@ function PhaseRatios(
center = cell_array(0.0, (nphases,), ni)
vertex = cell_array(0.0, (nphases,), ni .+ 1)

return JustPIC.PhaseRatios{B,typeof(center)}(center, vertex)
return JustPIC.PhaseRatios(B, center, vertex)
end

# function PhaseRatios(::Type{B}, center, vertex) where {B}
# return JustPIC.PhaseRatios{B, typeof(center)}(center, vertex)
# end

function PhaseRatios(nphases::Integer, ni::NTuple{N,Integer}) where {N}
return PhaseRatios(Float64, CPUBackend, nphases, ni)
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
CellArrays = "d35fcfd7-7af4-4c67-b1aa-d78070614af4"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ParallelStencil = "94395366-693c-11ea-3b26-d9b7aac5d958"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
56 changes: 56 additions & 0 deletions test/test_save_load.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using JLD2, JustPIC, JustPIC._2D

@testset "Save and load" begin
# Initialize particles -------------------------------
nxcell, max_xcell, min_xcell = 6, 6, 6
n = 128
nx = ny = n-1
ni = nx, ny
Lx = Ly = 1.0
# nodal vertices
xvi = xv, yv = range(0, Lx, length=n), range(0, Ly, length=n)

particles = init_particles(backend, nxcell, max_xcell, min_xcell, xvi...,);
phases, = init_cell_arrays(particles, Val(1));
phase_ratios = PhaseRatios(backend, 2, ni);
phase_ratios = PhaseRatios(JustPIC.CPUBackend, 2, ni);

jldsave(
"particles.jld2";
particles = Array(particles),
phases = Array(phases),
phase_ratios = Array(phase_ratios)
)

data = load("particles.jld2")
particles2 = data["particles"]
phases2 = data["phases"]
phase_ratios2 = data["phase_ratios"]

@test Array(particles.coords[1].data) == particles2.coords[1].data
@test Array(particles.coords[2].data) == particles2.coords[2].data
@test Array(particles.index.data) == particles2.index.data
@test Array(phase_ratios.center.data) == phase_ratios2.center.data
@test Array(phase_ratios.vertex.data) == phase_ratios2.vertex.data
@test Array(phases.data) == phases2.data

if isdefined(Main, :CUDA)
particles_cuda = CuArray(particles2)
phase_ratios_cuda = CuArray(phase_ratios2)
phases_cuda = CuArray(phases2)

@test particles_cuda isa JustPIC.Particles{CUDABackend}
@test phase_ratios_cuda isa JustPIC.PhaseRatios{CUDABackend}
@test phases_cuda isa CuArray

elseif isdefined(Main, :AMDGPU)
particles_amdgpu = ROCArray(particles2)
phase_ratios_amdgpu = ROCArray(phase_ratios2)
phases_amdgpu = ROCArray(phases2)

@test particles_cuda isa JustPIC.Particles{AMDGPUBackend}
@test phase_ratios_cuda isa JustPIC.PhaseRatios{AMDGPUBackend}
@test phases_cuda isa ROCArray

end
end

0 comments on commit e435861

Please sign in to comment.