Skip to content

Commit

Permalink
fixes 3D extension backend mixup
Browse files Browse the repository at this point in the history
  • Loading branch information
aelligp committed May 23, 2024
1 parent 372d5be commit f7eaeea
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
36 changes: 18 additions & 18 deletions src/ext/AMDGPU/3D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import JustRelax:
IGG,
BackendTrait,
CPUBackendTrait,
CUDABackendTrait,
AMDGPUBackendTrait,
backend,
CPUBackend,
Geometry,
Expand All @@ -29,24 +29,24 @@ include("../../common.jl")
include("../../stokes/Stokes3D.jl")

# Types
function JR3D.StokesArrays(::Type{CUDABackend}, ni::NTuple{N,Integer}) where {N}
function JR3D.StokesArrays(::Type{AMDGPUBackend}, ni::NTuple{N,Integer}) where {N}
return StokesArrays(ni)
end

function JR3D.ThermalArrays(::Type{CUDABackend}, ni::NTuple{N,Number}) where {N}
function JR3D.ThermalArrays(::Type{AMDGPUBackend}, ni::NTuple{N,Number}) where {N}
return ThermalArrays(ni...)
end

function JR3D.ThermalArrays(::Type{CUDABackend}, ni::Vararg{Number,N}) where {N}
function JR3D.ThermalArrays(::Type{AMDGPUBackend}, ni::Vararg{Number,N}) where {N}
return ThermalArrays(ni...)
end

function JR3D.PhaseRatio(::Type{CUDABackend}, ni, num_phases)
function JR3D.PhaseRatio(::Type{AMDGPUBackend}, ni, num_phases)
return PhaseRatio(ni, num_phases)
end

