Skip to content

Commit 9eeb526

Browse files
committed
API: fft: make fft output dtype configurable
1 parent f9b2855 commit 9eeb526

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import linalg, random, fft
1+
from . import fft, linalg, random
22
from ._dtypes import *
33
from ._funcs import *
44
from ._getlimits import errstate, finfo, iinfo

torch_np/fft.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
from __future__ import annotations
22

33
import functools
4+
45
import torch
56

7+
from . import _dtypes_impl, _util
68
from ._normalizations import ArrayLike, normalizer
7-
from . import _util
89

910

1011
def upcast(func):
1112
"""NumPy fft casts inputs to 64 bit and *returns 64-bit results*."""
13+
1214
@functools.wraps(func)
1315
def wrapped(tensor, *args, **kwds):
14-
target_dtype = torch.complex128 if tensor.is_complex() else torch.float64
16+
target_dtype = (
17+
_dtypes_impl.default_dtypes.complex_dtype
18+
if tensor.is_complex()
19+
else _dtypes_impl.default_dtypes.float_dtype
20+
)
1521
tensor = _util.cast_if_needed(tensor, target_dtype)
1622
return func(tensor, *args, **kwds)
23+
1724
return wrapped
1825

26+
1927
@normalizer
2028
@upcast
2129
def fft(a: ArrayLike, n=None, axis=-1, norm=None):
@@ -69,6 +77,7 @@ def irfftn(a: ArrayLike, s=None, axes=None, norm=None):
6977
def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
7078
return torch.fft.fft2(a, s, dim=axes, norm=norm)
7179

80+
7281
@normalizer
7382
@upcast
7483
def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
@@ -92,9 +101,8 @@ def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
92101
def hfft(a: ArrayLike, n=None, axis=-1, norm=None):
93102
return torch.fft.hfft(a, n, dim=axis, norm=norm)
94103

104+
95105
@normalizer
96106
@upcast
97107
def ihfft(a: ArrayLike, n=None, axis=-1, norm=None):
98108
return torch.fft.ihfft(a, n, dim=axis, norm=norm)
99-
100-

0 commit comments

Comments
 (0)