Skip to content

Commit aaabfda

Browse files
authored
Merge pull request #123 from Quansight-Labs/fft
add wrappers for np.fft
2 parents 428f073 + 8037d56 commit aaabfda

File tree

3 files changed

+138
-40
lines changed

3 files changed

+138
-40
lines changed

torch_np/fft.py

Lines changed: 114 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,128 @@
1-
def fft():
2-
raise NotImplementedError
1+
from __future__ import annotations
32

3+
import functools
44

5-
def ifft():
6-
raise NotImplementedError
5+
import torch
76

7+
from . import _dtypes_impl, _util
8+
from ._normalizations import ArrayLike, normalizer
89

9-
def fftn():
10-
raise NotImplementedError
1110

11+
def upcast(func):
12+
"""NumPy fft casts inputs to 64 bit and *returns 64-bit results*."""
1213

13-
def ifftn():
14-
raise NotImplementedError
14+
@functools.wraps(func)
15+
def wrapped(tensor, *args, **kwds):
16+
target_dtype = (
17+
_dtypes_impl.default_dtypes.complex_dtype
18+
if tensor.is_complex()
19+
else _dtypes_impl.default_dtypes.float_dtype
20+
)
21+
tensor = _util.cast_if_needed(tensor, target_dtype)
22+
return func(tensor, *args, **kwds)
1523

24+
return wrapped
1625

17-
def rfftn():
18-
raise NotImplementedError
1926

27+
@normalizer
28+
@upcast
29+
def fft(a: ArrayLike, n=None, axis=-1, norm=None):
30+
return torch.fft.fft(a, n, dim=axis, norm=norm)
2031

21-
def irfftn():
22-
raise NotImplementedError
2332

33+
@normalizer
34+
@upcast
35+
def ifft(a: ArrayLike, n=None, axis=-1, norm=None):
36+
return torch.fft.ifft(a, n, dim=axis, norm=norm)
2437

25-
def fft2():
26-
raise NotImplementedError
2738

39+
@normalizer
40+
@upcast
41+
def rfft(a: ArrayLike, n=None, axis=-1, norm=None):
42+
return torch.fft.rfft(a, n, dim=axis, norm=norm)
2843

29-
def ifft2():
30-
raise NotImplementedError
44+
45+
@normalizer
46+
@upcast
47+
def irfft(a: ArrayLike, n=None, axis=-1, norm=None):
48+
return torch.fft.irfft(a, n, dim=axis, norm=norm)
49+
50+
51+
@normalizer
52+
@upcast
53+
def fftn(a: ArrayLike, s=None, axes=None, norm=None):
54+
return torch.fft.fftn(a, s, dim=axes, norm=norm)
55+
56+
57+
@normalizer
58+
@upcast
59+
def ifftn(a: ArrayLike, s=None, axes=None, norm=None):
60+
return torch.fft.ifftn(a, s, dim=axes, norm=norm)
61+
62+
63+
@normalizer
64+
@upcast
65+
def rfftn(a: ArrayLike, s=None, axes=None, norm=None):
66+
return torch.fft.rfftn(a, s, dim=axes, norm=norm)
67+
68+
69+
@normalizer
70+
@upcast
71+
def irfftn(a: ArrayLike, s=None, axes=None, norm=None):
72+
return torch.fft.irfftn(a, s, dim=axes, norm=norm)
73+
74+
75+
@normalizer
76+
@upcast
77+
def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
78+
return torch.fft.fft2(a, s, dim=axes, norm=norm)
79+
80+
81+
@normalizer
82+
@upcast
83+
def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
84+
return torch.fft.ifft2(a, s, dim=axes, norm=norm)
85+
86+
87+
@normalizer
88+
@upcast
89+
def rfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
90+
return torch.fft.rfft2(a, s, dim=axes, norm=norm)
91+
92+
93+
@normalizer
94+
@upcast
95+
def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
96+
return torch.fft.irfft2(a, s, dim=axes, norm=norm)
97+
98+
99+
@normalizer
100+
@upcast
101+
def hfft(a: ArrayLike, n=None, axis=-1, norm=None):
102+
return torch.fft.hfft(a, n, dim=axis, norm=norm)
103+
104+
105+
@normalizer
106+
@upcast
107+
def ihfft(a: ArrayLike, n=None, axis=-1, norm=None):
108+
return torch.fft.ihfft(a, n, dim=axis, norm=norm)
109+
110+
111+
@normalizer
112+
def fftfreq(n, d=1.0):
113+
return torch.fft.fftfreq(n, d)
114+
115+
116+
@normalizer
117+
def rfftfreq(n, d=1.0):
118+
return torch.fft.rfftfreq(n, d)
119+
120+
121+
@normalizer
122+
def fftshift(x: ArrayLike, axes=None):
123+
return torch.fft.fftshift(x, axes)
124+
125+
126+
@normalizer
127+
def ifftshift(x: ArrayLike, axes=None):
128+
return torch.fft.ifftshift(x, axes)

