|
| 1 | +"""Test functions for fftpack.helper module |
| 2 | +
|
| 3 | +Copied from fftpack.helper by Pearu Peterson, October 2005 |
| 4 | +
|
| 5 | +""" |
| 6 | +import torch_np as np |
| 7 | +from torch_np.testing import assert_array_almost_equal |
| 8 | +from torch_np import fft, pi |
| 9 | + |
| 10 | +import pytest |
| 11 | + |
| 12 | +@pytest.mark.xfail(reason="TODO") |
| 13 | +class TestFFTShift: |
| 14 | + |
| 15 | + def test_definition(self): |
| 16 | + x = [0, 1, 2, 3, 4, -4, -3, -2, -1] |
| 17 | + y = [-4, -3, -2, -1, 0, 1, 2, 3, 4] |
| 18 | + assert_array_almost_equal(fft.fftshift(x), y) |
| 19 | + assert_array_almost_equal(fft.ifftshift(y), x) |
| 20 | + x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1] |
| 21 | + y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] |
| 22 | + assert_array_almost_equal(fft.fftshift(x), y) |
| 23 | + assert_array_almost_equal(fft.ifftshift(y), x) |
| 24 | + |
| 25 | + def test_inverse(self): |
| 26 | + for n in [1, 4, 9, 100, 211]: |
| 27 | + x = np.random.random((n,)) |
| 28 | + assert_array_almost_equal(fft.ifftshift(fft.fftshift(x)), x) |
| 29 | + |
| 30 | + def test_axes_keyword(self): |
| 31 | + freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]] |
| 32 | + shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]] |
| 33 | + assert_array_almost_equal(fft.fftshift(freqs, axes=(0, 1)), shifted) |
| 34 | + assert_array_almost_equal(fft.fftshift(freqs, axes=0), |
| 35 | + fft.fftshift(freqs, axes=(0,))) |
| 36 | + assert_array_almost_equal(fft.ifftshift(shifted, axes=(0, 1)), freqs) |
| 37 | + assert_array_almost_equal(fft.ifftshift(shifted, axes=0), |
| 38 | + fft.ifftshift(shifted, axes=(0,))) |
| 39 | + |
| 40 | + assert_array_almost_equal(fft.fftshift(freqs), shifted) |
| 41 | + assert_array_almost_equal(fft.ifftshift(shifted), freqs) |
| 42 | + |
| 43 | + def test_uneven_dims(self): |
| 44 | + """ Test 2D input, which has uneven dimension sizes """ |
| 45 | + freqs = [ |
| 46 | + [0, 1], |
| 47 | + [2, 3], |
| 48 | + [4, 5] |
| 49 | + ] |
| 50 | + |
| 51 | + # shift in dimension 0 |
| 52 | + shift_dim0 = [ |
| 53 | + [4, 5], |
| 54 | + [0, 1], |
| 55 | + [2, 3] |
| 56 | + ] |
| 57 | + assert_array_almost_equal(fft.fftshift(freqs, axes=0), shift_dim0) |
| 58 | + assert_array_almost_equal(fft.ifftshift(shift_dim0, axes=0), freqs) |
| 59 | + assert_array_almost_equal(fft.fftshift(freqs, axes=(0,)), shift_dim0) |
| 60 | + assert_array_almost_equal(fft.ifftshift(shift_dim0, axes=[0]), freqs) |
| 61 | + |
| 62 | + # shift in dimension 1 |
| 63 | + shift_dim1 = [ |
| 64 | + [1, 0], |
| 65 | + [3, 2], |
| 66 | + [5, 4] |
| 67 | + ] |
| 68 | + assert_array_almost_equal(fft.fftshift(freqs, axes=1), shift_dim1) |
| 69 | + assert_array_almost_equal(fft.ifftshift(shift_dim1, axes=1), freqs) |
| 70 | + |
| 71 | + # shift in both dimensions |
| 72 | + shift_dim_both = [ |
| 73 | + [5, 4], |
| 74 | + [1, 0], |
| 75 | + [3, 2] |
| 76 | + ] |
| 77 | + assert_array_almost_equal(fft.fftshift(freqs, axes=(0, 1)), shift_dim_both) |
| 78 | + assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=(0, 1)), freqs) |
| 79 | + assert_array_almost_equal(fft.fftshift(freqs, axes=[0, 1]), shift_dim_both) |
| 80 | + assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=[0, 1]), freqs) |
| 81 | + |
| 82 | + # axes=None (default) shift in all dimensions |
| 83 | + assert_array_almost_equal(fft.fftshift(freqs, axes=None), shift_dim_both) |
| 84 | + assert_array_almost_equal(fft.ifftshift(shift_dim_both, axes=None), freqs) |
| 85 | + assert_array_almost_equal(fft.fftshift(freqs), shift_dim_both) |
| 86 | + assert_array_almost_equal(fft.ifftshift(shift_dim_both), freqs) |
| 87 | + |
| 88 | + def test_equal_to_original(self): |
| 89 | + """ Test that the new (>=v1.15) implementation (see #10073) is equal to the original (<=v1.14) """ |
| 90 | + from numpy.core import asarray, concatenate, arange, take |
| 91 | + |
| 92 | + def original_fftshift(x, axes=None): |
| 93 | + """ How fftshift was implemented in v1.14""" |
| 94 | + tmp = asarray(x) |
| 95 | + ndim = tmp.ndim |
| 96 | + if axes is None: |
| 97 | + axes = list(range(ndim)) |
| 98 | + elif isinstance(axes, int): |
| 99 | + axes = (axes,) |
| 100 | + y = tmp |
| 101 | + for k in axes: |
| 102 | + n = tmp.shape[k] |
| 103 | + p2 = (n + 1) // 2 |
| 104 | + mylist = concatenate((arange(p2, n), arange(p2))) |
| 105 | + y = take(y, mylist, k) |
| 106 | + return y |
| 107 | + |
| 108 | + def original_ifftshift(x, axes=None): |
| 109 | + """ How ifftshift was implemented in v1.14 """ |
| 110 | + tmp = asarray(x) |
| 111 | + ndim = tmp.ndim |
| 112 | + if axes is None: |
| 113 | + axes = list(range(ndim)) |
| 114 | + elif isinstance(axes, int): |
| 115 | + axes = (axes,) |
| 116 | + y = tmp |
| 117 | + for k in axes: |
| 118 | + n = tmp.shape[k] |
| 119 | + p2 = n - (n + 1) // 2 |
| 120 | + mylist = concatenate((arange(p2, n), arange(p2))) |
| 121 | + y = take(y, mylist, k) |
| 122 | + return y |
| 123 | + |
| 124 | + # create possible 2d array combinations and try all possible keywords |
| 125 | + # compare output to original functions |
| 126 | + for i in range(16): |
| 127 | + for j in range(16): |
| 128 | + for axes_keyword in [0, 1, None, (0,), (0, 1)]: |
| 129 | + inp = np.random.rand(i, j) |
| 130 | + |
| 131 | + assert_array_almost_equal(fft.fftshift(inp, axes_keyword), |
| 132 | + original_fftshift(inp, axes_keyword)) |
| 133 | + |
| 134 | + assert_array_almost_equal(fft.ifftshift(inp, axes_keyword), |
| 135 | + original_ifftshift(inp, axes_keyword)) |
| 136 | + |
| 137 | + |
| 138 | +@pytest.mark.xfail(reason="TODO") |
| 139 | +class TestFFTFreq: |
| 140 | + |
| 141 | + def test_definition(self): |
| 142 | + x = [0, 1, 2, 3, 4, -4, -3, -2, -1] |
| 143 | + assert_array_almost_equal(9*fft.fftfreq(9), x) |
| 144 | + assert_array_almost_equal(9*pi*fft.fftfreq(9, pi), x) |
| 145 | + x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1] |
| 146 | + assert_array_almost_equal(10*fft.fftfreq(10), x) |
| 147 | + assert_array_almost_equal(10*pi*fft.fftfreq(10, pi), x) |
| 148 | + |
| 149 | + |
| 150 | +@pytest.mark.xfail(reason="TODO") |
| 151 | +class TestRFFTFreq: |
| 152 | + |
| 153 | + def test_definition(self): |
| 154 | + x = [0, 1, 2, 3, 4] |
| 155 | + assert_array_almost_equal(9*fft.rfftfreq(9), x) |
| 156 | + assert_array_almost_equal(9*pi*fft.rfftfreq(9, pi), x) |
| 157 | + x = [0, 1, 2, 3, 4, 5] |
| 158 | + assert_array_almost_equal(10*fft.rfftfreq(10), x) |
| 159 | + assert_array_almost_equal(10*pi*fft.rfftfreq(10, pi), x) |
| 160 | + |
| 161 | + |
| 162 | +@pytest.mark.xfail(reason="TODO") |
| 163 | +class TestIRFFTN: |
| 164 | + |
| 165 | + def test_not_last_axis_success(self): |
| 166 | + ar, ai = np.random.random((2, 16, 8, 32)) |
| 167 | + a = ar + 1j*ai |
| 168 | + |
| 169 | + axes = (-2,) |
| 170 | + |
| 171 | + # Should not raise error |
| 172 | + fft.irfftn(a, axes=axes) |
0 commit comments