Skip to content

Commit 675fa6c

Browse files
committed
Filter out stubs in some dtype_helpers
Required for testing `torch_np`
1 parent 96717db commit 675fa6c

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

array_api_tests/dtype_helpers.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union
66
from warnings import warn
77

8-
from . import api_version
98
from . import _array_module as xp
9+
from . import api_version
1010
from ._array_module import _UndefinedStub
11+
from ._array_module import mod as _xp
1112
from .stubs import name_to_func
1213
from .typing import DataType, ScalarType
1314

@@ -88,6 +89,12 @@ def __repr__(self):
8889
return f"EqualityMapping({self})"
8990

9091

92+
def _filter_stubs(*args):
93+
for a in args:
94+
if not isinstance(a, _UndefinedStub):
95+
yield a
96+
97+
9198
_uint_names = ("uint8", "uint16", "uint32", "uint64")
9299
_int_names = ("int8", "int16", "int32", "int64")
93100
_float_names = ("float32", "float64")
@@ -113,7 +120,14 @@ def __repr__(self):
113120
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
114121

115122

116-
dtype_to_name = EqualityMapping([(getattr(xp, name), name) for name in _dtype_names])
123+
_dtype_name_pairs = []
124+
for name in _dtype_names:
125+
try:
126+
dtype = getattr(_xp, name)
127+
except AttributeError:
128+
continue
129+
_dtype_name_pairs.append((dtype, name))
130+
dtype_to_name = EqualityMapping(_dtype_name_pairs)
117131

118132

119133
dtype_to_scalars = EqualityMapping(
@@ -173,12 +187,13 @@ class MinMax(NamedTuple):
173187
]
174188
)
175189

190+
176191
dtype_nbits = EqualityMapping(
177-
[(d, 8) for d in [xp.int8, xp.uint8]]
178-
+ [(d, 16) for d in [xp.int16, xp.uint16]]
179-
+ [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]]
180-
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64, xp.complex64]]
181-
+ [(xp.complex128, 128)]
192+
[(d, 8) for d in _filter_stubs(xp.int8, xp.uint8)]
193+
+ [(d, 16) for d in _filter_stubs(xp.int16, xp.uint16)]
194+
+ [(d, 32) for d in _filter_stubs(xp.int32, xp.uint32, xp.float32)]
195+
+ [(d, 64) for d in _filter_stubs(xp.int64, xp.uint64, xp.float64, xp.complex64)]
196+
+ [(d, 128) for d in _filter_stubs(xp.complex128)]
182197
)
183198

184199

@@ -265,7 +280,6 @@ class MinMax(NamedTuple):
265280
((xp.complex64, xp.complex64), xp.complex64),
266281
((xp.complex64, xp.complex128), xp.complex128),
267282
((xp.complex128, xp.complex128), xp.complex128),
268-
269283
]
270284
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
271285
_promotion_table = list(set(_numeric_promotions))

0 commit comments

Comments
 (0)