torch_np/tests/numpy_tests/fft/test_helper.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
from torch_np.testing import assert_array_almost_equal
88
from torch_np import fft, pi
99

10-
import pytest
1110

12-
@pytest.mark.xfail(reason="TODO")
1311
class TestFFTShift:
1412

1513
def test_definition(self):
@@ -87,7 +85,7 @@ def test_uneven_dims(self):
8785

8886
def test_equal_to_original(self):
8987
""" Test that the new (>=v1.15) implementation (see #10073) is equal to the original (<=v1.14) """
90-
from numpy.core import asarray, concatenate, arange, take
88+
from torch_np import asarray, concatenate, arange, take
9189

9290
def original_fftshift(x, axes=None):
9391
""" How fftshift was implemented in v1.14"""
@@ -135,7 +133,6 @@ def original_ifftshift(x, axes=None):
135133
original_ifftshift(inp, axes_keyword))
136134

137135

138-
@pytest.mark.xfail(reason="TODO")
139136
class TestFFTFreq:
140137

141138
def test_definition(self):
@@ -147,7 +144,6 @@ def test_definition(self):
147144
assert_array_almost_equal(10*pi*fft.fftfreq(10, pi), x)
148145

149146

150-
@pytest.mark.xfail(reason="TODO")
151147
class TestRFFTFreq:
152148

153149
def test_definition(self):
@@ -159,7 +155,6 @@ def test_definition(self):
159155
assert_array_almost_equal(10*pi*fft.rfftfreq(10, pi), x)
160156

161157

162-
@pytest.mark.xfail(reason="TODO")
163158
class TestIRFFTN:
164159

165160
def test_not_last_axis_success(self):

torch_np/tests/numpy_tests/fft/test_pocketfft.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@ def fft1(x):
1919
return np.sum(x*np.exp(phase), axis=1)
2020

2121

22-
@pytest.mark.xfail(reason='TODO')
2322
class TestFFTShift:
2423

2524
def test_fft_n(self):
26-
assert_raises(ValueError, np.fft.fft, [1, 2, 3], 0)
25+
assert_raises((ValueError, RuntimeError), np.fft.fft, [1, 2, 3], 0)
2726

2827

29-
@pytest.mark.xfail(reason='TODO')
3028
class TestFFT1D:
3129

30+
def setup_method(self):
31+
np.random.seed(123456)
32+
3233
def test_identity(self):
3334
maxlen = 512
3435
x = random(maxlen) + 1j*random(maxlen)
@@ -40,13 +41,15 @@ def test_identity(self):
4041
xr[0:i], atol=1e-12)
4142

4243
def test_fft(self):
44+
45+
np.random.seed(1234)
4346
x = random(30) + 1j*random(30)
44-
assert_allclose(fft1(x), np.fft.fft(x), atol=1e-6)
45-
assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=1e-6)
47+
assert_allclose(fft1(x), np.fft.fft(x), atol=2e-5)
48+
assert_allclose(fft1(x), np.fft.fft(x, norm="backward"), atol=2e-5)
4649
assert_allclose(fft1(x) / np.sqrt(30),
47-
np.fft.fft(x, norm="ortho"), atol=1e-6)
50+
np.fft.fft(x, norm="ortho"), atol=5e-6)
4851
assert_allclose(fft1(x) / 30.,
49-
np.fft.fft(x, norm="forward"), atol=1e-6)
52+
np.fft.fft(x, norm="forward"), atol=5e-6)
5053

