From 4bfcef3941b553e5a8b27f1261bdf87f0fe5452b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Apr 2023 21:27:33 +0300 Subject: [PATCH 1/5] TST: fft tests pass collection and xfail --- torch_np/__init__.py | 2 +- torch_np/fft.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_np/__init__.py b/torch_np/__init__.py index 8ea1bf3c..a3158b5b 100644 --- a/torch_np/__init__.py +++ b/torch_np/__init__.py @@ -1,4 +1,4 @@ -from . import fft, linalg, random +from . import linalg, random, fft from ._dtypes import * from ._funcs import * from ._getlimits import errstate, finfo, iinfo diff --git a/torch_np/fft.py b/torch_np/fft.py index b6bb763c..3d26eae6 100644 --- a/torch_np/fft.py +++ b/torch_np/fft.py @@ -28,3 +28,4 @@ def fft2(): def ifft2(): raise NotImplementedError + From c9e5fc97fe11a4ecbe35c46462b113ce750330f3 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Apr 2023 00:50:24 +0300 Subject: [PATCH 2/5] ENH: implement fft --- torch_np/fft.py | 102 +++++++++++++++--- .../tests/numpy_tests/fft/test_pocketfft.py | 41 +++---- 2 files changed, 109 insertions(+), 34 deletions(-) diff --git a/torch_np/fft.py b/torch_np/fft.py index 3d26eae6..0c155113 100644 --- a/torch_np/fft.py +++ b/torch_np/fft.py @@ -1,31 +1,101 @@ -def fft(): - raise NotImplementedError +from __future__ import annotations +import functools +import torch -def ifft(): - raise NotImplementedError +from ._normalizations import ArrayLike, normalizer +from . import _util -def fftn(): - raise NotImplementedError +def upcast(func): + """NumPy fft casts inputs to 64 bit and *returns 64-bit results*.""" + @functools.wraps(func) + def wrapped(tensor, *args, **kwds): + target_dtype = torch.complex128 if tensor.is_complex() else torch.float64 + tensor = _util.cast_if_needed(tensor, target_dtype) + return func(tensor, *args, **kwds) + return wrapped -def ifftn(): - raise NotImplementedError +@normalizer +@upcast +def fft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.fft(a, n, dim=axis, norm=norm) -def rfftn(): - raise NotImplementedError +@normalizer +@upcast +def ifft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.ifft(a, n, dim=axis, norm=norm) -def irfftn(): - raise NotImplementedError +@normalizer +@upcast +def rfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.rfft(a, n, dim=axis, norm=norm) -def fft2(): - raise NotImplementedError +@normalizer +@upcast +def irfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.irfft(a, n, dim=axis, norm=norm) -def ifft2(): - raise NotImplementedError +@normalizer +@upcast +def fftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.fftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def ifftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.ifftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def rfftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.rfftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def irfftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.irfftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.fft2(a, s, dim=axes, norm=norm) + +@normalizer +@upcast +def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.ifft2(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def rfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.rfft2(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.irfft2(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def hfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.hfft(a, n, dim=axis, norm=norm) + +@normalizer +@upcast +def ihfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.ihfft(a, n, dim=axis, norm=norm) + diff --git a/torch_np/tests/numpy_tests/fft/test_pocketfft.py b/torch_np/tests/numpy_tests/fft/test_pocketfft.py index e5b4c770..c3534eb2 100644 --- a/torch_np/tests/numpy_tests/fft/test_pocketfft.py +++ b/torch_np/tests/numpy_tests/fft/test_pocketfft.py @@ -19,16 +19,17 @@ def fft1(x): return np.sum(x*np.exp(phase), axis=1) -@pytest.mark.xfail(reason='TODO') class TestFFTShift: def test_fft_n(self): - assert_raises(ValueError, np.fft.fft, [1, 2, 3], 0) + assert_raises((ValueError, RuntimeError), np.fft.fft, [1, 2, 3], 0) -@pytest.mark.xfail(reason='TODO') class TestFFT1D: + def setup_method(self): + np.random.seed(123456) + def test_identity(self): maxlen = 512 x = random(maxlen) + 1j*random(maxlen) @@ -40,13 +41,15 @@ def test_identity(self): xr[0:i], atol=1e-12) def test_fft(self): + + np.random.seed(1234) x = random(30) + 1j*random(30) - assert_allclose(fft1(x), np.fft.fft(x), atol=1e-6) - assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=1e-6) + assert_allclose(fft1(x), np.fft.fft(x), atol=2e-5) + assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=2e-5) assert_allclose(fft1(x) / np.sqrt(30), - np.fft.fft(x, norm="ortho"), atol=1e-6) + np.fft.fft(x, norm="ortho"), atol=5e-6) assert_allclose(fft1(x) / 30., - np.fft.fft(x, norm="forward"), atol=1e-6) + np.fft.fft(x, norm="forward"), atol=5e-6) @pytest.mark.parametrize('norm', (None, 'backward', 'ortho', 'forward')) def test_ifft(self, norm): @@ -55,8 +58,8 @@ def test_ifft(self, norm): x, np.fft.ifft(np.fft.fft(x, norm=norm), norm=norm), atol=1e-6) # Ensure we get the correct error message - with pytest.raises(ValueError, - match='Invalid number of FFT data points'): + with pytest.raises((ValueError, RuntimeError), + match='Invalid number of data points'): np.fft.ifft([], norm=norm) def test_fft2(self): @@ -175,7 +178,7 @@ def test_irfftn(self): def test_hfft(self): x = random(14) + 1j*random(14) x_herm = np.concatenate((random(1), x, random(1))) - x = np.concatenate((x_herm, x[::-1].conj())) + x = np.concatenate((x_herm, np.flip(x).conj())) assert_allclose(np.fft.fft(x), np.fft.hfft(x_herm), atol=1e-6) assert_allclose(np.fft.hfft(x_herm), np.fft.hfft(x_herm, norm="backward"), atol=1e-6) @@ -187,7 +190,7 @@ def test_hfft(self): def test_ihfft(self): x = random(14) + 1j*random(14) x_herm = np.concatenate((random(1), x, random(1))) - x = np.concatenate((x_herm, x[::-1].conj())) + x = np.concatenate((x_herm, np.flip(x).conj())) assert_allclose(x_herm, np.fft.ihfft(np.fft.hfft(x_herm)), atol=1e-6) assert_allclose(x_herm, np.fft.ihfft(np.fft.hfft(x_herm, norm="backward"), norm="backward"), atol=1e-6) @@ -234,7 +237,6 @@ def test_dtypes(self, dtype): assert_allclose(np.fft.irfft(np.fft.rfft(x)), x, atol=1e-6) -@pytest.mark.xfail(reason='TODO') @pytest.mark.parametrize( "dtype", [np.float32, np.float64, np.complex64, np.complex128]) @@ -246,16 +248,20 @@ def test_dtypes(self, dtype): def test_fft_with_order(dtype, order, fft): # Check that FFT/IFFT produces identical results for C, Fortran and # non contiguous arrays - rng = np.random.RandomState(42) - X = rng.rand(8, 7, 13).astype(dtype, copy=False) + # rng = np.random.RandomState(42) + rng = np.random + X = rng.rand(8, 7, 13).astype(dtype) #, copy=False) # See discussion in pull/14178 - _tol = 8.0 * np.sqrt(np.log2(X.size)) * np.finfo(X.dtype).eps + _tol = float(8.0 * np.sqrt(np.log2(X.size)) * np.finfo(X.dtype).eps) if order == 'F': + pytest.skip("Fortran order arrays") Y = np.asfortranarray(X) else: # Make a non contiguous array - Y = X[::-1] - X = np.ascontiguousarray(X[::-1]) + Z = np.empty((16, 7, 13), dtype=X.dtype) + Z[::2] = X + Y = Z[::2] + X = Y.copy() if fft.__name__.endswith('fft'): for axis in range(3): @@ -274,7 +280,6 @@ def test_fft_with_order(dtype, order, fft): raise ValueError() -@pytest.mark.xfail(reason='TODO') @pytest.mark.skipif(IS_WASM, reason="Cannot start thread") class TestFFTThreadSafe: threads = 16 From 9e0de0fefd2cc19fb12d27cf54d0f2f2e009c75b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Apr 2023 10:23:01 +0300 Subject: [PATCH 3/5] API: fft: make fft output dtype configurable --- torch_np/__init__.py | 2 +- torch_np/fft.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/torch_np/__init__.py b/torch_np/__init__.py index a3158b5b..8ea1bf3c 100644 --- a/torch_np/__init__.py +++ b/torch_np/__init__.py @@ -1,4 +1,4 @@ -from . import linalg, random, fft +from . import fft, linalg, random from ._dtypes import * from ._funcs import * from ._getlimits import errstate, finfo, iinfo diff --git a/torch_np/fft.py b/torch_np/fft.py index 0c155113..aa0d506f 100644 --- a/torch_np/fft.py +++ b/torch_np/fft.py @@ -1,19 +1,26 @@ from __future__ import annotations import functools + import torch +from . import _dtypes_impl, _util from ._normalizations import ArrayLike, normalizer -from . import _util def upcast(func): """NumPy fft casts inputs to 64 bit and *returns 64-bit results*.""" + @functools.wraps(func) def wrapped(tensor, *args, **kwds): - target_dtype = torch.complex128 if tensor.is_complex() else torch.float64 + target_dtype = ( + _dtypes_impl.default_dtypes.complex_dtype + if tensor.is_complex() + else _dtypes_impl.default_dtypes.float_dtype + ) tensor = _util.cast_if_needed(tensor, target_dtype) return func(tensor, *args, **kwds) + return wrapped @@ -70,6 +77,7 @@ def irfftn(a: ArrayLike, s=None, axes=None, norm=None): def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): return torch.fft.fft2(a, s, dim=axes, norm=norm) + @normalizer @upcast def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): @@ -93,9 +101,8 @@ def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): def hfft(a: ArrayLike, n=None, axis=-1, norm=None): return torch.fft.hfft(a, n, dim=axis, norm=norm) + @normalizer @upcast def ihfft(a: ArrayLike, n=None, axis=-1, norm=None): return torch.fft.ihfft(a, n, dim=axis, norm=norm) - - From e71ce13a3dc08c974f6bb5cc493a0b37f74a02aa Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Apr 2023 10:27:36 +0300 Subject: [PATCH 4/5] TST: vendor numpy's fft/test_helper.py --- torch_np/tests/numpy_tests/fft/test_helper.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torch_np/tests/numpy_tests/fft/test_helper.py b/torch_np/tests/numpy_tests/fft/test_helper.py index 41650354..da65fb50 100644 --- a/torch_np/tests/numpy_tests/fft/test_helper.py +++ b/torch_np/tests/numpy_tests/fft/test_helper.py @@ -7,9 +7,7 @@ from torch_np.testing import assert_array_almost_equal from torch_np import fft, pi -import pytest -@pytest.mark.xfail(reason="TODO") class TestFFTShift: def test_definition(self): @@ -135,7 +133,6 @@ def original_ifftshift(x, axes=None): original_ifftshift(inp, axes_keyword)) -@pytest.mark.xfail(reason="TODO") class TestFFTFreq: def test_definition(self): @@ -147,7 +144,6 @@ def test_definition(self): assert_array_almost_equal(10*pi*fft.fftfreq(10, pi), x) -@pytest.mark.xfail(reason="TODO") class TestRFFTFreq: def test_definition(self): @@ -159,7 +155,6 @@ def test_definition(self): assert_array_almost_equal(10*pi*fft.rfftfreq(10, pi), x) -@pytest.mark.xfail(reason="TODO") class TestIRFFTN: def test_not_last_axis_success(self): From 8037d56057ce4941d703cd2d8c73f7e893a67347 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Apr 2023 11:06:39 +0300 Subject: [PATCH 5/5] FFT: finish up the dir(np.fft) contents --- torch_np/fft.py | 20 +++++++++++++++++++ torch_np/tests/numpy_tests/fft/test_helper.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/torch_np/fft.py b/torch_np/fft.py index aa0d506f..87a32387 100644 --- a/torch_np/fft.py +++ b/torch_np/fft.py @@ -106,3 +106,23 @@ def hfft(a: ArrayLike, n=None, axis=-1, norm=None): @upcast def ihfft(a: ArrayLike, n=None, axis=-1, norm=None): return torch.fft.ihfft(a, n, dim=axis, norm=norm) + + +@normalizer +def fftfreq(n, d=1.0): + return torch.fft.fftfreq(n, d) + + +@normalizer +def rfftfreq(n, d=1.0): + return torch.fft.rfftfreq(n, d) + + +@normalizer +def fftshift(x: ArrayLike, axes=None): + return torch.fft.fftshift(x, axes) + + +@normalizer +def ifftshift(x: ArrayLike, axes=None): + return torch.fft.ifftshift(x, axes) diff --git a/torch_np/tests/numpy_tests/fft/test_helper.py b/torch_np/tests/numpy_tests/fft/test_helper.py index da65fb50..d3dec841 100644 --- a/torch_np/tests/numpy_tests/fft/test_helper.py +++ b/torch_np/tests/numpy_tests/fft/test_helper.py @@ -85,7 +85,7 @@ def test_uneven_dims(self): def test_equal_to_original(self): """ Test that the new (>=v1.15) implementation (see #10073) is equal to the original (<=v1.14) """ - from numpy.core import asarray, concatenate, arange, take + from torch_np import asarray, concatenate, arange, take def original_fftshift(x, axes=None): """ How fftshift was implemented in v1.14"""