Skip to content

Commit

Permalink
rework and test displacement
Browse files Browse the repository at this point in the history
  • Loading branch information
albert-de-montserrat committed Jun 17, 2024
1 parent 7419f5a commit 4798ceb
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/types/constructors/stokes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function Displacement(nx::Integer, ny::Integer)
nUy = (nx + 2, ny + 1)

Ux, Uy = @zeros(nUx...), @zeros(nUy)
return JustRelax.Displacement(Vx, Vy, nothing)
return JustRelax.Displacement(Ux, Uy, nothing)
end

function Displacement(nx::Integer, ny::Integer, nz::Integer)
Expand Down
57 changes: 57 additions & 0 deletions src/types/displacement.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
function velocity2displacement!(stokes::JustRelax.StokesArrays, dt)
velocity2displacement!(stokes, backend(stokes), dt)
return nothing
end

function velocity2displacement!(stokes::JustRelax.StokesArrays, ::CPUBackendTrait, dt)
_velocity2displacement!(stokes, dt)
end

function _velocity2displacement!(stokes::JustRelax.StokesArrays, dt)
ni = size(stokes.P)
(; V, U) = stokes
@parallel (@idx ni.+2) _velocity2displacement!(V.Vx, V.Vy, V.Vz, U.Ux, U.Uy, U.Uz, 1 / dt)
return nothing
end

@parallel_indices (I...) function _velocity2displacement!(Vx, Vy, Vz, Ux, Uy, Uz, _dt)
if all(I .≤ size(Ux))
Ux[I...] = Vx[I...] * _dt
end
if all(I .≤ size(Uy))
Uy[I...] = Vy[I...] * _dt
end
if !isnothing(Vz) && all(I .≤ size(Uz))
Uz[I...] = Vz[I...] * _dt
end
return nothing
end

function displacement2velocity!(stokes::JustRelax.StokesArrays, dt)
displacement2velocity!(stokes, backend(stokes), dt)
return nothing
end

function displacement2velocity!(stokes::JustRelax.StokesArrays, ::CPUBackendTrait, dt)
_displacement2velocity!(stokes, dt)
end

function _displacement2velocity!(stokes::JustRelax.StokesArrays, dt)
ni = size(stokes.P)
(; V, U) = stokes
@parallel (@idx ni.+2) _displacement2velocity!(U.Ux, U.Uy, U.Uz, V.Vx, V.Vy, V.Vz, dt)
return nothing
end

@parallel_indices (I...) function _displacement2velocity!(Ux, Uy, Uz, Vx, Vy, Vz, dt)
if all(I .≤ size(Ux))
Vx[I...] = Ux[I...] * dt
end
if all(I .≤ size(Uy))
Vy[I...] = Uy[I...] * dt
end
if !isnothing(Vz) && all(I .≤ size(Uz))
Vz[I...] = Uz[I...] * dt
end
return nothing
end
2 changes: 1 addition & 1 deletion src/types/stokes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct Displacement{T}
Displacement(Ux::T, Uy::T, Uz::Union{T,Nothing}) where {T} = new{T}(Ux, Uy, Uz)
end

Displacement(Vx::T, Vy::T) where {T} = Displacement(Ux, Uy, nothing)
Displacement(Ux::T, Uy::T) where {T} = Displacement(Ux, Uy, nothing)

Displacement(ni::NTuple{N,Number}) where {N} = Displacement(ni...)
function Displacement(::Number, ::Number)
Expand Down
1 change: 1 addition & 0 deletions src/types/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct AMDGPUBackendTrait <: GPUBackendTrait end

# Custom struct's
@inline backend(::JustRelax.Velocity{T}) where {T} = backend(T)
@inline backend(::JustRelax.Displacement{T}) where {T} = backend(T)
@inline backend(::JustRelax.SymmetricTensor{T}) where {T} = backend(T)
@inline backend(::JustRelax.Residual{T}) where {T} = backend(T)
@inline backend(::JustRelax.Viscosity{T}) where {T} = backend(T)
Expand Down
31 changes: 31 additions & 0 deletions test/test_types.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
ENV["JULIA_JUSTRELAX_BACKEND"] = "CUDA"

@static if ENV["JULIA_JUSTRELAX_BACKEND"] === "AMDGPU"
using AMDGPU
elseif ENV["JULIA_JUSTRELAX_BACKEND"] === "CUDA"
Expand Down Expand Up @@ -80,6 +82,20 @@ const BackendArray = PTArray(backend)
@test_throws MethodError JR2.StokesArrays(backend, 10.0, 10.0)
end

@testset "2D Displacement" begin
ni = nx, ny = (2, 2)
stokes = JR2.StokesArrays(backend, ni)

stokes.V.Vx .= 1.0
stokes.V.Vy .= 1.0

JR2.velocity2displacement!(stokes, 10)
@test all(stokes.U.Ux.==0.1)

JR2.displacement2velocity!(stokes, 5)
@test all(stokes.V.Vx.==0.5)
end

@testset "3D allocators" begin
ni = nx, ny, nz = (2, 2, 2)

Expand Down Expand Up @@ -148,3 +164,18 @@ end

@test_throws MethodError JR3.StokesArrays(backend, 10.0, 10.0)
end

@testset "3D Displacement" begin
ni = nx, ny, nz = (2, 2, 2)
stokes = JR3.StokesArrays(backend, ni)

stokes.V.Vx .= 1.0
stokes.V.Vy .= 1.0
stokes.V.Vz .= 1.0

JR3.velocity2displacement!(stokes, 10)
@test all(stokes.U.Ux.==0.1)

JR3.displacement2velocity!(stokes, 5)
@test all(stokes.V.Vx.==0.5)
end

0 comments on commit 4798ceb

Please sign in to comment.