From 46ac3c0aca0182a81dab39b752e2bd6aef38127d Mon Sep 17 00:00:00 2001 From: albert-de-montserrat Date: Wed, 11 Dec 2024 18:17:16 +0100 Subject: [PATCH] fix Dirichlet on GPUs --- src/boundaryconditions/Dirichlet.jl | 32 ++++++++++++++--------------- src/mask/mask.jl | 16 +++++++-------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/boundaryconditions/Dirichlet.jl b/src/boundaryconditions/Dirichlet.jl index 52978161..14dbd251 100644 --- a/src/boundaryconditions/Dirichlet.jl +++ b/src/boundaryconditions/Dirichlet.jl @@ -2,14 +2,14 @@ abstract type AbstractDirichletBoundaryCondition{T,M} end struct DirichletBoundaryCondition{T,M} <: AbstractDirichletBoundaryCondition{T,M} value::T mask::M +end - function DirichletBoundaryCondition(value::T, mask::M) where {T,M} - return new{T,M}(value, mask) - end +# function DirichletBoundaryCondition(value::T, mask::M) where {T,M} +# return DirichletBoundaryCondition{T,M}(value, mask) +# end - function DirichletBoundaryCondition() - return new{Nothing,Nothing}(nothing, nothing) - end +function DirichletBoundaryCondition() + return DirichletBoundaryCondition{Nothing,Nothing}(nothing, nothing) end Adapt.@adapt_structure DirichletBoundaryCondition @@ -54,16 +54,15 @@ end struct ConstantDirichletBoundaryCondition{T,M} <: AbstractDirichletBoundaryCondition{T,M} value::T mask::M +end - function ConstantDirichletBoundaryCondition(value::N, mask::M) where {N<:Number,M} - v = ConstantArray(value) - T = typeof(v) - return new{T,M}(v, mask) - end +function ConstantDirichletBoundaryCondition(value::T, mask::M) where {T<:Number,M} + v = ConstantArray(value) + return ConstantDirichletBoundaryCondition{typeof(v),M}(v, mask) +end - function ConstantDirichletBoundaryCondition() - return new{Nothing,Nothing}(nothing, nothing) - end +function ConstantDirichletBoundaryCondition() + return ConstantDirichletBoundaryCondition{Nothing,Nothing}(nothing, nothing) end Adapt.@adapt_structure ConstantDirichletBoundaryCondition @@ -72,13 +71,14 @@ function Base.getindex(x::ConstantDirichletBoundaryCondition, inds::Vararg{Int,N return x.value * x.mask[inds...] end function Base.getindex( - x::ConstantDirichletBoundaryCondition{Nothing,Nothing}, ::Vararg{Int,N} + ::ConstantDirichletBoundaryCondition{Nothing,Nothing}, ::Vararg{Int,N} ) where {N} return 0 end @inline function apply_dirichlet!(A::AbstractArray, bc::AbstractDirichletBoundaryCondition) - return apply_mask!(A, bc.value, bc.mask) + apply_mask!(A, bc.value, bc.mask) + return nothing end @inline function apply_dirichlet!( diff --git a/src/mask/mask.jl b/src/mask/mask.jl index 6147cefe..8370a6f4 100644 --- a/src/mask/mask.jl +++ b/src/mask/mask.jl @@ -35,21 +35,21 @@ Base.similar(m::Mask) = Mask(size(m)...) @inline dims(::Mask{A}) where {A<:AbstractArray{T,N}} where {T,N} = N -@inline apply_mask!(A::AbstractArray, B::AbstractArray, m::Mask) = +@inline apply_mask!(A::AbstractArray, B::Any, m::Mask) = (A .= inv(m) .* A .+ m.mask .* B) -@inline apply_mask!(::AbstractArray, ::AbstractArray, ::Nothing) = nothing +@inline apply_mask!(::AbstractArray, ::Any, ::Nothing) = nothing @inline apply_mask!( - A::AbstractArray, B::AbstractArray, m::Mask, inds::Vararg{Int,N} + A::AbstractArray, B::Any, m::Mask, inds::Vararg{Int,N} ) where {N} = (A[inds...] = inv(m, inds...) * A[inds...] + m[inds...] * B[inds...]) @inline apply_mask!( - ::AbstractArray, ::AbstractArray, ::Nothing, inds::Vararg{Int,N} + ::AbstractArray, ::Any, ::Nothing, inds::Vararg{Int,N} ) where {N} = nothing -@inline apply_mask(A::AbstractArray, B::AbstractArray, m::Mask) = inv(m) .* A .+ m.mask .* B +@inline apply_mask(A::AbstractArray, B::Any, m::Mask) = inv(m) .* A .+ m.mask .* B @inline apply_mask( - A::AbstractArray, B::AbstractArray, m::Mask, inds::Vararg{Int,N} + A::AbstractArray, B::Any, m::Mask, inds::Vararg{Int,N} ) where {N} = @muladd inv(m, inds...) * A[inds...] + m[inds...] * B[inds...] -@inline apply_mask(A::AbstractArray, ::AbstractArray, ::Nothing) = A +@inline apply_mask(A::AbstractArray, ::Any, ::Nothing) = A @inline apply_mask( - A::AbstractArray, ::AbstractArray, ::Nothing, inds::Vararg{Int,N} + A::AbstractArray, ::Any, ::Nothing, inds::Vararg{Int,N} ) where {N} = A[inds...]