Skip to content

Commit

Permalink
CUDA extension
Browse files Browse the repository at this point in the history
  • Loading branch information
albert-de-montserrat committed Apr 2, 2024
1 parent a8e0dc7 commit 915f35f
Show file tree
Hide file tree
Showing 15 changed files with 294 additions and 174 deletions.
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ version = "0.1.1"
[deps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CellArrays = "d35fcfd7-7af4-4c67-b1aa-d78070614af4"
GeoParams = "e018b62d-d9de-4a26-8697-af89c310ae38"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
ImplicitGlobalGrid = "4d7a3746-15be-11ea-1130-334b0c4f5fa0"
JustPIC = "10dc771f-8528-4cd9-9d3b-b21b2e693339"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
Expand All @@ -21,6 +21,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
JustRelaxCUDAExt = "CUDA"

[compat]
AMDGPU = "0.6, 0.7, 0.8"
Adapt = "3.7.2"
Expand Down
177 changes: 177 additions & 0 deletions ext/JustRelaxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
module JustRelaxCUDAExt

using CUDA
using JustRelax: JustRelax
import JustRelax: PTArray

JustRelax.PTArray(::Type{CUDABackend}) = CuArray

module JustRelax2D

using JustRelax: JustRelax
using CUDA
using StaticArrays
using CellArrays
using ParallelStencil, ParallelStencil.FiniteDifferences2D
using ImplicitGlobalGrid
using GeoParams, LinearAlgebra, Printf
using MPI

import JustRelax:
IGG,
BackendTrait,
CPUBackendTrait,
backend,
CPUBackend,
Geometry,
@cell

@init_parallel_stencil(CUDA, Float64, 2)

include("../src/common.jl")
include("../src/stokes/Stokes2D.jl")

# add CUDA traits
struct CUDABackendTrait <: BackendTrait end

@inline backend(::CuArray) = CUDABackendTrait()
@inline backend(::Type{<:CuArray}) = CUDABackendTrait()

# Types
function JustRelax.JustRelax2D.StokesArrays(
::Type{CUDABackend}, ni::Vararg{Integer,N}
) where {N}
return StokesArrays(tuple(ni...))
end

function JustRelax.JustRelax2D.StokesArrays(
::Type{CUDABackend}, ni::NTuple{N,Integer}
) where {N}
return StokesArrays(ni)
end

function JustRelax.JustRelax2D.ThermalArrays(
::Type{CUDABackend}, ni::NTuple{N,Number}
) where {N}
return ThermalArrays(ni...)
end

function JustRelax.JustRelax2D.ThermalArrays(
::Type{CUDABackend}, ni::Vararg{Number,N}
) where {N}
return ThermalArrays(ni...)
end

function JustRelax.JustRelax2D.PhaseRatio(::Type{CUDABackend}, ni, num_phases)
return PhaseRatio(ni, num_phases)
end

# Boundary conditions
function JustRelax.JustRelax2D.flow_bcs!(
::CUDABackendTrait, stokes::StokesArrays, bcs
)
return _flow_bcs!(bcs, @velocity(stokes))
end

function flow_bcs!(
::CUDABackendTrait, stokes::StokesArrays, bcs
)
return _flow_bcs!(bcs, @velocity(stokes))
end

function JustRelax.JustRelax2D.thermal_bcs!(
::CUDABackendTrait, thermal::ThermalArrays, bcs
)
return thermal_bcs!(thermal.T, bcs)
end

function thermal_bcs!(
::CUDABackendTrait, thermal::ThermalArrays, bcs
)
return thermal_bcs!(thermal.T, bcs)
end

# Phases
function JustRelax.JustRelax2D.phase_ratios_center(
::CUDABackendTrait, phase_ratios::PhaseRatio, particles, grid::Geometry, phases
)
return _phase_ratios_center(phase_ratios, particles, grid, phases)
end

# Rheology
## viscosity
function JustRelax.JustRelax2D.compute_viscosity!(
::CUDABackendTrait, stokes, ν, args, rheology, cutoff
)
return _compute_viscosity!(stokes, ν, args, rheology, cutoff)
end
function JustRelax.JustRelax2D.compute_viscosity!(
::CUDABackendTrait, stokes, ν, phase_ratios, args, rheology, cutoff
)
return _compute_viscosity!(stokes, ν, phase_ratios, args, rheology, cutoff)
end
function JustRelax.JustRelax2D.compute_viscosity!(
η, ν, εII::CuArray, args, rheology, cutoff
)
return compute_viscosity!(η, ν, εII, args, rheology, cutoff)
end

function compute_viscosity!(
::CUDABackendTrait, stokes, ν, args, rheology, cutoff
)
return _compute_viscosity!(stokes, ν, args, rheology, cutoff)
end
function compute_viscosity!(
::CUDABackendTrait, stokes, ν, phase_ratios, args, rheology, cutoff
)
return _compute_viscosity!(stokes, ν, phase_ratios, args, rheology, cutoff)
end
function compute_viscosity!(
η, ν, εII::CuArray, args, rheology, cutoff
)
return compute_viscosity!(η, ν, εII, args, rheology, cutoff)
end

## Stress
JustRelax.JustRelax2D.tensor_invariant!(A::SymmetricTensor) = tensor_invariant!(A)

## Buoyancy forces
function JustRelax.JustRelax2D.compute_ρg!(ρg::CuArray, rheology, args)
return compute_ρg!(ρg, rheology, args)
end
function JustRelax.JustRelax2D.compute_ρg!(
ρg::CuArray, phase_ratios::PhaseRatio, rheology, args
)
return compute_ρg!(ρg, phase_ratios, rheology, args)
end

# Interpolations
function JustRelax.JustRelax2D.temperature2center!(
::CUDABackendTrait, thermal::ThermalArrays
)
return _temperature2center!(thermal)
end
function JustRelax.JustRelax2D.vertex2center!(center::T, vertex::T) where {T<:CuArray}
return vertex2center!(center, vertex)
end
function JustRelax.JustRelax2D.center2vertex!(vertex::T, center::T) where {T<:CuArray}
return center2vertex!(vertex, center)
end

function JustRelax.JustRelax2D.center2vertex!(
vertex_yz::T, vertex_xz::T, vertex_xy::T, center_yz::T, center_xz::T, center_xy::T
) where {T<:CuArray}
return center2vertex!(
vertex_yz, vertex_xz, vertex_xy, center_yz, center_xz, center_xy
)
end

# Solvers
JustRelax.JustRelax2D.solve!(::CUDABackendTrait, stokes, args...; kwargs) = _solve!(stokes, args...; kwargs...)

# Utils
JustRelax.JustRelax2D.compute_dt(S::StokesArrays, di, dt_diff, I::IGG) = compute_dt(S, di, dt_diff, I::IGG)

end

end
7 changes: 6 additions & 1 deletion src/Interpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ end

# From cell vertices to cell center

function temperature2center!(thermal::ThermalArrays)
temperature2center!(thermal) = temperature2center!(backend(thermal), thermal)
function temperature2center!(::CPUBackendTrait, thermal::ThermalArrays)
return _temperature2center!(thermal)
end

function _temperature2center!(thermal::ThermalArrays)
@parallel (@idx size(thermal.Tc)...) temperature2center_kernel!(thermal.Tc, thermal.T)
return nothing
end
Expand Down
11 changes: 10 additions & 1 deletion src/JustRelax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ using StaticArrays

function solve!() end

struct CPUBackend end
struct AMDGPUBackend end

PTArray() = Array
PTArray(::Type{CPUBackend}) = Array
PTArray(::T) where {T} = error(ArgumentError("Unknown backend $T"))

export PTArray, CPUBackend, CUDABackend, AMDGPUBackend

include("types/traits.jl")
export BackendTrait, CPUBackendTrait

Expand All @@ -23,6 +32,6 @@ export @cell, element, setelement!, cellnum, cellaxes, new_empty_cell, setindex!

include("JustRelax_CPU.jl")

include("IO/DataIO.jl")
# include("IO/DataIO.jl")

end # module
12 changes: 2 additions & 10 deletions src/JustRelax_CPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@ using ImplicitGlobalGrid
using GeoParams, LinearAlgebra, Printf
using MPI

import JustRelax: IGG, BackendTrait, CPUBackendTrait, backend
import JustRelax: IGG, BackendTrait, CPUBackendTrait, backend, CPUBackend

@init_parallel_stencil(Threads, Float64, 2)

export PTArray

PTArray() = Array

include("common.jl")
include("stokes/Stokes2D.jl")
export solve!
Expand All @@ -32,14 +28,10 @@ using ImplicitGlobalGrid
using GeoParams, LinearAlgebra, Printf
using MPI

import JustRelax: IGG, BackendTrait, CPUBackendTrait, backend
import JustRelax: IGG, BackendTrait, CPUBackendTrait, backend, CPUBackend

@init_parallel_stencil(Threads, Float64, 3)

export PTArray

PTArray() = Array

include("common.jl")
include("stokes/Stokes3D.jl")
export solve!
Expand Down
82 changes: 14 additions & 68 deletions src/boundaryconditions/BoundaryConditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@ end
Apply the prescribed heat boundary conditions `bc` on the `T`
"""
function thermal_bcs!(T, bcs::TemperatureBoundaryConditions)
thermal_bcs!(thermal, bcs) = thermal_bcs!(backend(thermal), thermal, bcs)
function thermal_bcs!(
::CPUBackendTrait, thermal::ThermalArrays, bcs::FlowBoundaryConditions
)
return thermal_bcs!(thermal.T, bcs)
end

function thermal_bcs!(T::AbstractArray, bcs::TemperatureBoundaryConditions)
n = bc_index(T)

# no flux boundary conditions
Expand All @@ -64,12 +71,15 @@ end
Apply the prescribed flow boundary conditions `bc` on the `stokes`
"""
flow_bcs!(stokes, bcs::FlowBoundaryConditions) = _flow_bcs!(bcs, @velocity(stokes))
function flow_bcs!(bcs::FlowBoundaryConditions, V::Vararg{T,N}) where {T,N}
flow_bcs!(stokes, bcs) = flow_bcs!(backend(stokes), stokes, bcs)
function flow_bcs!(::CPUBackendTrait, stokes, bcs)
return _flow_bcs!(bcs, @velocity(stokes))
end
function flow_bcs!(bcs, V::Vararg{T,N}) where {T,N}
return _flow_bcs!(bcs, tuple(V...))
end

function _flow_bcs!(bcs::FlowBoundaryConditions, V)
function _flow_bcs!(bcs, V)
n = bc_index(V)
# no slip boundary conditions
do_bc(bcs.no_slip) && (@parallel (@idx n) no_slip!(V..., bcs.no_slip))
Expand Down Expand Up @@ -214,57 +224,6 @@ end
return nothing
end

@parallel_indices (i) function periodic_boundaries!(Ax, Ay, bc)
@inbounds begin
if i size(Ax, 1)
bc.bot && (Ax[i, 1] = Ax[i, end - 1])
bc.top && (Ax[i, end] = Ax[i, 2])
end
if i size(Ay, 2)
bc.left && (Ay[1, i] = Ay[end - 1, i])
bc.right && (Ay[end, i] = Ay[2, i])
end
end
return nothing
end

@parallel_indices (i) function periodic_boundaries!(
T::_T, bc
) where {_T<:AbstractArray{<:Any,2}}
@inbounds begin
if i size(T, 1)
bc.bot && (T[i, 1] = T[i, end - 1])
bc.top && (T[i, end] = T[i, 2])
end
if i size(T, 2)
bc.left && (T[1, i] = T[end - 1, i])
bc.right && (T[end, i] = T[2, i])
end
end
return nothing
end

@parallel_indices (i, j) function periodic_boundaries!(
T::_T, bc
) where {_T<:AbstractArray{<:Any,3}}
nx, ny, nz = size(T)
@inbounds begin
if i nx && j ny
bc.bot && (T[i, j, 1] = T[i, j, end - 1])
bc.top && (T[i, j, end] = T[i, j, 2])
end
if i ny && j nz
bc.left && (T[1, i, j] = T[end - 1, i, j])
bc.right && (T[end, i, j] = T[2, i, j])
end
if i nx && j nz
bc.front && (T[i, 1, j] = T[i, end - 1, j])
bc.back && (T[i, end, j] = T[i, 2, j])
end
end
return nothing
end

function pureshear_bc!(
stokes::StokesArrays, xci::NTuple{2,T}, xvi::NTuple{2,T}, εbg
) where {T}
Expand Down Expand Up @@ -307,19 +266,6 @@ function apply_free_slip!(freeslip::NamedTuple{<:Any,NTuple{2,T}}, Vx, Vy) where
return nothing
end

function thermal_boundary_conditions!(
insulation::NamedTuple, T::AbstractArray{_T,2}
) where {_T}
insulation_x, insulation_y = insulation

nx, ny = size(T)

insulation_x && (@parallel (1:ny) free_slip_x!(T))
insulation_y && (@parallel (1:nx) free_slip_y!(T))

return nothing
end

# 3D KERNELS

@parallel_indices (j, k) function free_slip_x!(A::AbstractArray{T,3}) where {T}
Expand Down
Loading

0 comments on commit 915f35f

Please sign in to comment.