diff --git a/AbstractNFFTs/src/derived.jl b/AbstractNFFTs/src/derived.jl index 0501843..b1ce5c5 100644 --- a/AbstractNFFTs/src/derived.jl +++ b/AbstractNFFTs/src/derived.jl @@ -4,17 +4,26 @@ ########################## +# From: https://github.com/JuliaLang/julia/issues/35543 +# To dispatch between CPU/GPU plans we need to strip the type parameters of the array type. +# For example Array{Float32,2} -> Array, Array{Complex{Float32},2} -> Array, CuArray{Float32,2, ...} -> CuArray +# At the moment there is no stable API for this, so we need to use the following workaround: +strip_type_parameters(T) = Base.typename(T).wrapper +# This can change with future Julia versions, so we need to check if the workaround is still needed/working + for op in [:nfft, :nfct, :nfst] planfunc = Symbol("plan_"*"$op") @eval begin -# The following automatically call the plan_* version for type Array +# The following try to find a plan function with the correct array type parameters + +$(planfunc)(k::arrT, args...; kargs...) where {arrT <: AbstractArray} = + $(planfunc)(strip_type_parameters(arrT), k, args...; kargs...) -$(planfunc)(k::AbstractArray, N::Union{Integer,NTuple{D,Int}}, args...; kargs...) where {D} = - $(planfunc)(Array, k, N, args...; kargs...) +$(planfunc)(k::arrL, args...; kargs...) where {T, arrT <: AbstractArray{T}, arrL <: Union{Adjoint{T, arrT}, Transpose{T, arrT}}} = + $(planfunc)(strip_type_parameters(arrT), k, y, args...; kargs...) -$(planfunc)(k::AbstractArray, y::AbstractArray, args...; kargs...) = - $(planfunc)(Array, k, y, args...; kargs...) +$(planfunc)(k::AbstractRange, args...; kargs...) = $(planfunc)(collect(k), args...; kargs...) # The follow convert 1D parameters into the format required by the plan @@ -22,7 +31,7 @@ $(planfunc)(Q::Type, k::AbstractVector, N::Integer, rest...; kwargs...) = $(planfunc)(Q, collect(reshape(k,1,length(k))), (N,), rest...; kwargs...) $(planfunc)(Q::Type, k::AbstractVector, N::NTuple{D,Int}, rest...; kwargs...) where {D} = - $(planfunc)(Q, collect(reshape(k,1,length(k))), N, rest...; kwargs...) + $(planfunc)(Q, collect(reshape(k,1,length(k))), N, rest...; kwargs...) $(planfunc)(Q::Type, k::AbstractMatrix, N::NTuple{D,Int}, rest...; kwargs...) where {D} = $(planfunc)(Q, collect(k), N, rest...; kwargs...) diff --git a/ext/NFFTGPUArraysExt/implementation.jl b/ext/NFFTGPUArraysExt/implementation.jl index 4367929..0e8f5ac 100644 --- a/ext/NFFTGPUArraysExt/implementation.jl +++ b/ext/NFFTGPUArraysExt/implementation.jl @@ -16,6 +16,8 @@ mutable struct GPU_NFFTPlan{T,D, arrTc <: AbstractGPUArray{Complex{T}, D}, vecI B::SM end +# Atm initParams is not supported for k != Array, so in the case of inferred arr from k::GPUArray we need to convert k to Array +AbstractNFFTs.plan_nfft(arr::Type{<:AbstractGPUArray}, k::AbstractMatrix, args...; kwargs...) = AbstractNFFTs.plan_nfft(arr, Array(k), args...; kwargs...) function AbstractNFFTs.plan_nfft(arr::Type{<:AbstractGPUArray}, k::Matrix{T}, N::NTuple{D,Int}, rest...; timing::Union{Nothing,TimingStats} = nothing, kargs...) where {T,D} t = @elapsed begin diff --git a/test/gpu.jl b/test/gpu.jl index c8903c2..2d74fff 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -16,12 +16,18 @@ m = 5 p = plan_nfft(Array, k, N; m, σ, window, precompute=NFFT.FULL, fftflags=FFTW.ESTIMATE) p_d = plan_nfft(arrayType, k, N; m, σ, window, precompute=NFFT.FULL) + p_d_infer = plan_nfft(arrayType(k), N; m, σ, window, precompute=NFFT.FULL) pNDFT = NDFTPlan(k, N) fHat = rand(Float64, J) + rand(Float64, J) * im f = adjoint(pNDFT) * fHat fHat_d = arrayType(fHat) + + # GPU NFFT fApprox_d = adjoint(p_d) * fHat_d + fApprox_d_infer = adjoint(p_d_infer) * fHat_d + @test fApprox_d ≈ fApprox_d_infer + fApprox = Array(fApprox_d) e = norm(f[:] - fApprox[:]) / norm(f[:]) @debug "error adjoint nfft " e @@ -29,6 +35,8 @@ m = 5 gHat = pNDFT * f gHatApprox = Array(p_d * arrayType(f)) + gHatApproxInfer = Array(p_d_infer * arrayType(f)) + @test gHatApprox ≈ gHatApproxInfer e = norm(gHat[:] - gHatApprox[:]) / norm(gHat[:]) @debug "error nfft " e @test e < eps[l] @@ -49,6 +57,11 @@ m = 5 p = plan_nfft(arrayType, nodes, (N, N); m=5, σ=2.0) weights = Array(sdc(p, iters=5)) + # Infer the correct plan_nfft type + p_infer = plan_nfft(arrayType(nodes), (N, N); m=5, σ=2.0) + weights_infer = Array(sdc(p_infer, iters=5)) + @test weights ≈ weights_infer + @info extrema(vec(weights)) @test all((≈).(vec(weights), 1 / (N * N), rtol=1e-7))