Skip to content

Commit

Permalink
NFFT on GPU based on Package Extensions and GPUArrays.jl (#136)
Browse files Browse the repository at this point in the history
* Init working GPU_Plan based on package extension

* Fix parametric struct and functions for GPUNFFTPlan

* Slightly improve deconvolve_transpose! performance for PGU

* Add tests for GPU NFFT Plan

* Increase min. Julia compat to 1.9
  • Loading branch information
nHackel authored Jul 3, 2024
1 parent 9cf01cd commit dd72b2c
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 76 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
- '1.9' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
- '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia.
# - 'nightly'
os:
Expand Down
17 changes: 13 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,40 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
Adapt = "3, 4"
AbstractNFFTs = "0.8"
BasicInterpolators = "0.6.5, 0.7"
DataFrames = "1.3.1, 1.4.1"
FFTW = "1.5"
FINUFFT = "3.0.1"
FLoops = "0.2"
GPUArrays = "8, 9, 10"
JLArrays = "0.1.2"
Reexport = "1.0"
PrecompileTools = "1"
SpecialFunctions = "0.8, 0.10, 1, 2"
julia = "1.6"
julia = "1.9"
#StaticArrays = "1.4"
Ducc0 = "0.1"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CuNFFT = "a9291f20-7f4c-4d50-b30d-4e07b13252e1"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FINUFFT = "d8beea63-0952-562e-9c6a-8e8ef7364055"
NFFT3 = "53104703-03e8-40a5-ab01-812303a44cae"
NFFTTools = "7424e34d-94f7-41d6-98a0-85abaf1b6c91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Ducc0 = "47ec601d-2729-4ac9-bed9-2b3ab5fca9ff"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"

[targets]
test = ["Test", "BenchmarkTools", "FINUFFT", "NFFT3", "CuNFFT", "Zygote",
"NFFTTools", "DataFrames", "Ducc0"] # "NFFTTools" "CuNFFT"
test = ["Test", "JLArrays", "BenchmarkTools", "FINUFFT", "NFFT3", "Zygote",
"NFFTTools", "DataFrames", "Ducc0"] # "NFFTTools"

[extensions]
NFFTGPUArraysExt = ["Adapt", "GPUArrays"]
9 changes: 9 additions & 0 deletions ext/NFFTGPUArraysExt/NFFTGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module NFFTGPUArraysExt

using NFFT, NFFT.AbstractNFFTs
using NFFT.SparseArrays, NFFT.LinearAlgebra, NFFT.FFTW
using GPUArrays, Adapt

include("implementation.jl")

end
128 changes: 128 additions & 0 deletions ext/NFFTGPUArraysExt/implementation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
mutable struct GPU_NFFTPlan{T,D, arrTc <: AbstractGPUArray{Complex{T}, D}, vecI <: AbstractGPUVector{Int32}, FP, BP, INV, SM} <: AbstractNFFTPlan{T,D,1}
N::NTuple{D,Int64}
NOut::NTuple{1,Int64}
J::Int64
k::Matrix{T}
::NTuple{D,Int64}
dims::UnitRange{Int64}
params::NFFTParams{T}
forwardFFT::FP
backwardFFT::BP
tmpVec::arrTc
tmpVecHat::arrTc
deconvolveIdx::vecI
windowLinInterp::Vector{T}
windowHatInvLUT::INV
B::SM
end

function AbstractNFFTs.plan_nfft(arr::Type{<:AbstractGPUArray}, k::Matrix{T}, N::NTuple{D,Int}, rest...;
timing::Union{Nothing,TimingStats} = nothing, kargs...) where {T,D}
t = @elapsed begin
p = GPU_NFFTPlan(arr, k, N, rest...; kargs...)
end
if timing != nothing
timing.pre = t
end
return p
end

function GPU_NFFTPlan(arr, k::Matrix{T}, N::NTuple{D,Int}; dims::Union{Integer,UnitRange{Int64}}=1:D,
fftflags=nothing, kwargs...) where {T,D}

if dims != 1:D
error("GPU NFFT does not work along directions right now!")
end

params, N, NOut, J, Ñ, dims_ = NFFT.initParams(k, N, dims; kwargs...)
params.storeDeconvolutionIdx = true # GPU_NFFT only works this way
params.precompute = NFFT.FULL # GPU_NFFT only works this way

tmpVec = adapt(arr, zeros(Complex{T}, Ñ))

FP = plan_fft!(tmpVec, dims_)
BP = plan_bfft!(tmpVec, dims_)

windowLinInterp, windowPolyInterp, windowHatInvLUT, deconvolveIdx, B = NFFT.precomputation(k, N[dims_], Ñ[dims_], params)

U = params.storeDeconvolutionIdx ? N : ntuple(d->0,D)
tmpVecHat = adapt(arr, zeros(Complex{T}, U))

deconvIdx = Int32.(adapt(arr, (deconvolveIdx)))
winHatInvLUT = Complex{T}.(adapt(arr, (windowHatInvLUT[1])))
B_ = Complex{T}.(adapt(arr, (B))) # Bit hacky

GPU_NFFTPlan{T,D, typeof(tmpVec), typeof(deconvIdx), typeof(FP), typeof(BP), typeof(winHatInvLUT), typeof(B_)}(N, NOut, J, k, Ñ, dims_, params, FP, BP, tmpVec, tmpVecHat,
deconvIdx, windowLinInterp, winHatInvLUT, B_)
end

AbstractNFFTs.size_in(p::GPU_NFFTPlan) = p.N
AbstractNFFTs.size_out(p::GPU_NFFTPlan) = p.NOut


function AbstractNFFTs.convolve!(p::GPU_NFFTPlan{T,D, arrTc}, g::arrTc, fHat::arrH) where {D,T,arr<: AbstractGPUArray, arrTc <: arr, arrH <: arr}
mul!(fHat, transpose(p.B), vec(g))
return
end

function AbstractNFFTs.convolve_transpose!(p::GPU_NFFTPlan{T,D, arrTc}, fHat::arrH, g::arrTc) where {D,T,arr<: AbstractGPUArray, arrTc <: arr, arrH <: arr}
mul!(vec(g), p.B, fHat)
return
end

function AbstractNFFTs.deconvolve!(p::GPU_NFFTPlan{T,D, arrTc}, f::arrF, g::arrTc) where {D,T,arr<: AbstractGPUArray, arrTc <: arr, arrF <: arr}
p.tmpVecHat[:] .= vec(f) .* p.windowHatInvLUT
g[p.deconvolveIdx] = p.tmpVecHat
return
end

function AbstractNFFTs.deconvolve_transpose!(p::GPU_NFFTPlan{T,D, arrTc}, g::arrTc, f::arrF) where {D,T,arr<: AbstractGPUArray, arrTc <: arr, arrF <: arr}
p.tmpVecHat[:] .= broadcast(p.deconvolveIdx) do idx
g[idx]
end
f[:] .= vec(p.tmpVecHat) .* p.windowHatInvLUT
return
end

""" in-place NFFT on the GPU"""
function LinearAlgebra.mul!(fHat::arrH, p::GPU_NFFTPlan{T,D, arrT}, f::arrF;
verbose=false, timing::Union{Nothing,TimingStats} = nothing) where {T,D,arr<: AbstractGPUArray, arrT <: arr, arrH <: arr, arrF <: arr}
NFFT.consistencyCheck(p, f, fHat)

fill!(p.tmpVec, zero(Complex{T}))
t1 = @elapsed @inbounds deconvolve!(p, f, p.tmpVec)
t2 = @elapsed p.forwardFFT * p.tmpVec
t3 = @elapsed @inbounds convolve!(p, p.tmpVec, fHat)
if verbose
@info "Timing: deconv=$t1 fft=$t2 conv=$t3"
end
if timing != nothing
timing.conv = t3
timing.fft = t2
timing.deconv = t1
end

return fHat
end

""" in-place adjoint NFFT on the GPU"""
function LinearAlgebra.mul!(f::arrF, pl::Adjoint{Complex{T},<:GPU_NFFTPlan{T,D, arrT}}, fHat::arrH;
verbose=false, timing::Union{Nothing,TimingStats} = nothing) where {T,D,arr<: AbstractGPUArray, arrT <: arr, arrH <: arr, arrF <: arr}
p = pl.parent
NFFT.consistencyCheck(p, f, fHat)

t1 = @elapsed @inbounds convolve_transpose!(p, fHat, p.tmpVec)
t2 = @elapsed p.backwardFFT * p.tmpVec
t3 = @elapsed @inbounds deconvolve_transpose!(p, p.tmpVec, f)
if verbose
@info "Timing: conv=$t1 fft=$t2 deconv=$t3"
end
if timing != nothing
timing.conv_adjoint = t1
timing.fft_adjoint = t2
timing.deconv_adjoint = t3
end

return f
end

61 changes: 0 additions & 61 deletions test/cuda.jl

This file was deleted.

58 changes: 58 additions & 0 deletions test/gpu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
m = 5
σ = 2.0

@testset "GPU NFFT Plans" begin
for arrayType in arrayTypes

@testset "GPU_NFFT in multiple dimensions" begin
for (u, N) in enumerate([(256,), (32, 32), (12, 12, 12)])
eps = [1e-7, 1e-3, 1e-6, 1e-4]
for (l, window) in enumerate([:kaiser_bessel, :gauss, :kaiser_bessel_rev, :spline])
D = length(N)
@info "Testing $arrayType in $D dimensions using $window window"

J = prod(N)
k = rand(Float64, D, J) .- 0.5
p = plan_nfft(Array, k, N; m, σ, window, precompute=NFFT.FULL,
fftflags=FFTW.ESTIMATE)
p_d = plan_nfft(arrayType, k, N; m, σ, window, precompute=NFFT.FULL)
pNDFT = NDFTPlan(k, N)

fHat = rand(Float64, J) + rand(Float64, J) * im
f = adjoint(pNDFT) * fHat
fHat_d = arrayType(fHat)
fApprox_d = adjoint(p_d) * fHat_d
fApprox = Array(fApprox_d)
e = norm(f[:] - fApprox[:]) / norm(f[:])
@debug "error adjoint nfft " e
@test e < eps[l]

gHat = pNDFT * f
gHatApprox = Array(p_d * arrayType(f))
e = norm(gHat[:] - gHatApprox[:]) / norm(gHat[:])
@debug "error nfft " e
@test e < eps[l]
end
end
end

@testset "GPU_NFFT Sampling Density" begin

# create a 10x10 grid of unit spaced sampling points
N = 10
g = (0:(N-1)) ./ N .- 0.5
x = vec(ones(N) * g')
y = vec(g * ones(N)')
nodes = cat(x', y', dims=1)

# approximate the density weights
p = plan_nfft(arrayType, nodes, (N, N); m=5, σ=2.0)
weights = Array(sdc(p, iters=5))

@info extrema(vec(weights))

@test all(().(vec(weights), 1 / (N * N), rtol=1e-7))

end
end
end
5 changes: 5 additions & 0 deletions test/gpu/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using CUDA

arrayTypes = [CuArray]

include(joinpath(@__DIR__(), "..", "runtests.jl"))
5 changes: 5 additions & 0 deletions test/gpu/rocm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
using AMDGPU

arrayTypes = [ROCArray]

include(joinpath(@__DIR__(), "..", "runtests.jl"))
31 changes: 21 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,27 @@ using LinearAlgebra
using FFTW
using NFFTTools
using Zygote
using JLArrays

Random.seed!(123)
areTypesDefined = @isdefined arrayTypes
arrayTypes = areTypesDefined ? arrayTypes : [JLArray]

include("issues.jl")
include("accuracy.jl")
include("constructors.jl")
include("performance.jl")
include("testToeplitz.jl")
include("samplingDensity.jl")
include("cuda.jl")
include("chainrules.jl")
# Need to run after the other tests since they overload plan_*
include("wrappers.jl")
@testset "NFFT" begin
# If types were not defined we run everything
if !areTypesDefined
include("issues.jl")
include("accuracy.jl")
include("constructors.jl")
include("performance.jl")
include("testToeplitz.jl")
include("samplingDensity.jl")
include("gpu.jl")
include("chainrules.jl")
# Need to run after the other tests since they overload plan_*
include("wrappers.jl")
# If types were defined we only run GPU related tests
else
include("gpu.jl")
end
end

0 comments on commit dd72b2c

Please sign in to comment.