Skip to content

Commit

Permalink
fix Dirichlet on GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
albert-de-montserrat committed Dec 11, 2024
1 parent 9fb3249 commit 46ac3c0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
32 changes: 16 additions & 16 deletions src/boundaryconditions/Dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ abstract type AbstractDirichletBoundaryCondition{T,M} end
struct DirichletBoundaryCondition{T,M} <: AbstractDirichletBoundaryCondition{T,M}
value::T
mask::M
end

function DirichletBoundaryCondition(value::T, mask::M) where {T,M}
return new{T,M}(value, mask)
end
# function DirichletBoundaryCondition(value::T, mask::M) where {T,M}
# return DirichletBoundaryCondition{T,M}(value, mask)
# end

function DirichletBoundaryCondition()
return new{Nothing,Nothing}(nothing, nothing)
end
function DirichletBoundaryCondition()
return DirichletBoundaryCondition{Nothing,Nothing}(nothing, nothing)
end

Adapt.@adapt_structure DirichletBoundaryCondition
Expand Down Expand Up @@ -54,16 +54,15 @@ end
struct ConstantDirichletBoundaryCondition{T,M} <: AbstractDirichletBoundaryCondition{T,M}
value::T
mask::M
end

function ConstantDirichletBoundaryCondition(value::N, mask::M) where {N<:Number,M}
v = ConstantArray(value)
T = typeof(v)
return new{T,M}(v, mask)
end
function ConstantDirichletBoundaryCondition(value::T, mask::M) where {T<:Number,M}
v = ConstantArray(value)
return ConstantDirichletBoundaryCondition{typeof(v),M}(v, mask)
end

function ConstantDirichletBoundaryCondition()
return new{Nothing,Nothing}(nothing, nothing)
end
function ConstantDirichletBoundaryCondition()
return ConstantDirichletBoundaryCondition{Nothing,Nothing}(nothing, nothing)
end

Adapt.@adapt_structure ConstantDirichletBoundaryCondition
Expand All @@ -72,13 +71,14 @@ function Base.getindex(x::ConstantDirichletBoundaryCondition, inds::Vararg{Int,N
return x.value * x.mask[inds...]
end
function Base.getindex(
x::ConstantDirichletBoundaryCondition{Nothing,Nothing}, ::Vararg{Int,N}
::ConstantDirichletBoundaryCondition{Nothing,Nothing}, ::Vararg{Int,N}
) where {N}
return 0
end

@inline function apply_dirichlet!(A::AbstractArray, bc::AbstractDirichletBoundaryCondition)
return apply_mask!(A, bc.value, bc.mask)
apply_mask!(A, bc.value, bc.mask)
return nothing
end

@inline function apply_dirichlet!(
Expand Down
16 changes: 8 additions & 8 deletions src/mask/mask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@ Base.similar(m::Mask) = Mask(size(m)...)

@inline dims(::Mask{A}) where {A<:AbstractArray{T,N}} where {T,N} = N

@inline apply_mask!(A::AbstractArray, B::AbstractArray, m::Mask) =
@inline apply_mask!(A::AbstractArray, B::Any, m::Mask) =
(A .= inv(m) .* A .+ m.mask .* B)
@inline apply_mask!(::AbstractArray, ::AbstractArray, ::Nothing) = nothing
@inline apply_mask!(::AbstractArray, ::Any, ::Nothing) = nothing
@inline apply_mask!(
A::AbstractArray, B::AbstractArray, m::Mask, inds::Vararg{Int,N}
A::AbstractArray, B::Any, m::Mask, inds::Vararg{Int,N}
) where {N} = (A[inds...] = inv(m, inds...) * A[inds...] + m[inds...] * B[inds...])
@inline apply_mask!(
::AbstractArray, ::AbstractArray, ::Nothing, inds::Vararg{Int,N}
::AbstractArray, ::Any, ::Nothing, inds::Vararg{Int,N}
) where {N} = nothing

@inline apply_mask(A::AbstractArray, B::AbstractArray, m::Mask) = inv(m) .* A .+ m.mask .* B
@inline apply_mask(A::AbstractArray, B::Any, m::Mask) = inv(m) .* A .+ m.mask .* B
@inline apply_mask(
A::AbstractArray, B::AbstractArray, m::Mask, inds::Vararg{Int,N}
A::AbstractArray, B::Any, m::Mask, inds::Vararg{Int,N}
) where {N} = @muladd inv(m, inds...) * A[inds...] + m[inds...] * B[inds...]
@inline apply_mask(A::AbstractArray, ::AbstractArray, ::Nothing) = A
@inline apply_mask(A::AbstractArray, ::Any, ::Nothing) = A
@inline apply_mask(
A::AbstractArray, ::AbstractArray, ::Nothing, inds::Vararg{Int,N}
A::AbstractArray, ::Any, ::Nothing, inds::Vararg{Int,N}
) where {N} = A[inds...]

0 comments on commit 46ac3c0

Please sign in to comment.