Skip to content

Commit

Permalink
ext: CUDA trait
Browse files Browse the repository at this point in the history
  • Loading branch information
albert-de-montserrat committed Apr 29, 2024
1 parent 2d0030e commit 1ecccb7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
7 changes: 2 additions & 5 deletions ext/JustRelaxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@ module JustRelax2D
using MPI

import JustRelax:
IGG, BackendTrait, CPUBackendTrait, backend, CPUBackend, Geometry, @cell
IGG, BackendTrait, CPUBackendTrait, CUDABackendTrait, backend, CPUBackend, Geometry, @cell

@init_parallel_stencil(CUDA, Float64, 2)

include("../src/common.jl")
include("../src/stokes/Stokes2D.jl")

# add CUDA traits
struct CUDABackendTrait <: BackendTrait end

@inline backend(::CuArray) = CUDABackendTrait()
@inline backend(::Type{<:CuArray}) = CUDABackendTrait()

Expand Down Expand Up @@ -144,7 +141,7 @@ module JustRelax2D
end

## Stress
JustRelax.JustRelax2D.tensor_invariant!(A::SymmetricTensor) = tensor_invariant!(A)
JustRelax.JustRelax2D.tensor_invariant!(A::JustRelax.SymmetricTensor) = tensor_invariant!(A)

## Buoyancy forces
function JustRelax.JustRelax2D.compute_ρg!(ρg::CuArray, rheology, args)
Expand Down
4 changes: 2 additions & 2 deletions src/types/traits.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
abstract type BackendTrait end
struct CPUBackendTrait <: BackendTrait end
struct NonCPUBackendTrait <: BackendTrait end
# struct CUDABackendTrait <: BackendTrait end
# struct AMDGPUBackendTrait <: BackendTrait end
struct CUDABackendTrait <: BackendTrait end
struct AMDGPUBackendTrait <: BackendTrait end

@inline backend(::Array) = CPUBackendTrait()
@inline backend(::Type{<:Array}) = CPUBackendTrait()
Expand Down

0 comments on commit 1ecccb7

Please sign in to comment.