function JR3D.PTThermalCoeffs(
::Type{CUDABackend},
::Type{AMDGPUBackend},
rheology,
phase_ratios,
args,
Expand All @@ -61,7 +61,7 @@ function JR3D.PTThermalCoeffs(
end

function JR3D.PTThermalCoeffs(
::Type{CUDABackend},
::Type{AMDGPUBackend},
rheology,
args,
dt,
Expand All @@ -75,25 +75,25 @@ function JR3D.PTThermalCoeffs(
end

# Boundary conditions
function JR3D.flow_bcs!(::CUDABackendTrait, stokes::JustRelax.StokesArrays, bcs)
function JR3D.flow_bcs!(::AMDGPUBackendTrait, stokes::JustRelax.StokesArrays, bcs)
return _flow_bcs!(bcs, @velocity(stokes))
end

function flow_bcs!(::CUDABackendTrait, stokes::JustRelax.StokesArrays, bcs)
function flow_bcs!(::AMDGPUBackendTrait, stokes::JustRelax.StokesArrays, bcs)
return _flow_bcs!(bcs, @velocity(stokes))
end

function JR3D.thermal_bcs!(::CUDABackendTrait, thermal::JustRelax.ThermalArrays, bcs)
function JR3D.thermal_bcs!(::AMDGPUBackendTrait, thermal::JustRelax.ThermalArrays, bcs)
return thermal_bcs!(thermal.T, bcs)
end

function thermal_bcs!(::CUDABackendTrait, thermal::JustRelax.ThermalArrays, bcs)
function thermal_bcs!(::AMDGPUBackendTrait, thermal::JustRelax.ThermalArrays, bcs)
return thermal_bcs!(thermal.T, bcs)
end

# Phases
function JR3D.phase_ratios_center!(
::CUDABackendTrait,
::AMDGPUBackendTrait,
phase_ratios::JustRelax.PhaseRatio,
particles,
grid::Geometry,
Expand Down Expand Up @@ -134,7 +134,7 @@ function compute_viscosity!(η, ν, εII::RocArray, args, rheology, cutoff)
end

## Stress
function JR3D.tensor_invariant!(::CUDABackendTrait, A::JustRelax.SymmetricTensor)
function JR3D.tensor_invariant!(::AMDGPUBackendTrait, A::JustRelax.SymmetricTensor)
return _tensor_invariant!(A)
end

Expand All @@ -147,11 +147,11 @@ function JR3D.compute_ρg!(ρg::RocArray, phase_ratios::JustRelax.PhaseRatio, rh
end

# Interpolations
function JR3D.temperature2center!(::CUDABackendTrait, thermal::JustRelax.ThermalArrays)
function JR3D.temperature2center!(::AMDGPUBackendTrait, thermal::JustRelax.ThermalArrays)
return _temperature2center!(thermal)
end

function temperature2center!(::CUDABackendTrait, thermal::JustRelax.ThermalArrays)
function temperature2center!(::AMDGPUBackendTrait, thermal::JustRelax.ThermalArrays)
return _temperature2center!(thermal)
end

Expand All @@ -170,16 +170,16 @@ function JR3D.center2vertex!(
end

# Solvers
function JR3D.solve!(::CUDABackendTrait, stokes, args...; kwargs)
function JR3D.solve!(::AMDGPUBackendTrait, stokes, args...; kwargs)
return _solve!(stokes, args...; kwargs...)
end

function JR3D.heatdiffusion_PT!(::CUDABackendTrait, thermal, args...; kwargs)
function JR3D.heatdiffusion_PT!(::AMDGPUBackendTrait, thermal, args...; kwargs)
return _heatdiffusion_PT!(thermal, args...; kwargs...)
end

# Utils
function JR3D.compute_dt(::CUDABackendTrait, S::JustRelax.StokesArrays, args...)
function JR3D.compute_dt(::AMDGPUBackendTrait, S::JustRelax.StokesArrays, args...)
return _compute_dt(S, args...)
end

Expand Down
26 changes: 13 additions & 13 deletions src/ext/CUDA/3D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,31 +105,31 @@ end
# Rheology

## viscosity
function JR3D.compute_viscosity!(::AMDGPUBackendTrait, stokes, ν, args, rheology, cutoff)
function JR3D.compute_viscosity!(::CUDABackendTrait, stokes, ν, args, rheology, cutoff)
return _compute_viscosity!(stokes, ν, args, rheology, cutoff)
end

function JR3D.compute_viscosity!(
::AMDGPUBackendTrait, stokes, ν, phase_ratios, args, rheology, cutoff
::CUDABackendTrait, stokes, ν, phase_ratios, args, rheology, cutoff
)
return _compute_viscosity!(stokes, ν, phase_ratios, args, rheology, cutoff)
end

function JR2D.compute_viscosity!(η, ν, εII::RocArray, args, rheology, cutoff)
function JR3D.compute_viscosity!(η, ν, εII::CuArray, args, rheology, cutoff)
return compute_viscosity!(η, ν, εII, args, rheology, cutoff)
end

function compute_viscosity!(::AMDGPUBackendTrait, stokes, ν, args, rheology, cutoff)
function compute_viscosity!(::CUDABackendTrait, stokes, ν, args, rheology, cutoff)
return _compute_viscosity!(stokes, ν, args, rheology, cutoff)
end

function compute_viscosity!(
::AMDGPUBackendTrait, stokes, ν, phase_ratios, args, rheology, cutoff
::CUDABackendTrait, stokes, ν, phase_ratios, args, rheology, cutoff
)
return _compute_viscosity!(stokes, ν, phase_ratios, args, rheology, cutoff)
end

function compute_viscosity!(η, ν, εII::RocArray, args, rheology, cutoff)
function compute_viscosity!(η, ν, εII::CuArray, args, rheology, cutoff)
return compute_viscosity!(η, ν, εII, args, rheology, cutoff)
end

Expand All @@ -139,10 +139,10 @@ function JR3D.tensor_invariant!(::CUDABackendTrait, A::JustRelax.SymmetricTensor
end

## Buoyancy forces
function JR3D.compute_ρg!(ρg::RocArray, rheology, args)
function JR3D.compute_ρg!(ρg::CuArray, rheology, args)
return compute_ρg!(ρg, rheology, args)
end
function JR3D.compute_ρg!(ρg::RocArray, phase_ratios::JustRelax.PhaseRatio, rheology, args)
function JR3D.compute_ρg!(ρg::CuArray, phase_ratios::JustRelax.PhaseRatio, rheology, args)
return compute_ρg!(ρg, phase_ratios, rheology, args)
end

Expand All @@ -155,17 +155,17 @@ function temperature2center!(::CUDABackendTrait, thermal::JustRelax.ThermalArray
return _temperature2center!(thermal)
end

function JR3D.vertex2center!(center::T, vertex::T) where {T<:RocArray}
function JR3D.vertex2center!(center::T, vertex::T) where {T<:CuArray}
return vertex2center!(center, vertex)
end

function JR3D.center2vertex!(vertex::T, center::T) where {T<:RocArray}
function JR3D.center2vertex!(vertex::T, center::T) where {T<:CuArray}
return center2vertex!(vertex, center)
end

function JR3D.center2vertex!(
vertex_yz::T, vertex_xz::T, vertex_xy::T, center_yz::T, center_xz::T, center_xy::T
) where {T<:RocArray}
) where {T<:CuArray}
return center2vertex!(vertex_yz, vertex_xz, vertex_xy, center_yz, center_xz, center_xy)
end

Expand All @@ -186,7 +186,7 @@ end
function JR3D.subgrid_characteristic_time!(
subgrid_arrays,
particles,
dt₀::RocArray,
dt₀::CuArray,
phases::JustRelax.PhaseRatio,
rheology,
thermal::JustRelax.ThermalArrays,
Expand All @@ -204,7 +204,7 @@ end
function JR3D.subgrid_characteristic_time!(
subgrid_arrays,
particles,
dt₀::RocArray,
dt₀::CuArray,
phases::AbstractArray{Int,N},
rheology,
thermal::JustRelax.ThermalArrays,
Expand Down

0 comments on commit f7eaeea

Please sign in to comment.