diff --git a/src/MetaJustRelax.jl b/src/MetaJustRelax.jl index 9b4ee1cb..21092997 100644 --- a/src/MetaJustRelax.jl +++ b/src/MetaJustRelax.jl @@ -77,6 +77,9 @@ function environment!(model::PS_Setup{T,N}) where {T,N} ViscoElastoPlastic, solve! + include(joinpath(@__DIR__, "array_conversions.jl")) + export Array + include(joinpath(@__DIR__, "Utils.jl")) export @allocate, @add, @idx, @copy export @velocity, diff --git a/src/array_conversions.jl b/src/array_conversions.jl new file mode 100644 index 00000000..7dcdb99e --- /dev/null +++ b/src/array_conversions.jl @@ -0,0 +1,24 @@ + +import Base.Array + +@inline remove_parameters(::T) where T = Base.typename(T).wrapper + +function Array(x::T) where T<:Union{SymmetricTensor, ThermalArrays, Velocity, Residual} + nfields = fieldcount(T) + cpu_fields = ntuple(Val(nfields)) do i + Base.@_inline_meta + Array(getfield(x, i)) + end + T_clean = remove_parameters(x) + return T_clean(cpu_fields...) +end + +function Array(x::StokesArrays{T,A,B,C,M,nDim}) where {T,A,B,C,M,nDim} + nfields = fieldcount(StokesArrays) + cpu_fields = ntuple(Val(nfields)) do i + Base.@_inline_meta + Array(getfield(x, i)) + end + T_clean = remove_parameters(x) + return T_clean(cpu_fields...) +end diff --git a/src/stokes/MetaStokes.jl b/src/stokes/MetaStokes.jl index f7f76c25..7ab8bacc 100644 --- a/src/stokes/MetaStokes.jl +++ b/src/stokes/MetaStokes.jl @@ -220,10 +220,10 @@ function make_stokes_struct!(; name::Symbol=:StokesArrays) ) end - function $(name)(args::Vararg{T,N}) where {T<:AbstractArray,N} + function $(name)(args::Vararg{Any,N}) where {N} return new{ ViscoElastic, - typeof(args[4]), + typeof(args[3]), typeof(args[5]), typeof(args[end]), typeof(args[1]), diff --git a/test/test_arrays_conversions.jl b/test/test_arrays_conversions.jl new file mode 100644 index 00000000..2eaf40d9 --- /dev/null +++ b/test/test_arrays_conversions.jl @@ -0,0 +1,16 @@ +using JustRelax, Test +model = PS_Setup(:Threads, Float64, 2) +environment!(model) + +@testset "Array conversions" begin + ni = 10, 10 + stokes = StokesArrays(ni, ViscoElastic) + thermal = ThermalArrays(ni) + + @test Array(stokes.V) isa Velocity{Matrix{Float64}} + @test Array(stokes.τ) isa SymmetricTensor{Matrix{Float64}} + @test Array(stokes.R) isa Residual{Matrix{Float64}} + @test Array(stokes.P) isa Matrix + @test Array(stokes) isa StokesArrays + @test Array(thermal) isa ThermalArrays{Matrix{Float64}} +end