Skip to content

Commit

Permalink
gpu2cpu conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
albert-de-montserrat committed Apr 26, 2024
1 parent f4518bc commit f1a4336
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/MetaJustRelax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ function environment!(model::PS_Setup{T,N}) where {T,N}
ViscoElastoPlastic,
solve!

include(joinpath(@__DIR__, "array_conversions.jl"))
export Array

include(joinpath(@__DIR__, "Utils.jl"))
export @allocate, @add, @idx, @copy
export @velocity,
Expand Down
24 changes: 24 additions & 0 deletions src/array_conversions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

import Base.Array

@inline remove_parameters(::T) where T = Base.typename(T).wrapper

function Array(x::T) where T<:Union{SymmetricTensor, ThermalArrays, Velocity, Residual}
nfields = fieldcount(T)
cpu_fields = ntuple(Val(nfields)) do i
Base.@_inline_meta
Array(getfield(x, i))
end
T_clean = remove_parameters(x)
return T_clean(cpu_fields...)
end

function Array(x::StokesArrays{T,A,B,C,M,nDim}) where {T,A,B,C,M,nDim}
nfields = fieldcount(StokesArrays)
cpu_fields = ntuple(Val(nfields)) do i
Base.@_inline_meta
Array(getfield(x, i))
end
T_clean = remove_parameters(x)
return T_clean(cpu_fields...)
end
4 changes: 2 additions & 2 deletions src/stokes/MetaStokes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ function make_stokes_struct!(; name::Symbol=:StokesArrays)
)
end

function $(name)(args::Vararg{T,N}) where {T<:AbstractArray,N}
function $(name)(args::Vararg{Any,N}) where {N}
return new{
ViscoElastic,
typeof(args[4]),
typeof(args[3]),
typeof(args[5]),
typeof(args[end]),
typeof(args[1]),
Expand Down
16 changes: 16 additions & 0 deletions test/test_arrays_conversions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using JustRelax, Test
model = PS_Setup(:Threads, Float64, 2)
environment!(model)

@testset "Array conversions" begin
ni = 10, 10
stokes = StokesArrays(ni, ViscoElastic)
thermal = ThermalArrays(ni)

@test Array(stokes.V) isa Velocity{Matrix{Float64}}
@test Array(stokes.τ) isa SymmetricTensor{Matrix{Float64}}
@test Array(stokes.R) isa Residual{Matrix{Float64}}
@test Array(stokes.P) isa Matrix
@test Array(stokes) isa StokesArrays
@test Array(thermal) isa ThermalArrays{Matrix{Float64}}
end

0 comments on commit f1a4336

Please sign in to comment.