Skip to content

Commit 9e0de0f

Browse files
committed
API: fft: make fft output dtype configurable
1 parent c9e5fc9 commit 9e0de0f

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import linalg, random, fft
1+
from . import fft, linalg, random
22
from ._dtypes import *
33
from ._funcs import *
44
from ._getlimits import errstate, finfo, iinfo

torch_np/fft.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
from __future__ import annotations
22

33
import functools
4+
45
import torch
56

7+
from . import _dtypes_impl, _util
68
from ._normalizations import ArrayLike, normalizer
7-
from . import _util
89

910

1011
def upcast(func):
1112
"""NumPy fft casts inputs to 64 bit and *returns 64-bit results*."""
13+
1214
@functools.wraps(func)
1315
def wrapped(tensor, *args, **kwds):
14-
target_dtype = torch.complex128 if tensor.is_complex() else torch.float64
16+
target_dtype = (
17+
_dtypes_impl.default_dtypes.complex_dtype
18+
if tensor.is_complex()
19+
else _dtypes_impl.default_dtypes.float_dtype
20+
)
1521
tensor = _util.cast_if_needed(tensor, target_dtype)
1622
return func(tensor, *args, **kwds)
23+
1724
return wrapped
1825

1926

@@ -70,6 +77,7 @@ def irfftn(a: ArrayLike, s=None, axes=None, norm=None):
7077
def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
7178
return torch.fft.fft2(a, s, dim=axes, norm=norm)
7279

80+
7381
@normalizer
7482
@upcast
7583
def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
@@ -93,9 +101,8 @@ def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
93101
def hfft(a: ArrayLike, n=None, axis=-1, norm=None):
94102
return torch.fft.hfft(a, n, dim=axis, norm=norm)
95103

104+
96105
@normalizer
97106
@upcast
98107
def ihfft(a: ArrayLike, n=None, axis=-1, norm=None):
99108
return torch.fft.ihfft(a, n, dim=axis, norm=norm)
100-
101-

0 commit comments

Comments
 (0)