5154
@pytest.mark.parametrize('norm', (None, 'backward', 'ortho', 'forward'))
5255
def test_ifft(self, norm):
@@ -55,8 +58,8 @@ def test_ifft(self, norm):
5558
x, np.fft.ifft(np.fft.fft(x, norm=norm), norm=norm),
5659
atol=1e-6)
5760
# Ensure we get the correct error message
58-
with pytest.raises(ValueError,
59-
match='Invalid number of FFT data points'):
61+
with pytest.raises((ValueError, RuntimeError),
62+
match='Invalid number of data points'):
6063
np.fft.ifft([], norm=norm)
6164

6265
def test_fft2(self):
@@ -175,7 +178,7 @@ def test_irfftn(self):
175178
def test_hfft(self):
176179
x = random(14) + 1j*random(14)
177180
x_herm = np.concatenate((random(1), x, random(1)))
178-
x = np.concatenate((x_herm, x[::-1].conj()))
181+
x = np.concatenate((x_herm, np.flip(x).conj()))
179182
assert_allclose(np.fft.fft(x), np.fft.hfft(x_herm), atol=1e-6)
180183
assert_allclose(np.fft.hfft(x_herm),
181184
np.fft.hfft(x_herm, norm="backward"), atol=1e-6)
@@ -187,7 +190,7 @@ def test_hfft(self):
187190
def test_ihfft(self):
188191
x = random(14) + 1j*random(14)
189192
x_herm = np.concatenate((random(1), x, random(1)))
190-
x = np.concatenate((x_herm, x[::-1].conj()))
193+
x = np.concatenate((x_herm, np.flip(x).conj()))
191194
assert_allclose(x_herm, np.fft.ihfft(np.fft.hfft(x_herm)), atol=1e-6)
192195
assert_allclose(x_herm, np.fft.ihfft(np.fft.hfft(x_herm,
193196
norm="backward"), norm="backward"), atol=1e-6)
@@ -234,7 +237,6 @@ def test_dtypes(self, dtype):
234237
assert_allclose(np.fft.irfft(np.fft.rfft(x)), x, atol=1e-6)
235238

236239

237-
@pytest.mark.xfail(reason='TODO')
238240
@pytest.mark.parametrize(
239241
"dtype",
240242
[np.float32, np.float64, np.complex64, np.complex128])
@@ -246,16 +248,20 @@ def test_dtypes(self, dtype):
246248
def test_fft_with_order(dtype, order, fft):
247249
# Check that FFT/IFFT produces identical results for C, Fortran and
248250
# non contiguous arrays
249-
rng = np.random.RandomState(42)
250-
X = rng.rand(8, 7, 13).astype(dtype, copy=False)
251+
# rng = np.random.RandomState(42)
252+
rng = np.random
253+
X = rng.rand(8, 7, 13).astype(dtype) #, copy=False)
251254
# See discussion in pull/14178
252-
_tol = 8.0 * np.sqrt(np.log2(X.size)) * np.finfo(X.dtype).eps
255+
_tol = float(8.0 * np.sqrt(np.log2(X.size)) * np.finfo(X.dtype).eps)
253256
if order == 'F':
257+
pytest.skip("Fortran order arrays")
254258
Y = np.asfortranarray(X)
255259
else:
256260
# Make a non contiguous array
257-
Y = X[::-1]
258-
X = np.ascontiguousarray(X[::-1])
261+
Z = np.empty((16, 7, 13), dtype=X.dtype)
262+
Z[::2] = X
263+
Y = Z[::2]
264+
X = Y.copy()
259265

260266
if fft.__name__.endswith('fft'):
261267
for axis in range(3):
@@ -274,7 +280,6 @@ def test_fft_with_order(dtype, order, fft):
274280
raise ValueError()
275281

276282

277-
@pytest.mark.xfail(reason='TODO')
278283
@pytest.mark.skipif(IS_WASM, reason="Cannot start thread")
279284
class TestFFTThreadSafe:
280285
threads = 16

0 commit comments

Comments
 (0)