Skip to content

Commit c51ce80

Browse files
committed
Use better wording for the return dtypes for fft functions
Fixes data-apis#717
1 parent 5a14534 commit c51ce80

File tree

1 file changed

+14
-14
lines changed
  • src/array_api_stubs/_draft

1 file changed

+14
-14
lines changed

src/array_api_stubs/_draft/fft.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def fft(
6060
Returns
6161
-------
6262
out: array
63-
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex floating-point data type determined by :ref:`type-promotion`.
63+
an array transformed along the axis (dimension) indicated by ``axis``. If ``x`` has a complex floating-point data type, the returned array must have the same data type as ``x``. If ``x`` has a real-valued floating-point data type, the returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
6464
6565
Notes
6666
-----
@@ -111,7 +111,7 @@ def ifft(
111111
Returns
112112
-------
113113
out: array
114-
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex floating-point data type determined by :ref:`type-promotion`.
114+
an array transformed along the axis (dimension) indicated by ``axis``. If ``x`` has a complex floating-point data type, the returned array must have the same data type as ``x``. If ``x`` has a real-valued floating-point data type, the returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
115115
116116
Notes
117117
-----
@@ -169,7 +169,7 @@ def fftn(
169169
Returns
170170
-------
171171
out: array
172-
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a complex floating-point data type determined by :ref:`type-promotion`.
172+
an array transformed along the axes (dimension) indicated by ``axes``. If ``x`` has a complex floating-point data type, the returned array must have the same data type as ``x``. If ``x`` has a real-valued floating-point data type, the returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
173173
174174
Notes
175175
-----
@@ -227,7 +227,7 @@ def ifftn(
227227
Returns
228228
-------
229229
out: array
230-
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a complex floating-point data type determined by :ref:`type-promotion`.
230+
an array transformed along the axes (dimension) indicated by ``axes``. If ``x`` has a complex floating-point data type, the returned array must have the same data type as ``x``. If ``x`` has a real-valued floating-point data type, the returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
231231
232232
Notes
233233
-----
@@ -278,7 +278,7 @@ def rfft(
278278
Returns
279279
-------
280280
out: array
281-
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex-valued floating-point data type determined by :ref:`type-promotion`.
281+
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
282282
283283
Notes
284284
-----
@@ -329,7 +329,7 @@ def irfft(
329329
Returns
330330
-------
331331
out: array
332-
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`. The length along the transformed axis is ``n`` (if given) or ``2*(m-1)`` (otherwise).
332+
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have the same data type as ``x``. The length along the transformed axis is ``n`` (if given) or ``2*(m-1)`` (otherwise).
333333
334334
Notes
335335
-----
@@ -387,7 +387,7 @@ def rfftn(
387387
Returns
388388
-------
389389
out: array
390-
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a complex-valued floating-point data type determined by :ref:`type-promotion`.
390+
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
391391
392392
Notes
393393
-----
@@ -445,7 +445,7 @@ def irfftn(
445445
Returns
446446
-------
447447
out: array
448-
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`. The length along the last transformed axis is ``s[-1]`` (if given) or ``2*(m - 1)`` (otherwise), and all other axes ``s[i]``.
448+
an array transformed along the axes (dimension) indicated by ``axes``. The returned array must have a real-valued floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``complex128``, then the returned array must have a ``float64`` data type). The length along the last transformed axis is ``s[-1]`` (if given) or ``2*(m - 1)`` (otherwise), and all other axes ``s[i]``.
449449
450450
Notes
451451
-----
@@ -493,7 +493,7 @@ def hfft(
493493
Returns
494494
-------
495495
out: array
496-
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`.
496+
an array transformed along the axis (dimension) indicated by ``axis``. If ``x`` has a complex floating-point data type, the returned array must have the same data type as ``x``. If ``x`` has a real-valued floating-point data type, the returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
497497
498498
Notes
499499
-----
@@ -541,7 +541,7 @@ def ihfft(
541541
Returns
542542
-------
543543
out: array
544-
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex-valued floating-point data type determined by :ref:`type-promotion`.
544+
an array transformed along the axis (dimension) indicated by ``axis``. The returned array must have a complex floating-point data type whose precision matches the precision of ``x`` (e.g., if ``x`` is ``float64``, then the returned array must have a ``complex128`` data type).
545545
546546
Notes
547547
-----
@@ -552,7 +552,7 @@ def ihfft(
552552

553553
def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> array:
554554
"""
555-
Returns the discrete Fourier transform sample frequencies.
555+
Computes the discrete Fourier transform sample frequencies.
556556
557557
For a Fourier transform of length ``n`` and length unit of ``d`` the frequencies are described as:
558558
@@ -573,7 +573,7 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> ar
573573
Returns
574574
-------
575575
out: array
576-
an array of length ``n`` containing the sample frequencies. The returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`.
576+
an array of length ``n`` containing the sample frequencies. The returned array must have the default real-valued floating-point data type.
577577
578578
Notes
579579
-----
@@ -584,7 +584,7 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> ar
584584

585585
def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> array:
586586
"""
587-
Returns the discrete Fourier transform sample frequencies (for ``rfft`` and ``irfft``).
587+
Computes the discrete Fourier transform sample frequencies (for ``rfft`` and ``irfft``).
588588
589589
For a Fourier transform of length ``n`` and length unit of ``d`` the frequencies are described as:
590590
@@ -607,7 +607,7 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[device] = None) -> a
607607
Returns
608608
-------
609609
out: array
610-
an array of length ``n//2+1`` containing the sample frequencies. The returned array must have a real-valued floating-point data type determined by :ref:`type-promotion`.
610+
an array of length ``n//2+1`` containing the sample frequencies. The returned array must have the default real-valued floating-point data type.
611611
612612
Notes
613613
-----

0 commit comments

Comments
 (0)