diff --git a/src/JustRelax.jl b/src/JustRelax.jl index 0322619f..1d39792b 100644 --- a/src/JustRelax.jl +++ b/src/JustRelax.jl @@ -47,7 +47,7 @@ include("JustRelax_CPU.jl") include("IO/DataIO.jl") -include("array_conversions.jl") -export Array +include("types/type_conversions.jl") +export Array, copy end # module diff --git a/src/array_conversions.jl b/src/types/type_conversions.jl similarity index 64% rename from src/array_conversions.jl rename to src/types/type_conversions.jl index 78fa2794..015091a5 100644 --- a/src/array_conversions.jl +++ b/src/types/type_conversions.jl @@ -1,4 +1,4 @@ -import Base: Array +import Base: Array, copy const JR_T = Union{StokesArrays,SymmetricTensor,ThermalArrays,Velocity,Residual,Viscosity} @@ -19,3 +19,14 @@ function Array(::GPUBackendTrait, x::T) where {T<:JR_T} T_clean = remove_parameters(x) return T_clean(cpu_fields...) end + +function copy(x::T) where {T<:JR_T} + nfields = fieldcount(T) + fields = ntuple(Val(nfields)) do i + Base.@_inline_meta + field = getfield(x, i) + field === nothing ? nothing : copy(field) + end + T_clean = remove_parameters(x) + return T_clean(fields...) +end diff --git a/test/test_arrays_conversions.jl b/test/test_arrays_conversions.jl index 0c77a36e..ffcf8a87 100644 --- a/test/test_arrays_conversions.jl +++ b/test/test_arrays_conversions.jl @@ -14,10 +14,11 @@ else CPUBackend end -@testset "Array conversions" begin - ni = 2, 2 - stokes = StokesArrays(backend, ni) - thermal = ThermalArrays(backend, ni) +ni = 2, 2 +stokes = StokesArrays(backend, ni) +thermal = ThermalArrays(backend, ni) + +@testset "Type conversions" begin A1 = Array(stokes.V) A2 = Array(stokes.τ) A3 = Array(stokes.R) @@ -33,31 +34,18 @@ end @test typeof(A6) <: JustRelax.ThermalArrays{<:Array} end -# const JR_T = Union{ -# JustRelax.StokesArrays, -# JustRelax.SymmetricTensor, -# JustRelax.ThermalArrays, -# JustRelax.Velocity, -# JustRelax.Residual -# } - -# @inline remove_parameters(::T) where {T} = Base.typename(T).wrapper - -# foo(x::T) where {T} = foo(JustRelax.backend(x), x) -# foo(x::Array) = x -# foo(::Nothing) = nothing - -# function foo(::Any, x::T) where {T<:JR_T} -# nfields = fieldcount(T) -# cpu_fields = ntuple(Val(nfields)) do i -# Base.@_inline_meta -# @show i -# foo(getfield(x, i)) -# end -# T_clean = remove_parameters(x) -# return T_clean(cpu_fields...) -# end - -# foo(stokes.V) - -# pot = 29 +@testset "Type copy" begin + A1 = copy(stokes.V) + A2 = copy(stokes.τ) + A3 = copy(stokes.R) + A4 = copy(stokes.P) + A5 = copy(stokes) + A6 = copy(thermal) + + @test typeof(A1) <: JustRelax.Velocity + @test typeof(A2) <: JustRelax.SymmetricTensor + @test typeof(A3) <: JustRelax.Residual + @test typeof(A4) <: typeof(stokes.P) + @test typeof(A5) <: JustRelax.StokesArrays + @test typeof(A6) <: JustRelax.ThermalArrays +end