Skip to content

reverse the order of individual FFTs in rfftn #2524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Aligned the license expression with `PEP-639` [#2511](https://github.com/IntelPython/dpnp/pull/2511)
* Bumped oneMKL version up to `v0.8` [#2514](https://github.com/IntelPython/dpnp/pull/2514)
* Removed the use of class template argument deduction for alias template to conform to the C++17 standard [#2517](https://github.com/IntelPython/dpnp/pull/2517)
* Changed th order of individual FFTs over `axes` for `dpnp.fft.irfftn` to be in forward order [#2524](https://github.com/IntelPython/dpnp/pull/2524)

### Deprecated

Expand Down
33 changes: 27 additions & 6 deletions dpnp/fft/dpnp_utils_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,28 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft):
return dsc, out_strides


def _complex_nd_fft(a, s, norm, out, forward, in_place, c2c, axes, batch_fft):
def _complex_nd_fft(
a,
s,
norm,
out,
forward,
in_place,
c2c,
axes,
batch_fft,
*,
reversed_axes=True,
):
"""Computes complex-to-complex FFT of the input N-D array."""

len_axes = len(axes)
# OneMKL supports up to 3-dimensional FFT on GPU
# repeated axis in OneMKL FFT is not allowed
if len_axes > 3 or len(set(axes)) < len_axes:
axes_chunk, shape_chunk = _extract_axes_chunk(axes, s, chunk_size=3)
axes_chunk, shape_chunk = _extract_axes_chunk(
axes, s, chunk_size=3, reversed_axes=reversed_axes
)
for i, (s_chunk, a_chunk) in enumerate(zip(shape_chunk, axes_chunk)):
a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk)
# if out is used in an intermediate step, it will have memory
Expand Down Expand Up @@ -291,7 +305,7 @@ def _copy_array(x, complex_input):
return x, copy_flag


def _extract_axes_chunk(a, s, chunk_size=3):
def _extract_axes_chunk(a, s, chunk_size=3, reversed_axes=True):
"""
Classify the first input into a list of lists with each list containing
only unique values in reverse order and its length is at most `chunk_size`.
Expand Down Expand Up @@ -362,7 +376,10 @@ def _extract_axes_chunk(a, s, chunk_size=3):
a_chunks.append(a_current_chunk[::-1])
s_chunks.append(s_current_chunk[::-1])

return a_chunks[::-1], s_chunks[::-1]
if reversed_axes:
return a_chunks[::-1], s_chunks[::-1]

return a_chunks, s_chunks


def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
Expand Down Expand Up @@ -531,9 +548,12 @@ def _validate_out_keyword(a, out, s, axes, c2c, c2r, r2c):
expected_shape[axes[-1]] = s[-1] // 2 + 1
elif c2c:
expected_shape[axes[-1]] = s[-1]
for s_i, axis in zip(s[-2::-1], axes[-2::-1]):
expected_shape[axis] = s_i
if r2c or c2c:
for s_i, axis in zip(s[-2::-1], axes[-2::-1]):
expected_shape[axis] = s_i
if c2r:
for s_i, axis in zip(s[:-1], axes[:-1]):
expected_shape[axis] = s_i
expected_shape[axes[-1]] = s[-1]

if out.shape != tuple(expected_shape):
Expand Down Expand Up @@ -717,6 +737,7 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
c2c=True,
axes=axes[:-1],
batch_fft=a.ndim != len_axes - 1,
reversed_axes=False,
)
a = _truncate_or_pad(a, (s[-1],), (axes[-1],))
if c2r:
Expand Down
10 changes: 5 additions & 5 deletions dpnp/tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,13 +401,13 @@ def test_repeated_axes(self, axes):
result = dpnp.fft.fftn(ia, axes=axes)
# Intel NumPy ignores repeated axes (mkl_fft-gh-104), handle it one by one
expected = a
for ii in axes:
for ii in axes[::-1]:
expected = numpy.fft.fft(expected, axis=ii)
assert_dtype_allclose(result, expected)

# inverse FFT
result = dpnp.fft.ifftn(result, axes=axes)
for ii in axes:
for ii in axes[::-1]:
expected = numpy.fft.ifft(expected, axis=ii)
assert_dtype_allclose(result, expected)

Expand Down Expand Up @@ -893,7 +893,7 @@ def test_repeated_axes(self, axes):

# inverse FFT
result = dpnp.fft.irfftn(result, axes=axes)
for ii in axes[-2::-1]:
for ii in axes[:-1]:
expected = numpy.fft.ifft(expected, axis=ii)
expected = numpy.fft.irfft(expected, axis=axes[-1])
assert_dtype_allclose(result, expected)
Expand All @@ -912,7 +912,7 @@ def test_repeated_axes_with_s(self, axes, s):
assert_dtype_allclose(result, expected)

result = dpnp.fft.irfftn(result, s=s, axes=axes)
for jj, ii in zip(s[-2::-1], axes[-2::-1]):
for jj, ii in zip(s[:-1], axes[:-1]):
expected = numpy.fft.ifft(expected, n=jj, axis=ii)
expected = numpy.fft.irfft(expected, n=s[-1], axis=axes[-1])
assert_dtype_allclose(result, expected)
Expand All @@ -934,7 +934,7 @@ def test_out(self, axes, s):
assert_dtype_allclose(result, expected)

# inverse FFT
for jj, ii in zip(s[-2::-1], axes[-2::-1]):
for jj, ii in zip(s[:-1], axes[:-1]):
expected = numpy.fft.ifft(expected, n=jj, axis=ii)
expected = numpy.fft.irfft(expected, n=s[-1], axis=axes[-1])
out = dpnp.empty(expected.shape, dtype=numpy.float32)
Expand Down
Loading