Skip to content

Commit 428f073

Browse files
authored
Merge pull request #124 from Quansight-Labs/fft_tests
Add numpy tests for np.fft
2 parents 5be250a + d70014f commit 428f073

File tree

4 files changed

+518
-1
lines changed

4 files changed

+518
-1
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import linalg, random
1+
from . import fft, linalg, random
22
from ._dtypes import *
33
from ._funcs import *
44
from ._getlimits import errstate, finfo, iinfo

torch_np/fft.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
def fft():
2+
raise NotImplementedError
3+
4+
5+
def ifft():
6+
raise NotImplementedError
7+
8+
9+
def fftn():
10+
raise NotImplementedError
11+
12+
13+
def ifftn():
14+
raise NotImplementedError
15+
16+
17+
def rfftn():
18+
raise NotImplementedError
19+
20+
21+
def irfftn():
22+
raise NotImplementedError
23+
24+
25+
def fft2():
26+
raise NotImplementedError
27+
28+
29+
def ifft2():
30+
raise NotImplementedError
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""Test functions for fftpack.helper module
2+
3+
Copied from fftpack.helper by Pearu Peterson, October 2005
4+
5+
"""
6+
import torch_np as np
7+
from torch_np.testing import assert_array_almost_equal
8+
from torch_np import fft, pi
9+
10+
import pytest
11+
12+
@pytest.mark.xfail(reason="TODO")
13+
class TestFFTShift:
14+
15+
def test_definition(self):
16+
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
17+
y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
18+
assert_array_almost_equal(fft.fftshift(x), y)
19+
assert_array_almost_equal(fft.ifftshift(y), x)
20+
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
21+
y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
22+
assert_array_almost_equal(fft.fftshift(x), y)
23+
assert_array_almost_equal(fft.ifftshift(y), x)
24+
25+
def test_inverse(self):
26+
for n in [1, 4, 9, 100, 211]:
27+
x = np.random.random((n,))
28+
assert_array_almost_equal(fft.ifftshift(fft.fftshift(x)), x)
29+
30+
def test_axes_keyword(self):
31+
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
32+
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
33+
assert_array_almost_equal(fft.fftshift(freqs, axes=(0, 1)), shifted)
34+
assert_array_almost_equal(fft.fftshift(freqs, axes=0),
35+
fft.fftshift(freqs, axes=(0,)))
36+
assert_array_almost_equal(fft.ifftshift(shifted, axes=(0, 1)), freqs)
37+
assert_array_almost_equal(fft.ifftshift(shifted, axes=0),
38+
fft.ifftshift(shifted, axes=(0,)))
39+
40+
assert_array_almost_equal(fft.fftshift(freqs), shifted)
41+
assert_array_almost_equal(fft.ifftshift(shifted), freqs)
42+
43+
def test_uneven_dims(self):
44+
""" Test 2D input, which has uneven dimension sizes """
45+
freqs = [
46+
[0, 1],
47+
[2, 3],
48+
[4, 5]
49+
]
50+
51+
# shift in dimension 0
52+
shift_dim0 = [
53+
[4, 5],
54+
[0, 1],
55+
[2, 3]
56+
]
57+
assert_array_almost_equal(fft.fftshift(freqs, axes=0), shift_dim0)
58+
assert_array_almost_equal(fft.ifftshift(shift_dim0, axes=0), freqs)
59+
assert_array_almost_equal(fft.fftshift(freqs, axes=(0,)), shift_dim0)
60+
assert_array_almost_equal(fft.ifftshift(shift_dim0, axes=[0]), freqs)
61+
62+
# shift in dimension 1
63+
shift_dim1 = [
64+
[1, 0],
65+
[3, 2],
66+
[5, 4]
67+
]
68+
assert_array_almost_equal(fft.fftshift(freqs, axes=1), shift_dim1)
69+
assert_array_almost_equal(fft.ifftshift(shift_dim1, axes=1), freqs)
70+
71+
# shift in both dimensions
72+
shift_dim_both = [
73+
[5, 4],
74+
[1, 0],
75+
[3, 2]
76+
]
77+
assert_array_almost_equal(fft.fftshift(freqs, axes=(0, 1)), shift_dim_both)
78+
assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=(0, 1)), freqs)
79+
assert_array_almost_equal(fft.fftshift(freqs, axes=[0, 1]), shift_dim_both)
80+
assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=[0, 1]), freqs)
81+
82+
# axes=None (default) shift in all dimensions
83+
assert_array_almost_equal(fft.fftshift(freqs, axes=None), shift_dim_both)
84+
assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=None), freqs)
85+
assert_array_almost_equal(fft.fftshift(freqs), shift_dim_both)
86+
assert_array_almost_equal(fft.ifftshift(shift_dim_both), freqs)
87+
88+
def test_equal_to_original(self):
89+
""" Test that the new (>=v1.15) implementation (see #10073) is equal to the original (<=v1.14) """
90+
from numpy.core import asarray, concatenate, arange, take
91+
92+
def original_fftshift(x, axes=None):
93+
""" How fftshift was implemented in v1.14"""
94+
tmp = asarray(x)
95+
ndim = tmp.ndim
96+
if axes is None:
97+
axes = list(range(ndim))
98+
elif isinstance(axes, int):
99+
axes = (axes,)
100+
y = tmp
101+
for k in axes:
102+
n = tmp.shape[k]
103+
p2 = (n + 1) // 2
104+
mylist = concatenate((arange(p2, n), arange(p2)))
105+
y = take(y, mylist, k)
106+
return y
107+
108+
def original_ifftshift(x, axes=None):
109+
""" How ifftshift was implemented in v1.14 """
110+
tmp = asarray(x)
111+
ndim = tmp.ndim
112+
if axes is None:
113+
axes = list(range(ndim))
114+
elif isinstance(axes, int):
115+
axes = (axes,)
116+
y = tmp
117+
for k in axes:
118+
n = tmp.shape[k]
119+
p2 = n - (n + 1) // 2
120+
mylist = concatenate((arange(p2, n), arange(p2)))
121+
y = take(y, mylist, k)
122+
return y
123+
124+
# create possible 2d array combinations and try all possible keywords
125+
# compare output to original functions
126+
for i in range(16):
127+
for j in range(16):
128+
for axes_keyword in [0, 1, None, (0,), (0, 1)]:
129+
inp = np.random.rand(i, j)
130+
131+
assert_array_almost_equal(fft.fftshift(inp, axes_keyword),
132+
original_fftshift(inp, axes_keyword))
133+
134+
assert_array_almost_equal(fft.ifftshift(inp, axes_keyword),
135+
original_ifftshift(inp, axes_keyword))
136+
137+
138+
@pytest.mark.xfail(reason="TODO")
139+
class TestFFTFreq:
140+
141+
def test_definition(self):
142+
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
143+
assert_array_almost_equal(9*fft.fftfreq(9), x)
144+
assert_array_almost_equal(9*pi*fft.fftfreq(9, pi), x)
145+
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
146+
assert_array_almost_equal(10*fft.fftfreq(10), x)
147+
assert_array_almost_equal(10*pi*fft.fftfreq(10, pi), x)
148+
149+
150+
@pytest.mark.xfail(reason="TODO")
151+
class TestRFFTFreq:
152+
153+
def test_definition(self):
154+
x = [0, 1, 2, 3, 4]
155+
assert_array_almost_equal(9*fft.rfftfreq(9), x)
156+
assert_array_almost_equal(9*pi*fft.rfftfreq(9, pi), x)
157+
x = [0, 1, 2, 3, 4, 5]
158+
assert_array_almost_equal(10*fft.rfftfreq(10), x)
159+
assert_array_almost_equal(10*pi*fft.rfftfreq(10, pi), x)
160+
161+
162+
@pytest.mark.xfail(reason="TODO")
163+
class TestIRFFTN:
164+
165+
def test_not_last_axis_success(self):
166+
ar, ai = np.random.random((2, 16, 8, 32))
167+
a = ar + 1j*ai
168+
169+
axes = (-2,)
170+
171+
# Should not raise error
172+
fft.irfftn(a, axes=axes)

0 commit comments

Comments
 (0)