Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer arraytype for plan_nfft from k #142

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions AbstractNFFTs/src/derived.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,34 @@
##########################


# From: https://github.com/JuliaLang/julia/issues/35543
# To dispatch between CPU/GPU plans we need to strip the type parameters of the array type.
# For example Array{Float32,2} -> Array, Array{Complex{Float32},2} -> Array, CuArray{Float32,2, ...} -> CuArray
# At the moment there is no stable API for this, so we need to use the following workaround:
strip_type_parameters(T) = Base.typename(T).wrapper
# This can change with future Julia versions, so we need to check if the workaround is still needed/working

for op in [:nfft, :nfct, :nfst]
planfunc = Symbol("plan_"*"$op")
@eval begin

# The following automatically call the plan_* version for type Array
# The following try to find a plan function with the correct array type parameters

$(planfunc)(k::arrT, args...; kargs...) where {arrT <: AbstractArray} =
$(planfunc)(strip_type_parameters(arrT), k, args...; kargs...)

$(planfunc)(k::AbstractArray, N::Union{Integer,NTuple{D,Int}}, args...; kargs...) where {D} =
$(planfunc)(Array, k, N, args...; kargs...)
$(planfunc)(k::arrL, args...; kargs...) where {T, arrT <: AbstractArray{T}, arrL <: Union{Adjoint{T, arrT}, Transpose{T, arrT}}} =
$(planfunc)(strip_type_parameters(arrT), k, y, args...; kargs...)

$(planfunc)(k::AbstractArray, y::AbstractArray, args...; kargs...) =
$(planfunc)(Array, k, y, args...; kargs...)
$(planfunc)(k::AbstractRange, args...; kargs...) = $(planfunc)(collect(k), args...; kargs...)

# The follow convert 1D parameters into the format required by the plan

$(planfunc)(Q::Type, k::AbstractVector, N::Integer, rest...; kwargs...) =
$(planfunc)(Q, collect(reshape(k,1,length(k))), (N,), rest...; kwargs...)

$(planfunc)(Q::Type, k::AbstractVector, N::NTuple{D,Int}, rest...; kwargs...) where {D} =
$(planfunc)(Q, collect(reshape(k,1,length(k))), N, rest...; kwargs...)
$(planfunc)(Q, collect(reshape(k,1,length(k))), N, rest...; kwargs...)

$(planfunc)(Q::Type, k::AbstractMatrix, N::NTuple{D,Int}, rest...; kwargs...) where {D} =
$(planfunc)(Q, collect(k), N, rest...; kwargs...)
Expand Down
2 changes: 2 additions & 0 deletions ext/NFFTGPUArraysExt/implementation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ mutable struct GPU_NFFTPlan{T,D, arrTc <: AbstractGPUArray{Complex{T}, D}, vecI
B::SM
end

# Atm initParams is not supported for k != Array, so in the case of inferred arr from k::GPUArray we need to convert k to Array
AbstractNFFTs.plan_nfft(arr::Type{<:AbstractGPUArray}, k::AbstractMatrix, args...; kwargs...) = AbstractNFFTs.plan_nfft(arr, Array(k), args...; kwargs...)
function AbstractNFFTs.plan_nfft(arr::Type{<:AbstractGPUArray}, k::Matrix{T}, N::NTuple{D,Int}, rest...;
timing::Union{Nothing,TimingStats} = nothing, kargs...) where {T,D}
t = @elapsed begin
Expand Down
13 changes: 13 additions & 0 deletions test/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,27 @@ m = 5
p = plan_nfft(Array, k, N; m, σ, window, precompute=NFFT.FULL,
fftflags=FFTW.ESTIMATE)
p_d = plan_nfft(arrayType, k, N; m, σ, window, precompute=NFFT.FULL)
p_d_infer = plan_nfft(arrayType(k), N; m, σ, window, precompute=NFFT.FULL)
pNDFT = NDFTPlan(k, N)

fHat = rand(Float64, J) + rand(Float64, J) * im
f = adjoint(pNDFT) * fHat
fHat_d = arrayType(fHat)

# GPU NFFT
fApprox_d = adjoint(p_d) * fHat_d
fApprox_d_infer = adjoint(p_d_infer) * fHat_d
@test fApprox_d ≈ fApprox_d_infer

fApprox = Array(fApprox_d)
e = norm(f[:] - fApprox[:]) / norm(f[:])
@debug "error adjoint nfft " e
@test e < eps[l]

gHat = pNDFT * f
gHatApprox = Array(p_d * arrayType(f))
gHatApproxInfer = Array(p_d_infer * arrayType(f))
@test gHatApprox ≈ gHatApproxInfer
e = norm(gHat[:] - gHatApprox[:]) / norm(gHat[:])
@debug "error nfft " e
@test e < eps[l]
Expand All @@ -49,6 +57,11 @@ m = 5
p = plan_nfft(arrayType, nodes, (N, N); m=5, σ=2.0)
weights = Array(sdc(p, iters=5))

# Infer the correct plan_nfft type
p_infer = plan_nfft(arrayType(nodes), (N, N); m=5, σ=2.0)
weights_infer = Array(sdc(p_infer, iters=5))
@test weights ≈ weights_infer

@info extrema(vec(weights))

@test all((≈).(vec(weights), 1 / (N * N), rtol=1e-7))
Expand Down
Loading