Skip to content

Commit fe9760e

Browse files
committed
ENH: add dtype kwarg to fft.{fftfreq, rfftfreq}
1 parent 1a288de commit fe9760e

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

array_api_strict/_fft.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
if TYPE_CHECKING:
66
from typing import Union, Optional, Literal
7-
from ._typing import Device
7+
from ._typing import Device, Dtype as DType
88
from collections.abc import Sequence
99

1010
from ._dtypes import (
@@ -251,26 +251,52 @@ def ihfft(
251251
return res
252252

253253
@requires_extension('fft')
254-
def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array:
254+
def fftfreq(
255+
n: int,
256+
/,
257+
*,
258+
d: float = 1.0,
259+
dtype: Optional[DType] = None,
260+
device: Optional[Device] = None
261+
) -> Array:
255262
"""
256263
Array API compatible wrapper for :py:func:`np.fft.fftfreq <numpy.fft.fftfreq>`.
257264
258265
See its docstring for more information.
259266
"""
260267
if device is not None and device not in ALL_DEVICES:
261268
raise ValueError(f"Unsupported device {device!r}")
262-
return Array._new(np.fft.fftfreq(n, d=d), device=device)
269+
if dtype and not dtype in _real_floating_dtypes:
270+
raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.")
271+
272+
np_result = np.fft.fftfreq(n, d=d)
273+
if dtype:
274+
np_result = np_result.astype(dtype._np_dtype)
275+
return Array._new(np_result, device=device)
263276

264277
@requires_extension('fft')
265-
def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array:
278+
def rfftfreq(
279+
n: int,
280+
/,
281+
*,
282+
d: float = 1.0,
283+
dtype: Optional[DType] = None,
284+
device: Optional[Device] = None
285+
) -> Array:
266286
"""
267287
Array API compatible wrapper for :py:func:`np.fft.rfftfreq <numpy.fft.rfftfreq>`.
268288
269289
See its docstring for more information.
270290
"""
271291
if device is not None and device not in ALL_DEVICES:
272292
raise ValueError(f"Unsupported device {device!r}")
273-
return Array._new(np.fft.rfftfreq(n, d=d), device=device)
293+
if dtype and not dtype in _real_floating_dtypes:
294+
raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.")
295+
296+
np_result = np.fft.rfftfreq(n, d=d)
297+
if dtype:
298+
np_result = np_result.astype(dtype._np_dtype)
299+
return Array._new(np_result, device=device)
274300

275301
@requires_extension('fft')
276302
def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:

0 commit comments

Comments
 (0)