diff --git a/lib/cufft/fft.jl b/lib/cufft/fft.jl index c431453d1f..1bbe5ec688 100644 --- a/lib/cufft/fft.jl +++ b/lib/cufft/fft.jl @@ -455,3 +455,12 @@ function Base.:(*)(p::cCuFFTPlan{T,K,false,N}, x::DenseCuArray{T,M}) where {T,K, unsafe_execute_trailing!(p,x, y) y end + +## support adjoints of FFT plans + +AbstractFFTs.AdjointStyle(::cCuFFTPlan) = AbstractFFTs.FFTAdjointStyle() +AbstractFFTs.AdjointStyle(::rCuFFTPlan{T, CUFFT_FORWARD}) where {T} = AbstractFFTs.RFFTAdjointStyle() +AbstractFFTs.AdjointStyle(p::rCuFFTPlan{T, CUFFT_INVERSE}) where {T} = AbstractFFTs.IRFFTAdjointStyle(p.osz[first(p.region)]) + +# manually resolve ambiguity in adjoint plan application +Base.:(*)(p::AbstractFFTs.AdjointPlan{T}, x::CuArray) where T = Base.invoke(*, Tuple{AbstractFFTs.AdjointPlan, AbstractArray}, p, x) \ No newline at end of file diff --git a/test/libraries/cufft.jl b/test/libraries/cufft.jl index 6c3cbfa200..e45c2ab6ce 100644 --- a/test/libraries/cufft.jl +++ b/test/libraries/cufft.jl @@ -342,3 +342,10 @@ end @test Array(dy) ≈ y end + +## AbstractFFTs tests, which also test adjoint functionality of CUFFT plans. + +@testset "AbstractFFTs FFT backend tests" begin + AbstractFFTs.TestUtils.test_complex_ffts(CuArray; test_wrappers=false) + AbstractFFTs.TestUtils.test_real_ffts(CuArray; copy_input=true, test_wrappers=false) +end