Skip to content

Commit 035969e

Browse files
committed
Wrap fft for dask
1 parent 8b9e0c0 commit 035969e

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

array_api_compat/dask/array/fft.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from dask.array.fft import * # noqa: F403
2+
from numpy.fft import __all__ as fft_all
3+
4+
from ..common import _fft
5+
from .._internal import get_xp
6+
7+
import dask.array as da
8+
9+
fft = get_xp(da)(_fft.fft)
10+
ifft = get_xp(da)(_fft.ifft)
11+
fftn = get_xp(da)(_fft.fftn)
12+
ifftn = get_xp(da)(_fft.ifftn)
13+
rfft = get_xp(da)(_fft.rfft)
14+
irfft = get_xp(da)(_fft.irfft)
15+
rfftn = get_xp(da)(_fft.rfftn)
16+
irfftn = get_xp(da)(_fft.irfftn)
17+
hfft = get_xp(da)(_fft.hfft)
18+
ihfft = get_xp(da)(_fft.ihfft)
19+
fftfreq = get_xp(da)(_fft.fftfreq)
20+
rfftfreq = get_xp(da)(_fft.rfftfreq)
21+
fftshift = get_xp(da)(_fft.fftshift)
22+
ifftshift = get_xp(da)(_fft.ifftshift)
23+
24+
__all__ = fft_all + _fft.__all__
25+
26+
del get_xp
27+
del da
28+
del fft_all
29+
del _fft

dask-skips.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,2 @@
1-
# FFT isn't conformant
2-
array_api_tests/test_fft.py
3-
41
# slow and not implemented in dask
52
array_api_tests/test_linalg.py::test_matrix_power

0 commit comments

Comments
 (0)