2
2
from collections import defaultdict
3
3
from collections .abc import Mapping
4
4
from functools import lru_cache
5
- from typing import Any , DefaultDict , NamedTuple , Sequence , Tuple , Union
5
+ from typing import Any , DefaultDict , Dict , List , NamedTuple , Sequence , Tuple , Union
6
6
from warnings import warn
7
7
8
- from . import _array_module as xp
9
8
from . import api_version
10
- from ._array_module import _UndefinedStub
11
- from ._array_module import mod as _xp
9
+ from ._array_module import mod as xp
12
10
from .stubs import name_to_func
13
11
from .typing import DataType , ScalarType
14
12
15
13
__all__ = [
14
+ "uint_names" ,
15
+ "int_names" ,
16
+ "float_names" ,
17
+ "real_names" ,
18
+ "complex_names" ,
19
+ "numeric_names" ,
20
+ "dtype_names" ,
16
21
"int_dtypes" ,
17
22
"uint_dtypes" ,
18
23
"all_int_dtypes" ,
@@ -90,27 +95,42 @@ def __repr__(self):
90
95
return f"EqualityMapping({ self } )"
91
96
92
97
93
- def _filter_stubs (* args ):
94
- for a in args :
95
- if not isinstance (a , _UndefinedStub ):
96
- yield a
98
+ uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
99
+ int_names = ("int8" , "int16" , "int32" , "int64" )
100
+ float_names = ("float32" , "float64" )
101
+ real_names = uint_names + int_names + float_names
102
+ complex_names = ("complex64" , "complex128" )
103
+ numeric_names = real_names + complex_names
104
+ dtype_names = ("bool" ,) + numeric_names
97
105
98
106
99
- _uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
100
- _int_names = ("int8" , "int16" , "int32" , "int64" )
101
- _float_names = ("float32" , "float64" )
102
- _real_names = _uint_names + _int_names + _float_names
103
- _complex_names = ("complex64" , "complex128" )
104
- _numeric_names = _real_names + _complex_names
105
- _dtype_names = ("bool" ,) + _numeric_names
107
+ _name_to_dtype = {}
108
+ for name in dtype_names :
109
+ try :
110
+ dtype = getattr (xp , name )
111
+ except AttributeError :
112
+ continue
113
+ _name_to_dtype [name ] = dtype
114
+ dtype_to_name = EqualityMapping ([(d , n ) for n , d in _name_to_dtype .items ()])
106
115
107
116
108
- uint_dtypes = tuple (getattr (xp , name ) for name in _uint_names )
109
- int_dtypes = tuple (getattr (xp , name ) for name in _int_names )
110
- float_dtypes = tuple (getattr (xp , name ) for name in _float_names )
117
+ def _make_dtype_tuple_from_names (names : List [str ]) -> Tuple [DataType ]:
118
+ dtypes = []
119
+ for name in names :
120
+ try :
121
+ dtype = _name_to_dtype [name ]
122
+ except KeyError :
123
+ continue
124
+ dtypes .append (dtype )
125
+ return tuple (dtypes )
126
+
127
+
128
+ uint_dtypes = _make_dtype_tuple_from_names (uint_names )
129
+ int_dtypes = _make_dtype_tuple_from_names (int_names )
130
+ float_dtypes = _make_dtype_tuple_from_names (float_names )
111
131
all_int_dtypes = uint_dtypes + int_dtypes
112
132
real_dtypes = all_int_dtypes + float_dtypes
113
- complex_dtypes = tuple ( getattr ( xp , name ) for name in _complex_names )
133
+ complex_dtypes = _make_dtype_tuple_from_names ( complex_names )
114
134
numeric_dtypes = real_dtypes
115
135
if api_version > "2021.12" :
116
136
numeric_dtypes += complex_dtypes
@@ -121,16 +141,6 @@ def _filter_stubs(*args):
121
141
bool_and_all_int_dtypes = (xp .bool ,) + all_int_dtypes
122
142
123
143
124
- _dtype_name_pairs = []
125
- for name in _dtype_names :
126
- try :
127
- dtype = getattr (_xp , name )
128
- except AttributeError :
129
- continue
130
- _dtype_name_pairs .append ((dtype , name ))
131
- dtype_to_name = EqualityMapping (_dtype_name_pairs )
132
-
133
-
134
144
dtype_to_scalars = EqualityMapping (
135
145
[
136
146
(xp .bool , [bool ]),
@@ -179,47 +189,59 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
179
189
return bool
180
190
181
191
192
+ def _make_dtype_mapping_from_names (mapping : Dict [str , Any ]) -> EqualityMapping :
193
+ dtype_value_pairs = []
194
+ for name , value in mapping .items ():
195
+ assert isinstance (name , str ) and name in dtype_names # sanity check
196
+ try :
197
+ dtype = getattr (xp , name )
198
+ except AttributeError :
199
+ continue
200
+ dtype_value_pairs .append ((dtype , value ))
201
+ return EqualityMapping (dtype_value_pairs )
202
+
203
+
182
204
class MinMax (NamedTuple ):
183
205
min : Union [int , float ]
184
206
max : Union [int , float ]
185
207
186
208
187
- dtype_ranges = EqualityMapping (
188
- [
189
- ( xp . int8 , MinMax (- 128 , + 127 ) ),
190
- ( xp . int16 , MinMax (- 32_768 , + 32_767 ) ),
191
- ( xp . int32 , MinMax (- 2_147_483_648 , + 2_147_483_647 ) ),
192
- ( xp . int64 , MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 ) ),
193
- ( xp . uint8 , MinMax (0 , + 255 ) ),
194
- ( xp . uint16 , MinMax (0 , + 65_535 ) ),
195
- ( xp . uint32 , MinMax (0 , + 4_294_967_295 ) ),
196
- ( xp . uint64 , MinMax (0 , + 18_446_744_073_709_551_615 ) ),
197
- ( xp . float32 , MinMax (- 3.4028234663852886e38 , 3.4028234663852886e38 ) ),
198
- ( xp . float64 , MinMax (- 1.7976931348623157e308 , 1.7976931348623157e308 ) ),
199
- ]
209
+ dtype_ranges = _make_dtype_mapping_from_names (
210
+ {
211
+ " int8" : MinMax (- 128 , + 127 ),
212
+ " int16" : MinMax (- 32_768 , + 32_767 ),
213
+ " int32" : MinMax (- 2_147_483_648 , + 2_147_483_647 ),
214
+ " int64" : MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 ),
215
+ " uint8" : MinMax (0 , + 255 ),
216
+ " uint16" : MinMax (0 , + 65_535 ),
217
+ " uint32" : MinMax (0 , + 4_294_967_295 ),
218
+ " uint64" : MinMax (0 , + 18_446_744_073_709_551_615 ),
219
+ " float32" : MinMax (- 3.4028234663852886e38 , 3.4028234663852886e38 ),
220
+ " float64" : MinMax (- 1.7976931348623157e308 , 1.7976931348623157e308 ),
221
+ }
200
222
)
201
223
202
224
203
- dtype_nbits = EqualityMapping (
204
- [( d , 8 ) for d in _filter_stubs ( xp . int8 , xp . uint8 )]
205
- + [( d , 16 ) for d in _filter_stubs ( xp . int16 , xp . uint16 )]
206
- + [( d , 32 ) for d in _filter_stubs ( xp . int32 , xp . uint32 , xp . float32 )]
207
- + [( d , 64 ) for d in _filter_stubs ( xp . int64 , xp . uint64 , xp . float64 , xp . complex64 )]
208
- + [( d , 128 ) for d in _filter_stubs ( xp . complex128 )]
209
- )
225
+ r_nbits = re . compile ( r"[a-z]+([0-9]+)" )
226
+ _dtype_nbits : Dict [ str , int ] = {}
227
+ for name in numeric_names :
228
+ m = r_nbits . fullmatch ( name )
229
+ assert m is not None # sanity check / for mypy
230
+ _dtype_nbits [ name ] = int ( m . group ( 1 ))
231
+ dtype_nbits = _make_dtype_mapping_from_names ( _dtype_nbits )
210
232
211
233
212
- dtype_signed = EqualityMapping (
213
- [( d , True ) for d in int_dtypes ] + [( d , False ) for d in uint_dtypes ]
234
+ dtype_signed = _make_dtype_mapping_from_names (
235
+ { ** { name : True for name in int_names }, ** { name : False for name in uint_names }}
214
236
)
215
237
216
238
217
- dtype_components = EqualityMapping (
218
- [( xp . complex64 , xp .float32 ), ( xp . complex128 , xp .float64 )]
239
+ dtype_components = _make_dtype_mapping_from_names (
240
+ { " complex64" : xp .float32 , " complex128" : xp .float64 }
219
241
)
220
242
221
243
222
- if isinstance (xp . asarray , _UndefinedStub ):
244
+ if not hasattr (xp , "asarray" ):
223
245
default_int = xp .int32
224
246
default_float = xp .float32
225
247
warn (
@@ -243,60 +265,73 @@ class MinMax(NamedTuple):
243
265
else :
244
266
default_complex = None
245
267
if dtype_nbits [default_int ] == 32 :
246
- default_uint = xp . uint32
268
+ default_uint = getattr ( xp , " uint32" , None )
247
269
else :
248
- default_uint = xp .uint64
249
-
270
+ default_uint = getattr (xp , "uint64" , None )
250
271
251
- _numeric_promotions = [
272
+ _promotion_table : Dict [Tuple [str , str ], str ] = {
273
+ ("bool" , "bool" ): "bool" ,
252
274
# ints
253
- (( xp . int8 , xp . int8 ), xp . int8 ) ,
254
- (( xp . int8 , xp . int16 ), xp . int16 ) ,
255
- (( xp . int8 , xp . int32 ), xp . int32 ) ,
256
- (( xp . int8 , xp . int64 ), xp . int64 ) ,
257
- (( xp . int16 , xp . int16 ), xp . int16 ) ,
258
- (( xp . int16 , xp . int32 ), xp . int32 ) ,
259
- (( xp . int16 , xp . int64 ), xp . int64 ) ,
260
- (( xp . int32 , xp . int32 ), xp . int32 ) ,
261
- (( xp . int32 , xp . int64 ), xp . int64 ) ,
262
- (( xp . int64 , xp . int64 ), xp . int64 ) ,
275
+ (" int8" , " int8" ): " int8" ,
276
+ (" int8" , " int16" ): " int16" ,
277
+ (" int8" , " int32" ): " int32" ,
278
+ (" int8" , " int64" ): " int64" ,
279
+ (" int16" , " int16" ): " int16" ,
280
+ (" int16" , " int32" ): " int32" ,
281
+ (" int16" , " int64" ): " int64" ,
282
+ (" int32" , " int32" ): " int32" ,
283
+ (" int32" , " int64" ): " int64" ,
284
+ (" int64" , " int64" ): " int64" ,
263
285
# uints
264
- (( xp . uint8 , xp . uint8 ), xp . uint8 ) ,
265
- (( xp . uint8 , xp . uint16 ), xp . uint16 ) ,
266
- (( xp . uint8 , xp . uint32 ), xp . uint32 ) ,
267
- (( xp . uint8 , xp . uint64 ), xp . uint64 ) ,
268
- (( xp . uint16 , xp . uint16 ), xp . uint16 ) ,
269
- (( xp . uint16 , xp . uint32 ), xp . uint32 ) ,
270
- (( xp . uint16 , xp . uint64 ), xp . uint64 ) ,
271
- (( xp . uint32 , xp . uint32 ), xp . uint32 ) ,
272
- (( xp . uint32 , xp . uint64 ), xp . uint64 ) ,
273
- (( xp . uint64 , xp . uint64 ), xp . uint64 ) ,
286
+ (" uint8" , " uint8" ): " uint8" ,
287
+ (" uint8" , " uint16" ): " uint16" ,
288
+ (" uint8" , " uint32" ): " uint32" ,
289
+ (" uint8" , " uint64" ): " uint64" ,
290
+ (" uint16" , " uint16" ): " uint16" ,
291
+ (" uint16" , " uint32" ): " uint32" ,
292
+ (" uint16" , " uint64" ): " uint64" ,
293
+ (" uint32" , " uint32" ): " uint32" ,
294
+ (" uint32" , " uint64" ): " uint64" ,
295
+ (" uint64" , " uint64" ): " uint64" ,
274
296
# ints and uints (mixed sign)
275
- (( xp . int8 , xp . uint8 ), xp . int16 ) ,
276
- (( xp . int8 , xp . uint16 ), xp . int32 ) ,
277
- (( xp . int8 , xp . uint32 ), xp . int64 ) ,
278
- (( xp . int16 , xp . uint8 ), xp . int16 ) ,
279
- (( xp . int16 , xp . uint16 ), xp . int32 ) ,
280
- (( xp . int16 , xp . uint32 ), xp . int64 ) ,
281
- (( xp . int32 , xp . uint8 ), xp . int32 ) ,
282
- (( xp . int32 , xp . uint16 ), xp . int32 ) ,
283
- (( xp . int32 , xp . uint32 ), xp . int64 ) ,
284
- (( xp . int64 , xp . uint8 ), xp . int64 ) ,
285
- (( xp . int64 , xp . uint16 ), xp . int64 ) ,
286
- (( xp . int64 , xp . uint32 ), xp . int64 ) ,
297
+ (" int8" , " uint8" ): " int16" ,
298
+ (" int8" , " uint16" ): " int32" ,
299
+ (" int8" , " uint32" ): " int64" ,
300
+ (" int16" , " uint8" ): " int16" ,
301
+ (" int16" , " uint16" ): " int32" ,
302
+ (" int16" , " uint32" ): " int64" ,
303
+ (" int32" , " uint8" ): " int32" ,
304
+ (" int32" , " uint16" ): " int32" ,
305
+ (" int32" , " uint32" ): " int64" ,
306
+ (" int64" , " uint8" ): " int64" ,
307
+ (" int64" , " uint16" ): " int64" ,
308
+ (" int64" , " uint32" ): " int64" ,
287
309
# floats
288
- (( xp . float32 , xp . float32 ), xp . float32 ) ,
289
- (( xp . float32 , xp . float64 ), xp . float64 ) ,
290
- (( xp . float64 , xp . float64 ), xp . float64 ) ,
310
+ (" float32" , " float32" ): " float32" ,
311
+ (" float32" , " float64" ): " float64" ,
312
+ (" float64" , " float64" ): " float64" ,
291
313
# complex
292
- ((xp .complex64 , xp .complex64 ), xp .complex64 ),
293
- ((xp .complex64 , xp .complex128 ), xp .complex128 ),
294
- ((xp .complex128 , xp .complex128 ), xp .complex128 ),
295
- ]
296
- _numeric_promotions += [((d2 , d1 ), res ) for (d1 , d2 ), res in _numeric_promotions ]
297
- _promotion_table = list (set (_numeric_promotions ))
298
- _promotion_table .insert (0 , ((xp .bool , xp .bool ), xp .bool ))
299
- promotion_table = EqualityMapping (_promotion_table )
314
+ ("complex64" , "complex64" ): "complex64" ,
315
+ ("complex64" , "complex128" ): "complex128" ,
316
+ ("complex128" , "complex128" ): "complex128" ,
317
+ }
318
+ _promotion_table .update ({(d2 , d1 ): res for (d1 , d2 ), res in _promotion_table .items ()})
319
+ _promotion_table_pairs : List [Tuple [Tuple [DataType , DataType ], DataType ]] = []
320
+ for (in_name1 , in_name2 ), res_name in _promotion_table .items ():
321
+ try :
322
+ in_dtype1 = getattr (xp , in_name1 )
323
+ except AttributeError :
324
+ continue
325
+ try :
326
+ in_dtype2 = getattr (xp , in_name2 )
327
+ except AttributeError :
328
+ continue
329
+ try :
330
+ res_dtype = getattr (xp , res_name )
331
+ except AttributeError :
332
+ continue
333
+ _promotion_table_pairs .append (((in_dtype1 , in_dtype2 ), res_dtype ))
334
+ promotion_table = EqualityMapping (_promotion_table_pairs )
300
335
301
336
302
337
def result_type (* dtypes : DataType ):
@@ -325,6 +360,7 @@ def result_type(*dtypes: DataType):
325
360
}
326
361
func_in_dtypes : DefaultDict [str , Tuple [DataType , ...]] = defaultdict (lambda : all_dtypes )
327
362
for name , func in name_to_func .items ():
363
+ assert func .__doc__ is not None # for mypy
328
364
if m := r_in_dtypes .search (func .__doc__ ):
329
365
dtype_category = m .group (1 )
330
366
if dtype_category == "numeric" and r_int_note .search (func .__doc__ ):
@@ -457,11 +493,10 @@ def result_type(*dtypes: DataType):
457
493
}
458
494
459
495
496
+ # Construct func_in_dtypes and func_returns bool
460
497
for op , elwise_func in op_to_func .items ():
461
498
func_in_dtypes [op ] = func_in_dtypes [elwise_func ]
462
499
func_returns_bool [op ] = func_returns_bool [elwise_func ]
463
-
464
-
465
500
inplace_op_to_symbol = {}
466
501
for op , symbol in binary_op_to_symbol .items ():
467
502
if op == "__matmul__" or func_returns_bool [op ]:
@@ -470,8 +505,6 @@ def result_type(*dtypes: DataType):
470
505
inplace_op_to_symbol [iop ] = f"{ symbol } ="
471
506
func_in_dtypes [iop ] = func_in_dtypes [op ]
472
507
func_returns_bool [iop ] = func_returns_bool [op ]
473
-
474
-
475
508
func_in_dtypes ["__bool__" ] = (xp .bool ,)
476
509
func_in_dtypes ["__int__" ] = all_int_dtypes
477
510
func_in_dtypes ["__index__" ] = all_int_dtypes
0 commit comments