diff --git a/src/types/constructors/stokes.jl b/src/types/constructors/stokes.jl index e0252a52..da75e475 100644 --- a/src/types/constructors/stokes.jl +++ b/src/types/constructors/stokes.jl @@ -24,7 +24,7 @@ function Displacement(nx::Integer, ny::Integer) nUy = (nx + 2, ny + 1) Ux, Uy = @zeros(nUx...), @zeros(nUy) - return JustRelax.Displacement(Vx, Vy, nothing) + return JustRelax.Displacement(Ux, Uy, nothing) end function Displacement(nx::Integer, ny::Integer, nz::Integer) diff --git a/src/types/displacement.jl b/src/types/displacement.jl new file mode 100644 index 00000000..9489a8d5 --- /dev/null +++ b/src/types/displacement.jl @@ -0,0 +1,57 @@ +function velocity2displacement!(stokes::JustRelax.StokesArrays, dt) + velocity2displacement!(stokes, backend(stokes), dt) + return nothing +end + +function velocity2displacement!(stokes::JustRelax.StokesArrays, ::CPUBackendTrait, dt) + _velocity2displacement!(stokes, dt) +end + +function _velocity2displacement!(stokes::JustRelax.StokesArrays, dt) + ni = size(stokes.P) + (; V, U) = stokes + @parallel (@idx ni.+2) _velocity2displacement!(V.Vx, V.Vy, V.Vz, U.Ux, U.Uy, U.Uz, 1 / dt) + return nothing +end + +@parallel_indices (I...) function _velocity2displacement!(Vx, Vy, Vz, Ux, Uy, Uz, _dt) + if all(I .≤ size(Ux)) + Ux[I...] = Vx[I...] * _dt + end + if all(I .≤ size(Uy)) + Uy[I...] = Vy[I...] * _dt + end + if !isnothing(Vz) && all(I .≤ size(Uz)) + Uz[I...] = Vz[I...] * _dt + end + return nothing +end + +function displacement2velocity!(stokes::JustRelax.StokesArrays, dt) + displacement2velocity!(stokes, backend(stokes), dt) + return nothing +end + +function displacement2velocity!(stokes::JustRelax.StokesArrays, ::CPUBackendTrait, dt) + _displacement2velocity!(stokes, dt) +end + +function _displacement2velocity!(stokes::JustRelax.StokesArrays, dt) + ni = size(stokes.P) + (; V, U) = stokes + @parallel (@idx ni.+2) _displacement2velocity!(U.Ux, U.Uy, U.Uz, V.Vx, V.Vy, V.Vz, dt) + return nothing +end + +@parallel_indices (I...) function _displacement2velocity!(Ux, Uy, Uz, Vx, Vy, Vz, dt) + if all(I .≤ size(Ux)) + Vx[I...] = Ux[I...] * dt + end + if all(I .≤ size(Uy)) + Vy[I...] = Uy[I...] * dt + end + if !isnothing(Vz) && all(I .≤ size(Uz)) + Vz[I...] = Uz[I...] * dt + end + return nothing +end \ No newline at end of file diff --git a/src/types/stokes.jl b/src/types/stokes.jl index 690ec33f..2f088815 100644 --- a/src/types/stokes.jl +++ b/src/types/stokes.jl @@ -28,7 +28,7 @@ struct Displacement{T} Displacement(Ux::T, Uy::T, Uz::Union{T,Nothing}) where {T} = new{T}(Ux, Uy, Uz) end -Displacement(Vx::T, Vy::T) where {T} = Displacement(Ux, Uy, nothing) +Displacement(Ux::T, Uy::T) where {T} = Displacement(Ux, Uy, nothing) Displacement(ni::NTuple{N,Number}) where {N} = Displacement(ni...) function Displacement(::Number, ::Number) diff --git a/src/types/traits.jl b/src/types/traits.jl index 0ae5bc93..7a9d6fe3 100644 --- a/src/types/traits.jl +++ b/src/types/traits.jl @@ -14,6 +14,7 @@ struct AMDGPUBackendTrait <: GPUBackendTrait end # Custom struct's @inline backend(::JustRelax.Velocity{T}) where {T} = backend(T) +@inline backend(::JustRelax.Displacement{T}) where {T} = backend(T) @inline backend(::JustRelax.SymmetricTensor{T}) where {T} = backend(T) @inline backend(::JustRelax.Residual{T}) where {T} = backend(T) @inline backend(::JustRelax.Viscosity{T}) where {T} = backend(T) diff --git a/test/test_types.jl b/test/test_types.jl index 44931da0..b06123ad 100644 --- a/test/test_types.jl +++ b/test/test_types.jl @@ -1,3 +1,5 @@ +ENV["JULIA_JUSTRELAX_BACKEND"] = "CUDA" + @static if ENV["JULIA_JUSTRELAX_BACKEND"] === "AMDGPU" using AMDGPU elseif ENV["JULIA_JUSTRELAX_BACKEND"] === "CUDA" @@ -80,6 +82,20 @@ const BackendArray = PTArray(backend) @test_throws MethodError JR2.StokesArrays(backend, 10.0, 10.0) end +@testset "2D Displacement" begin + ni = nx, ny = (2, 2) + stokes = JR2.StokesArrays(backend, ni) + + stokes.V.Vx .= 1.0 + stokes.V.Vy .= 1.0 + + JR2.velocity2displacement!(stokes, 10) + @test all(stokes.U.Ux.==0.1) + + JR2.displacement2velocity!(stokes, 5) + @test all(stokes.V.Vx.==0.5) +end + @testset "3D allocators" begin ni = nx, ny, nz = (2, 2, 2) @@ -148,3 +164,18 @@ end @test_throws MethodError JR3.StokesArrays(backend, 10.0, 10.0) end + +@testset "3D Displacement" begin + ni = nx, ny, nz = (2, 2, 2) + stokes = JR3.StokesArrays(backend, ni) + + stokes.V.Vx .= 1.0 + stokes.V.Vy .= 1.0 + stokes.V.Vz .= 1.0 + + JR3.velocity2displacement!(stokes, 10) + @test all(stokes.U.Ux.==0.1) + + JR3.displacement2velocity!(stokes, 5) + @test all(stokes.V.Vx.==0.5) +end \ No newline at end of file