|
4 | 4 |
|
5 | 5 | if TYPE_CHECKING:
|
6 | 6 | from typing import Union, Optional, Literal
|
7 |
| - from ._typing import Device |
| 7 | + from ._typing import Device, Dtype as DType |
8 | 8 | from collections.abc import Sequence
|
9 | 9 |
|
10 | 10 | from ._dtypes import (
|
@@ -251,26 +251,52 @@ def ihfft(
|
251 | 251 | return res
|
252 | 252 |
|
253 | 253 | @requires_extension('fft')
|
254 |
| -def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: |
| 254 | +def fftfreq( |
| 255 | + n: int, |
| 256 | + /, |
| 257 | + *, |
| 258 | + d: float = 1.0, |
| 259 | + dtype: Optional[DType] = None, |
| 260 | + device: Optional[Device] = None |
| 261 | +) -> Array: |
255 | 262 | """
|
256 | 263 | Array API compatible wrapper for :py:func:`np.fft.fftfreq <numpy.fft.fftfreq>`.
|
257 | 264 |
|
258 | 265 | See its docstring for more information.
|
259 | 266 | """
|
260 | 267 | if device is not None and device not in ALL_DEVICES:
|
261 | 268 | raise ValueError(f"Unsupported device {device!r}")
|
262 |
| - return Array._new(np.fft.fftfreq(n, d=d), device=device) |
| 269 | + if dtype and not dtype in _real_floating_dtypes: |
| 270 | + raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.") |
| 271 | + |
| 272 | + np_result = np.fft.fftfreq(n, d=d) |
| 273 | + if dtype: |
| 274 | + np_result = np_result.astype(dtype._np_dtype) |
| 275 | + return Array._new(np_result, device=device) |
263 | 276 |
|
264 | 277 | @requires_extension('fft')
|
265 |
| -def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: |
| 278 | +def rfftfreq( |
| 279 | + n: int, |
| 280 | + /, |
| 281 | + *, |
| 282 | + d: float = 1.0, |
| 283 | + dtype: Optional[DType] = None, |
| 284 | + device: Optional[Device] = None |
| 285 | +) -> Array: |
266 | 286 | """
|
267 | 287 | Array API compatible wrapper for :py:func:`np.fft.rfftfreq <numpy.fft.rfftfreq>`.
|
268 | 288 |
|
269 | 289 | See its docstring for more information.
|
270 | 290 | """
|
271 | 291 | if device is not None and device not in ALL_DEVICES:
|
272 | 292 | raise ValueError(f"Unsupported device {device!r}")
|
273 |
| - return Array._new(np.fft.rfftfreq(n, d=d), device=device) |
| 293 | + if dtype and not dtype in _real_floating_dtypes: |
| 294 | + raise ValueError(f"`dtype` must be a real floating-point type. Got {dtype=}.") |
| 295 | + |
| 296 | + np_result = np.fft.rfftfreq(n, d=d) |
| 297 | + if dtype: |
| 298 | + np_result = np_result.astype(dtype._np_dtype) |
| 299 | + return Array._new(np_result, device=device) |
274 | 300 |
|
275 | 301 | @requires_extension('fft')
|
276 | 302 | def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array:
|
|
0 commit comments