1
1
from __future__ import annotations
2
2
3
3
import functools
4
+
4
5
import torch
5
6
7
+ from . import _dtypes_impl , _util
6
8
from ._normalizations import ArrayLike , normalizer
7
- from . import _util
8
9
9
10
10
11
def upcast (func ):
11
12
"""NumPy fft casts inputs to 64 bit and *returns 64-bit results*."""
13
+
12
14
@functools .wraps (func )
13
15
def wrapped (tensor , * args , ** kwds ):
14
- target_dtype = torch .complex128 if tensor .is_complex () else torch .float64
16
+ target_dtype = (
17
+ _dtypes_impl .default_dtypes .complex_dtype
18
+ if tensor .is_complex ()
19
+ else _dtypes_impl .default_dtypes .float_dtype
20
+ )
15
21
tensor = _util .cast_if_needed (tensor , target_dtype )
16
22
return func (tensor , * args , ** kwds )
23
+
17
24
return wrapped
18
25
19
26
@@ -70,6 +77,7 @@ def irfftn(a: ArrayLike, s=None, axes=None, norm=None):
70
77
def fft2 (a : ArrayLike , s = None , axes = (- 2 , - 1 ), norm = None ):
71
78
return torch .fft .fft2 (a , s , dim = axes , norm = norm )
72
79
80
+
73
81
@normalizer
74
82
@upcast
75
83
def ifft2 (a : ArrayLike , s = None , axes = (- 2 , - 1 ), norm = None ):
@@ -93,9 +101,8 @@ def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
93
101
def hfft (a : ArrayLike , n = None , axis = - 1 , norm = None ):
94
102
return torch .fft .hfft (a , n , dim = axis , norm = norm )
95
103
104
+
96
105
@normalizer
97
106
@upcast
98
107
def ihfft (a : ArrayLike , n = None , axis = - 1 , norm = None ):
99
108
return torch .fft .ihfft (a , n , dim = axis , norm = norm )
100
-
101
-
0 commit comments