5
5
from typing import Any , Dict , NamedTuple , Sequence , Tuple , Union
6
6
from warnings import warn
7
7
8
+ from . import api_version
8
9
from . import _array_module as xp
9
10
from ._array_module import _UndefinedStub
10
11
from .stubs import name_to_func
15
16
"uint_dtypes" ,
16
17
"all_int_dtypes" ,
17
18
"float_dtypes" ,
19
+ "real_dtypes" ,
18
20
"numeric_dtypes" ,
19
21
"all_dtypes" ,
20
- "dtype_to_name " ,
22
+ "all_float_dtypes " ,
21
23
"bool_and_all_int_dtypes" ,
24
+ "dtype_to_name" ,
22
25
"dtype_to_scalars" ,
23
26
"is_int_dtype" ,
24
27
"is_float_dtype" ,
27
30
"default_int" ,
28
31
"default_uint" ,
29
32
"default_float" ,
33
+ "default_complex" ,
30
34
"promotion_table" ,
31
35
"dtype_nbits" ,
32
36
"dtype_signed" ,
37
+ "dtype_components" ,
33
38
"func_in_dtypes" ,
34
39
"func_returns_bool" ,
35
40
"binary_op_to_symbol" ,
@@ -86,15 +91,25 @@ def __repr__(self):
86
91
_uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
87
92
_int_names = ("int8" , "int16" , "int32" , "int64" )
88
93
_float_names = ("float32" , "float64" )
89
- _dtype_names = ("bool" ,) + _uint_names + _int_names + _float_names
94
+ _real_names = _uint_names + _int_names + _float_names
95
+ _complex_names = ("complex64" , "complex128" )
96
+ _numeric_names = _real_names + _complex_names
97
+ _dtype_names = ("bool" ,) + _numeric_names
90
98
91
99
92
100
uint_dtypes = tuple (getattr (xp , name ) for name in _uint_names )
93
101
int_dtypes = tuple (getattr (xp , name ) for name in _int_names )
94
102
float_dtypes = tuple (getattr (xp , name ) for name in _float_names )
95
103
all_int_dtypes = uint_dtypes + int_dtypes
96
- numeric_dtypes = all_int_dtypes + float_dtypes
104
+ real_dtypes = all_int_dtypes + float_dtypes
105
+ complex_dtypes = tuple (getattr (xp , name ) for name in _complex_names )
106
+ numeric_dtypes = real_dtypes
107
+ if api_version > "2021.12" :
108
+ numeric_dtypes += complex_dtypes
97
109
all_dtypes = (xp .bool ,) + numeric_dtypes
110
+ all_float_dtypes = float_dtypes
111
+ if api_version > "2021.12" :
112
+ all_float_dtypes += complex_dtypes
98
113
bool_and_all_int_dtypes = (xp .bool ,) + all_int_dtypes
99
114
100
115
@@ -121,14 +136,19 @@ def is_float_dtype(dtype):
121
136
# See https://github.com/numpy/numpy/issues/18434
122
137
if dtype is None :
123
138
return False
124
- return dtype in float_dtypes
139
+ valid_dtypes = float_dtypes
140
+ if api_version > "2021.12" :
141
+ valid_dtypes += complex_dtypes
142
+ return dtype in valid_dtypes
125
143
126
144
127
145
def get_scalar_type (dtype : DataType ) -> ScalarType :
128
146
if is_int_dtype (dtype ):
129
147
return int
130
148
elif is_float_dtype (dtype ):
131
149
return float
150
+ elif dtype in complex_dtypes :
151
+ return complex
132
152
else :
133
153
return bool
134
154
@@ -157,7 +177,8 @@ class MinMax(NamedTuple):
157
177
[(d , 8 ) for d in [xp .int8 , xp .uint8 ]]
158
178
+ [(d , 16 ) for d in [xp .int16 , xp .uint16 ]]
159
179
+ [(d , 32 ) for d in [xp .int32 , xp .uint32 , xp .float32 ]]
160
- + [(d , 64 ) for d in [xp .int64 , xp .uint64 , xp .float64 ]]
180
+ + [(d , 64 ) for d in [xp .int64 , xp .uint64 , xp .float64 , xp .complex64 ]]
181
+ + [(xp .complex128 , 128 )]
161
182
)
162
183
163
184
@@ -166,6 +187,11 @@ class MinMax(NamedTuple):
166
187
)
167
188
168
189
190
+ dtype_components = EqualityMapping (
191
+ [(xp .complex64 , xp .float32 ), (xp .complex128 , xp .float64 )]
192
+ )
193
+
194
+
169
195
if isinstance (xp .asarray , _UndefinedStub ):
170
196
default_int = xp .int32
171
197
default_float = xp .float32
@@ -180,6 +206,15 @@ class MinMax(NamedTuple):
180
206
default_float = xp .asarray (float ()).dtype
181
207
if default_float not in float_dtypes :
182
208
warn (f"inferred default float is { default_float !r} , which is not a float" )
209
+ if api_version > "2021.12" :
210
+ default_complex = xp .asarray (complex ()).dtype
211
+ if default_complex not in complex_dtypes :
212
+ warn (
213
+ f"inferred default complex is { default_complex !r} , "
214
+ "which is not a complex"
215
+ )
216
+ else :
217
+ default_complex = None
183
218
if dtype_nbits [default_int ] == 32 :
184
219
default_uint = xp .uint32
185
220
else :
@@ -226,6 +261,11 @@ class MinMax(NamedTuple):
226
261
((xp .float32 , xp .float32 ), xp .float32 ),
227
262
((xp .float32 , xp .float64 ), xp .float64 ),
228
263
((xp .float64 , xp .float64 ), xp .float64 ),
264
+ # complex
265
+ ((xp .complex64 , xp .complex64 ), xp .complex64 ),
266
+ ((xp .complex64 , xp .complex128 ), xp .complex128 ),
267
+ ((xp .complex128 , xp .complex128 ), xp .complex128 ),
268
+
229
269
]
230
270
_numeric_promotions += [((d2 , d1 ), res ) for (d1 , d2 ), res in _numeric_promotions ]
231
271
_promotion_table = list (set (_numeric_promotions ))
0 commit comments