Skip to content

Commit 0f32292

Browse files
committed
TST: fft tests pass collection and xfail
1 parent f3e0685 commit 0f32292

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import linalg, random
1+
from . import fft, linalg, random
22
from ._dtypes import *
33
from ._funcs import *
44
from ._getlimits import errstate, finfo, iinfo

torch_np/fft.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
def fft():
2+
raise NotImplementedError
3+
4+
5+
def ifft():
6+
raise NotImplementedError
7+
8+
9+
def fftn():
10+
raise NotImplementedError
11+
12+
13+
def ifftn():
14+
raise NotImplementedError
15+
16+
17+
def rfftn():
18+
raise NotImplementedError
19+
20+
21+
def irfftn():
22+
raise NotImplementedError
23+
24+
25+
def fft2():
26+
raise NotImplementedError
27+
28+
29+
def ifft2():
30+
raise NotImplementedError

torch_np/tests/numpy_tests/fft/test_pocketfft.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
import numpy as np
1+
import torch_np as np
22
import pytest
3-
from numpy.random import random
4-
from numpy.testing import (
5-
assert_array_equal, assert_raises, assert_allclose, IS_WASM
3+
from pytest import raises as assert_raises
4+
5+
from torch_np.random import random
6+
from torch_np.testing import (
7+
assert_array_equal, assert_allclose #, IS_WASM
68
)
79
import threading
810
import queue
911

12+
IS_WASM = False
13+
1014

1115
def fft1(x):
1216
L = len(x)
@@ -15,12 +19,14 @@ def fft1(x):
1519
return np.sum(x*np.exp(phase), axis=1)
1620

1721

22+
@pytest.mark.xfail(reason='TODO')
1823
class TestFFTShift:
1924

2025
def test_fft_n(self):
2126
assert_raises(ValueError, np.fft.fft, [1, 2, 3], 0)
2227

2328

29+
@pytest.mark.xfail(reason='TODO')
2430
class TestFFT1D:
2531

2632
def test_identity(self):
@@ -219,8 +225,7 @@ def test_all_1d_norm_preserving(self):
219225
assert_allclose(x_norm,
220226
np.linalg.norm(tmp), atol=1e-6)
221227

222-
@pytest.mark.parametrize("dtype", [np.half, np.single, np.double,
223-
np.longdouble])
228+
@pytest.mark.parametrize("dtype", [np.half, np.single, np.double])
224229
def test_dtypes(self, dtype):
225230
# make sure that all input precisions are accepted and internally
226231
# converted to 64bit
@@ -229,6 +234,7 @@ def test_dtypes(self, dtype):
229234
assert_allclose(np.fft.irfft(np.fft.rfft(x)), x, atol=1e-6)
230235

231236

237+
@pytest.mark.xfail(reason='TODO')
232238
@pytest.mark.parametrize(
233239
"dtype",
234240
[np.float32, np.float64, np.complex64, np.complex128])
@@ -268,6 +274,7 @@ def test_fft_with_order(dtype, order, fft):
268274
raise ValueError()
269275

270276

277+
@pytest.mark.xfail(reason='TODO')
271278
@pytest.mark.skipif(IS_WASM, reason="Cannot start thread")
272279
class TestFFTThreadSafe:
273280
threads = 16

0 commit comments

Comments
 (0)