5
5
from typing import Any , Dict , NamedTuple , Sequence , Tuple , Union
6
6
from warnings import warn
7
7
8
- from . import api_version
9
8
from . import _array_module as xp
9
+ from . import api_version
10
10
from ._array_module import _UndefinedStub
11
+ from ._array_module import mod as _xp
11
12
from .stubs import name_to_func
12
13
from .typing import DataType , ScalarType
13
14
@@ -88,6 +89,12 @@ def __repr__(self):
88
89
return f"EqualityMapping({ self } )"
89
90
90
91
92
+ def _filter_stubs (* args ):
93
+ for a in args :
94
+ if not isinstance (a , _UndefinedStub ):
95
+ yield a
96
+
97
+
91
98
_uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
92
99
_int_names = ("int8" , "int16" , "int32" , "int64" )
93
100
_float_names = ("float32" , "float64" )
@@ -113,7 +120,14 @@ def __repr__(self):
113
120
bool_and_all_int_dtypes = (xp .bool ,) + all_int_dtypes
114
121
115
122
116
- dtype_to_name = EqualityMapping ([(getattr (xp , name ), name ) for name in _dtype_names ])
123
+ _dtype_name_pairs = []
124
+ for name in _dtype_names :
125
+ try :
126
+ dtype = getattr (_xp , name )
127
+ except AttributeError :
128
+ continue
129
+ _dtype_name_pairs .append ((dtype , name ))
130
+ dtype_to_name = EqualityMapping (_dtype_name_pairs )
117
131
118
132
119
133
dtype_to_scalars = EqualityMapping (
@@ -173,12 +187,13 @@ class MinMax(NamedTuple):
173
187
]
174
188
)
175
189
190
+
176
191
dtype_nbits = EqualityMapping (
177
- [(d , 8 ) for d in [ xp .int8 , xp .uint8 ] ]
178
- + [(d , 16 ) for d in [ xp .int16 , xp .uint16 ] ]
179
- + [(d , 32 ) for d in [ xp .int32 , xp .uint32 , xp .float32 ] ]
180
- + [(d , 64 ) for d in [ xp .int64 , xp .uint64 , xp .float64 , xp .complex64 ] ]
181
- + [(xp . complex128 , 128 )]
192
+ [(d , 8 ) for d in _filter_stubs ( xp .int8 , xp .uint8 ) ]
193
+ + [(d , 16 ) for d in _filter_stubs ( xp .int16 , xp .uint16 ) ]
194
+ + [(d , 32 ) for d in _filter_stubs ( xp .int32 , xp .uint32 , xp .float32 ) ]
195
+ + [(d , 64 ) for d in _filter_stubs ( xp .int64 , xp .uint64 , xp .float64 , xp .complex64 ) ]
196
+ + [(d , 128 ) for d in _filter_stubs ( xp . complex128 )]
182
197
)
183
198
184
199
@@ -265,7 +280,6 @@ class MinMax(NamedTuple):
265
280
((xp .complex64 , xp .complex64 ), xp .complex64 ),
266
281
((xp .complex64 , xp .complex128 ), xp .complex128 ),
267
282
((xp .complex128 , xp .complex128 ), xp .complex128 ),
268
-
269
283
]
270
284
_numeric_promotions += [((d2 , d1 ), res ) for (d1 , d2 ), res in _numeric_promotions ]
271
285
_promotion_table = list (set (_numeric_promotions ))
0 commit comments