Skip to content

add wrappers for np.fft #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 114 additions & 16 deletions torch_np/fft.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,128 @@
def fft():
raise NotImplementedError
from __future__ import annotations

import functools

def ifft():
raise NotImplementedError
import torch

from . import _dtypes_impl, _util
from ._normalizations import ArrayLike, normalizer

def fftn():
raise NotImplementedError

def upcast(func):
"""NumPy fft casts inputs to 64 bit and *returns 64-bit results*."""

def ifftn():
raise NotImplementedError
@functools.wraps(func)
def wrapped(tensor, *args, **kwds):
target_dtype = (
_dtypes_impl.default_dtypes.complex_dtype
if tensor.is_complex()
else _dtypes_impl.default_dtypes.float_dtype
)
tensor = _util.cast_if_needed(tensor, target_dtype)
return func(tensor, *args, **kwds)

return wrapped

def rfftn():
raise NotImplementedError

@normalizer
@upcast
def fft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.fft(a, n, dim=axis, norm=norm)

def irfftn():
raise NotImplementedError

@normalizer
@upcast
def ifft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.ifft(a, n, dim=axis, norm=norm)

def fft2():
raise NotImplementedError

@normalizer
@upcast
def rfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.rfft(a, n, dim=axis, norm=norm)

def ifft2():
raise NotImplementedError

@normalizer
@upcast
def irfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.irfft(a, n, dim=axis, norm=norm)


@normalizer
@upcast
def fftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.fftn(a, s, dim=axes, norm=norm)


@normalizer
@upcast
def ifftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.ifftn(a, s, dim=axes, norm=norm)


@normalizer
@upcast
def rfftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.rfftn(a, s, dim=axes, norm=norm)


@normalizer
@upcast
def irfftn(a: ArrayLike, s=None, axes=None, norm=None):
return torch.fft.irfftn(a, s, dim=axes, norm=norm)


@normalizer
@upcast
def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.fft2(a, s, dim=axes, norm=norm)


@normalizer
@upcast
def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.ifft2(a, s, dim=axes, norm=norm)


@normalizer
@upcast
def rfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.rfft2(a, s, dim=axes, norm=norm)


@normalizer
@upcast
def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
return torch.fft.irfft2(a, s, dim=axes, norm=norm)


@normalizer
@upcast
def hfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.hfft(a, n, dim=axis, norm=norm)


@normalizer
@upcast
def ihfft(a: ArrayLike, n=None, axis=-1, norm=None):
return torch.fft.ihfft(a, n, dim=axis, norm=norm)


@normalizer
def fftfreq(n, d=1.0):
return torch.fft.fftfreq(n, d)


@normalizer
def rfftfreq(n, d=1.0):
return torch.fft.rfftfreq(n, d)


@normalizer
def fftshift(x: ArrayLike, axes=None):
return torch.fft.fftshift(x, axes)


@normalizer
def ifftshift(x: ArrayLike, axes=None):
return torch.fft.ifftshift(x, axes)
7 changes: 1 addition & 6 deletions torch_np/tests/numpy_tests/fft/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from torch_np.testing import assert_array_almost_equal
from torch_np import fft, pi

import pytest

@pytest.mark.xfail(reason="TODO")
class TestFFTShift:

def test_definition(self):
Expand Down Expand Up @@ -87,7 +85,7 @@ def test_uneven_dims(self):

def test_equal_to_original(self):
""" Test that the new (>=v1.15) implementation (see #10073) is equal to the original (<=v1.14) """
from numpy.core import asarray, concatenate, arange, take
from torch_np import asarray, concatenate, arange, take

def original_fftshift(x, axes=None):
""" How fftshift was implemented in v1.14"""
Expand Down Expand Up @@ -135,7 +133,6 @@ def original_ifftshift(x, axes=None):
original_ifftshift(inp, axes_keyword))


@pytest.mark.xfail(reason="TODO")
class TestFFTFreq:

def test_definition(self):
Expand All @@ -147,7 +144,6 @@ def test_definition(self):
assert_array_almost_equal(10*pi*fft.fftfreq(10, pi), x)


@pytest.mark.xfail(reason="TODO")
class TestRFFTFreq:

def test_definition(self):
Expand All @@ -159,7 +155,6 @@ def test_definition(self):
assert_array_almost_equal(10*pi*fft.rfftfreq(10, pi), x)


@pytest.mark.xfail(reason="TODO")
class TestIRFFTN:

def test_not_last_axis_success(self):
Expand Down
41 changes: 23 additions & 18 deletions torch_np/tests/numpy_tests/fft/test_pocketfft.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ def fft1(x):
return np.sum(x*np.exp(phase), axis=1)


@pytest.mark.xfail(reason='TODO')
class TestFFTShift:

