Skip to content

Commit ae47a9f

Browse files
committed
TST: vendor numpy's fft/test_helper.py
1 parent 9eeb526 commit ae47a9f

File tree

1 file changed

+167
-0
lines changed

1 file changed

+167
-0
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
"""Test functions for fftpack.helper module
2+
3+
Copied from fftpack.helper by Pearu Peterson, October 2005
4+
5+
"""
6+
import torch_np as np
7+
from torch_np.testing import assert_array_almost_equal
8+
from torch_np import fft, pi
9+
10+
11+
class TestFFTShift:
12+
13+
def test_definition(self):
14+
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
15+
y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
16+
assert_array_almost_equal(fft.fftshift(x), y)
17+
assert_array_almost_equal(fft.ifftshift(y), x)
18+
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
19+
y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
20+
assert_array_almost_equal(fft.fftshift(x), y)
21+
assert_array_almost_equal(fft.ifftshift(y), x)
22+
23+
def test_inverse(self):
24+
for n in [1, 4, 9, 100, 211]:
25+
x = np.random.random((n,))
26+
assert_array_almost_equal(fft.ifftshift(fft.fftshift(x)), x)
27+
28+
def test_axes_keyword(self):
29+
freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
30+
shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
31+
assert_array_almost_equal(fft.fftshift(freqs, axes=(0, 1)), shifted)
32+
assert_array_almost_equal(fft.fftshift(freqs, axes=0),
33+
fft.fftshift(freqs, axes=(0,)))
34+
assert_array_almost_equal(fft.ifftshift(shifted, axes=(0, 1)), freqs)
35+
assert_array_almost_equal(fft.ifftshift(shifted, axes=0),
36+
fft.ifftshift(shifted, axes=(0,)))
37+
38+
assert_array_almost_equal(fft.fftshift(freqs), shifted)
39+
assert_array_almost_equal(fft.ifftshift(shifted), freqs)
40+
41+
def test_uneven_dims(self):
42+
""" Test 2D input, which has uneven dimension sizes """
43+
freqs = [
44+
[0, 1],
45+
[2, 3],
46+
[4, 5]
47+
]
48+
49+
# shift in dimension 0
50+
shift_dim0 = [
51+
[4, 5],
52+
[0, 1],
53+
[2, 3]
54+
]
55+
assert_array_almost_equal(fft.fftshift(freqs, axes=0), shift_dim0)
56+
assert_array_almost_equal(fft.ifftshift(shift_dim0, axes=0), freqs)
57+
assert_array_almost_equal(fft.fftshift(freqs, axes=(0,)), shift_dim0)
58+
assert_array_almost_equal(fft.ifftshift(shift_dim0, axes=[0]), freqs)
59+
60+
# shift in dimension 1
61+
shift_dim1 = [
62+
[1, 0],
63+
[3, 2],
64+
[5, 4]
65+
]
66+
assert_array_almost_equal(fft.fftshift(freqs, axes=1), shift_dim1)
67+
assert_array_almost_equal(fft.ifftshift(shift_dim1, axes=1), freqs)
68+
69+
# shift in both dimensions
70+
shift_dim_both = [
71+
[5, 4],
72+
[1, 0],
73+
[3, 2]
74+
]
75+
assert_array_almost_equal(fft.fftshift(freqs, axes=(0, 1)), shift_dim_both)
76+
assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=(0, 1)), freqs)
77+
assert_array_almost_equal(fft.fftshift(freqs, axes=[0, 1]), shift_dim_both)
78+
assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=[0, 1]), freqs)
79+
80+
# axes=None (default) shift in all dimensions
81+
assert_array_almost_equal(fft.fftshift(freqs, axes=None), shift_dim_both)
82+
assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=None), freqs)
83+
assert_array_almost_equal(fft.fftshift(freqs), shift_dim_both)
84+
assert_array_almost_equal(fft.ifftshift(shift_dim_both), freqs)
85+
86+
def test_equal_to_original(self):
87+
""" Test that the new (>=v1.15) implementation (see #10073) is equal to the original (<=v1.14) """
88+
from numpy.core import asarray, concatenate, arange, take
89+
90+
def original_fftshift(x, axes=None):
91+
""" How fftshift was implemented in v1.14"""
92+
tmp = asarray(x)
93+
ndim = tmp.ndim
94+
if axes is None:
95+
axes = list(range(ndim))
96+
elif isinstance(axes, int):
97+
axes = (axes,)
98+
y = tmp
99+
for k in axes:
100+
n = tmp.shape[k]
101+
p2 = (n + 1) // 2
102+
mylist = concatenate((arange(p2, n), arange(p2)))
103+
y = take(y, mylist, k)
104+
return y
105+
106+
def original_ifftshift(x, axes=None):
107+
""" How ifftshift was implemented in v1.14 """
108+
tmp = asarray(x)
109+
ndim = tmp.ndim
110+
if axes is None:
111+
axes = list(range(ndim))
112+
elif isinstance(axes, int):
113+
axes = (axes,)
114+
y = tmp
115+
for k in axes:
116+
n = tmp.shape[k]
117+
p2 = n - (n + 1) // 2
118+
mylist = concatenate((arange(p2, n), arange(p2)))
119+
y = take(y, mylist, k)
120+
return y
121+
122+
# create possible 2d array combinations and try all possible keywords
123+
# compare output to original functions
124+
for i in range(16):
125+
for j in range(16):
126+
for axes_keyword in [0, 1, None, (0,), (0, 1)]:
127+
inp = np.random.rand(i, j)
128+
129+
assert_array_almost_equal(fft.fftshift(inp, axes_keyword),
130+
original_fftshift(inp, axes_keyword))
131+
132+
assert_array_almost_equal(fft.ifftshift(inp, axes_keyword),
133+
original_ifftshift(inp, axes_keyword))
134+
135+
136+
class TestFFTFreq:
137+
138+
def test_definition(self):
139+
x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
140+
assert_array_almost_equal(9*fft.fftfreq(9), x)
141+
assert_array_almost_equal(9*pi*fft.fftfreq(9, pi), x)
142+
x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
143+
assert_array_almost_equal(10*fft.fftfreq(10), x)
144+
assert_array_almost_equal(10*pi*fft.fftfreq(10, pi), x)
145+
146+
147+
class TestRFFTFreq:
148+
149+
def test_definition(self):
150+
x = [0, 1, 2, 3, 4]
151+
assert_array_almost_equal(9*fft.rfftfreq(9), x)
152+
assert_array_almost_equal(9*pi*fft.rfftfreq(9, pi), x)
153+
x = [0, 1, 2, 3, 4, 5]
154+
assert_array_almost_equal(10*fft.rfftfreq(10), x)
155+
assert_array_almost_equal(10*pi*fft.rfftfreq(10, pi), x)
156+
157+
158+
class TestIRFFTN:
159+
160+
def test_not_last_axis_success(self):
161+
ar, ai = np.random.random((2, 16, 8, 32))
162+
a = ar + 1j*ai
163+
164+
axes = (-2,)
165+
166+
# Should not raise error
167+
fft.irfftn(a, axes=axes)

0 commit comments

Comments
 (0)