Skip to content

Commit

Permalink
fix conversion of structs from GPU to CPU (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
albert-de-montserrat authored Apr 29, 2024
1 parent 8afbd4b commit f9c943d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions src/array_conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@ struct CPUDeviceTrait <: DeviceTrait end
struct NonCPUDeviceTrait <: DeviceTrait end

@inline iscpu(::Array) = CPUDeviceTrait()
@inline iscpu(::AbstractArray) = NonCPUDeviceTrait()
@inline iscpu(::T) where {T<:AbstractArray} = NonCPUDeviceTrait()
@inline iscpu(::T) where {T} = throw(ArgumentError("Unknown device"))

@inline iscpu(::Velocity{Array{T,N}}) where {T,N} = CPUDeviceTrait()
@inline iscpu(::Velocity{AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()
@inline iscpu(::Velocity{<:AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()

@inline iscpu(::SymmetricTensor{Array{T,N}}) where {T,N} = CPUDeviceTrait()
@inline iscpu(::SymmetricTensor{AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()
@inline iscpu(::SymmetricTensor{<:AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()

@inline iscpu(::Residual{Array{T,N}}) where {T,N} = CPUDeviceTrait()
@inline iscpu(::Residual{AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()
@inline iscpu(::Residual{<:AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()

@inline iscpu(::ThermalArrays{Array{T,N}}) where {T,N} = CPUDeviceTrait()
@inline iscpu(::ThermalArrays{AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()
@inline iscpu(::ThermalArrays{<:AbstractArray{T,N}}) where {T,N} = NonCPUDeviceTrait()

@inline iscpu(::StokesArrays{M,A,B,C,Array{T,N},nDim}) where {M,A,B,C,T,N,nDim} =
CPUDeviceTrait()
@inline iscpu(::StokesArrays{M,A,B,C,AbstractArray{T,N},nDim}) where {M,A,B,C,T,N,nDim} =
@inline iscpu(::StokesArrays{M,A,B,C,<:AbstractArray{T,N},nDim}) where {M,A,B,C,T,N,nDim} =
NonCPUDeviceTrait()

## Conversion of structs to CPU
Expand Down
6 changes: 3 additions & 3 deletions src/stokes/MetaStokes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function make_velocity_struct!(ndim::Integer; name::Symbol=:Velocity)
return new{$PTArray}(@zeros(ni[1]...), @zeros(ni[2]...), @zeros(ni[3]...))
end

$(name)(args::Vararg{T,N}) where {T<:AbstractArray,N} = new{$PTArray}(args...)
$(name)(args::Vararg{T,N}) where {T<:AbstractArray,N} = new{T}(args...)
end
end
end
Expand Down Expand Up @@ -83,7 +83,7 @@ function make_symmetrictensor_struct!(nDim::Integer; name::Symbol=:SymmetricTens
)
end

$(name)(args::Vararg{T,N}) where {T<:AbstractArray,N} = new{$PTArray}(args...)
$(name)(args::Vararg{T,N}) where {T<:AbstractArray,N} = new{T}(args...)
end
end
end
Expand Down Expand Up @@ -111,7 +111,7 @@ function make_residual_struct!(ndim; name::Symbol=:Residual)
return new{typeof(Rx)}(Rx, Ry, Rz, RP)
end

$(name)(args::Vararg{T,N}) where {T<:AbstractArray,N} = new{$PTArray}(args...)
$(name)(args::Vararg{T,N}) where {T<:AbstractArray,N} = new{T}(args...)
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/thermal_diffusion/MetaDiffusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function make_thermal_arrays!(ndim)
ResT::_T

function ThermalArrays(args::Vararg{T,N}) where {T<:AbstractArray,N}
return new{$PTArray}(args...)
return new{T}(args...)
end

function ThermalArrays(ni::NTuple{1,Integer})
Expand Down

0 comments on commit f9c943d

Please sign in to comment.