def test_fft_n(self):
assert_raises(ValueError, np.fft.fft, [1, 2, 3], 0)
assert_raises((ValueError, RuntimeError), np.fft.fft, [1, 2, 3], 0)


@pytest.mark.xfail(reason='TODO')
class TestFFT1D:

def setup_method(self):
np.random.seed(123456)

def test_identity(self):
maxlen = 512
x = random(maxlen) + 1j*random(maxlen)
Expand All @@ -40,13 +41,15 @@ def test_identity(self):
xr[0:i], atol=1e-12)

def test_fft(self):

np.random.seed(1234)
x = random(30) + 1j*random(30)
assert_allclose(fft1(x), np.fft.fft(x), atol=1e-6)
assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=1e-6)
assert_allclose(fft1(x), np.fft.fft(x), atol=2e-5)
Copy link
Collaborator Author

@ev-br ev-br Apr 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to bump the tolerance. Not much we can do I guess, this falls straight through to pytorch.

assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=2e-5)
assert_allclose(fft1(x) / np.sqrt(30),
np.fft.fft(x, norm="ortho"), atol=1e-6)
np.fft.fft(x, norm="ortho"), atol=5e-6)
assert_allclose(fft1(x) / 30.,
np.fft.fft(x, norm="forward"), atol=1e-6)
np.fft.fft(x, norm="forward"), atol=5e-6)

@pytest.mark.parametrize('norm', (None, 'backward', 'ortho', 'forward'))
def test_ifft(self, norm):
Expand All @@ -55,8 +58,8 @@ def test_ifft(self, norm):
x, np.fft.ifft(np.fft.fft(x, norm=norm), norm=norm),
atol=1e-6)
# Ensure we get the correct error message
with pytest.raises(ValueError,
match='Invalid number of FFT data points'):
with pytest.raises((ValueError, RuntimeError),
match='Invalid number of data points'):
np.fft.ifft([], norm=norm)

def test_fft2(self):
Expand Down Expand Up @@ -175,7 +178,7 @@ def test_irfftn(self):
def test_hfft(self):
x = random(14) + 1j*random(14)
x_herm = np.concatenate((random(1), x, random(1)))
x = np.concatenate((x_herm, x[::-1].conj()))
x = np.concatenate((x_herm, np.flip(x).conj()))
assert_allclose(np.fft.fft(x), np.fft.hfft(x_herm), atol=1e-6)
assert_allclose(np.fft.hfft(x_herm),
np.fft.hfft(x_herm, norm="backward"), atol=1e-6)
Expand All @@ -187,7 +190,7 @@ def test_hfft(self):
def test_ihfft(self):
x = random(14) + 1j*random(14)
x_herm = np.concatenate((random(1), x, random(1)))
x = np.concatenate((x_herm, x[::-1].conj()))
x = np.concatenate((x_herm, np.flip(x).conj()))
assert_allclose(x_herm, np.fft.ihfft(np.fft.hfft(x_herm)), atol=1e-6)
assert_allclose(x_herm, np.fft.ihfft(np.fft.hfft(x_herm,
norm="backward"), norm="backward"), atol=1e-6)
Expand Down Expand Up @@ -234,7 +237,6 @@ def test_dtypes(self, dtype):
assert_allclose(np.fft.irfft(np.fft.rfft(x)), x, atol=1e-6)


@pytest.mark.xfail(reason='TODO')
@pytest.mark.parametrize(
"dtype",
[np.float32, np.float64, np.complex64, np.complex128])
Expand All @@ -246,16 +248,20 @@ def test_dtypes(self, dtype):
def test_fft_with_order(dtype, order, fft):
# Check that FFT/IFFT produces identical results for C, Fortran and
# non contiguous arrays
rng = np.random.RandomState(42)
X = rng.rand(8, 7, 13).astype(dtype, copy=False)
# rng = np.random.RandomState(42)
rng = np.random
X = rng.rand(8, 7, 13).astype(dtype) #, copy=False)
# See discussion in pull/14178
_tol = 8.0 * np.sqrt(np.log2(X.size)) * np.finfo(X.dtype).eps
_tol = float(8.0 * np.sqrt(np.log2(X.size)) * np.finfo(X.dtype).eps)
if order == 'F':
pytest.skip("Fortran order arrays")
Y = np.asfortranarray(X)
else:
# Make a non contiguous array
Y = X[::-1]
X = np.ascontiguousarray(X[::-1])
Z = np.empty((16, 7, 13), dtype=X.dtype)
Z[::2] = X
Y = Z[::2]
X = Y.copy()

if fft.__name__.endswith('fft'):
for axis in range(3):
Expand All @@ -274,7 +280,6 @@ def test_fft_with_order(dtype, order, fft):
raise ValueError()


@pytest.mark.xfail(reason='TODO')
@pytest.mark.skipif(IS_WASM, reason="Cannot start thread")
class TestFFTThreadSafe:
threads = 16
Expand Down