Skip to content

Commit

Permalink
Out-of-place copy method for JR types (#175)
Browse files Browse the repository at this point in the history
* copy method for JR types

* format

* revert name change

* fix
  • Loading branch information
albert-de-montserrat authored Jun 7, 2024
1 parent ca45acb commit 9dd49c0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 35 deletions.
4 changes: 2 additions & 2 deletions src/JustRelax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ include("JustRelax_CPU.jl")

include("IO/DataIO.jl")

include("array_conversions.jl")
export Array
include("types/type_conversions.jl")
export Array, copy

end # module
13 changes: 12 additions & 1 deletion src/array_conversions.jl → src/types/type_conversions.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Base: Array
import Base: Array, copy

const JR_T = Union{StokesArrays,SymmetricTensor,ThermalArrays,Velocity,Residual,Viscosity}

Expand All @@ -19,3 +19,14 @@ function Array(::GPUBackendTrait, x::T) where {T<:JR_T}
T_clean = remove_parameters(x)
return T_clean(cpu_fields...)
end

function copy(x::T) where {T<:JR_T}
nfields = fieldcount(T)
fields = ntuple(Val(nfields)) do i
Base.@_inline_meta
field = getfield(x, i)
field === nothing ? nothing : copy(field)
end
T_clean = remove_parameters(x)
return T_clean(fields...)
end
52 changes: 20 additions & 32 deletions test/test_arrays_conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ else
CPUBackend
end

@testset "Array conversions" begin
ni = 2, 2
stokes = StokesArrays(backend, ni)
thermal = ThermalArrays(backend, ni)
ni = 2, 2
stokes = StokesArrays(backend, ni)
thermal = ThermalArrays(backend, ni)

@testset "Type conversions" begin
A1 = Array(stokes.V)
A2 = Array(stokes.τ)
A3 = Array(stokes.R)
Expand All @@ -33,31 +34,18 @@ end
@test typeof(A6) <: JustRelax.ThermalArrays{<:Array}
end

# const JR_T = Union{
# JustRelax.StokesArrays,
# JustRelax.SymmetricTensor,
# JustRelax.ThermalArrays,
# JustRelax.Velocity,
# JustRelax.Residual
# }

# @inline remove_parameters(::T) where {T} = Base.typename(T).wrapper

# foo(x::T) where {T} = foo(JustRelax.backend(x), x)
# foo(x::Array) = x
# foo(::Nothing) = nothing

# function foo(::Any, x::T) where {T<:JR_T}
# nfields = fieldcount(T)
# cpu_fields = ntuple(Val(nfields)) do i
# Base.@_inline_meta
# @show i
# foo(getfield(x, i))
# end
# T_clean = remove_parameters(x)
# return T_clean(cpu_fields...)
# end

# foo(stokes.V)

# pot = 29
@testset "Type copy" begin
A1 = copy(stokes.V)
A2 = copy(stokes.τ)
A3 = copy(stokes.R)
A4 = copy(stokes.P)
A5 = copy(stokes)
A6 = copy(thermal)

@test typeof(A1) <: JustRelax.Velocity
@test typeof(A2) <: JustRelax.SymmetricTensor
@test typeof(A3) <: JustRelax.Residual
@test typeof(A4) <: typeof(stokes.P)
@test typeof(A5) <: JustRelax.StokesArrays
@test typeof(A6) <: JustRelax.ThermalArrays
end

0 comments on commit 9dd49c0

Please sign in to comment.