Skip to content

Commit 829f0f7

Browse files
committed
Remove use of _UndefinedStub in dtype_helpers.py
1 parent e0bb425 commit 829f0f7

File tree

1 file changed

+138
-105
lines changed

1 file changed

+138
-105
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 138 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,22 @@
22
from collections import defaultdict
33
from collections.abc import Mapping
44
from functools import lru_cache
5-
from typing import Any, DefaultDict, NamedTuple, Sequence, Tuple, Union
5+
from typing import Any, DefaultDict, Dict, List, NamedTuple, Sequence, Tuple, Union
66
from warnings import warn
77

8-
from . import _array_module as xp
98
from . import api_version
10-
from ._array_module import _UndefinedStub
11-
from ._array_module import mod as _xp
9+
from ._array_module import mod as xp
1210
from .stubs import name_to_func
1311
from .typing import DataType, ScalarType
1412

1513
__all__ = [
14+
"uint_names",
15+
"int_names",
16+
"float_names",
17+
"real_names",
18+
"complex_names",
19+
"numeric_names",
20+
"dtype_names",
1621
"int_dtypes",
1722
"uint_dtypes",
1823
"all_int_dtypes",
@@ -90,27 +95,42 @@ def __repr__(self):
9095
return f"EqualityMapping({self})"
9196

9297

93-
def _filter_stubs(*args):
94-
for a in args:
95-
if not isinstance(a, _UndefinedStub):
96-
yield a
98+
uint_names = ("uint8", "uint16", "uint32", "uint64")
99+
int_names = ("int8", "int16", "int32", "int64")
100+
float_names = ("float32", "float64")
101+
real_names = uint_names + int_names + float_names
102+
complex_names = ("complex64", "complex128")
103+
numeric_names = real_names + complex_names
104+
dtype_names = ("bool",) + numeric_names
97105

98106

99-
_uint_names = ("uint8", "uint16", "uint32", "uint64")
100-
_int_names = ("int8", "int16", "int32", "int64")
101-
_float_names = ("float32", "float64")
102-
_real_names = _uint_names + _int_names + _float_names
103-
_complex_names = ("complex64", "complex128")
104-
_numeric_names = _real_names + _complex_names
105-
_dtype_names = ("bool",) + _numeric_names
107+
_name_to_dtype = {}
108+
for name in dtype_names:
109+
try:
110+
dtype = getattr(xp, name)
111+
except AttributeError:
112+
continue
113+
_name_to_dtype[name] = dtype
114+
dtype_to_name = EqualityMapping([(d, n) for n, d in _name_to_dtype.items()])
106115

107116

108-
uint_dtypes = tuple(getattr(xp, name) for name in _uint_names)
109-
int_dtypes = tuple(getattr(xp, name) for name in _int_names)
110-
float_dtypes = tuple(getattr(xp, name) for name in _float_names)
117+
def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]:
118+
dtypes = []
119+
for name in names:
120+
try:
121+
dtype = _name_to_dtype[name]
122+
except KeyError:
123+
continue
124+
dtypes.append(dtype)
125+
return tuple(dtypes)
126+
127+
128+
uint_dtypes = _make_dtype_tuple_from_names(uint_names)
129+
int_dtypes = _make_dtype_tuple_from_names(int_names)
130+
float_dtypes = _make_dtype_tuple_from_names(float_names)
111131
all_int_dtypes = uint_dtypes + int_dtypes
112132
real_dtypes = all_int_dtypes + float_dtypes
113-
complex_dtypes = tuple(getattr(xp, name) for name in _complex_names)
133+
complex_dtypes = _make_dtype_tuple_from_names(complex_names)
114134
numeric_dtypes = real_dtypes
115135
if api_version > "2021.12":
116136
numeric_dtypes += complex_dtypes
@@ -121,16 +141,6 @@ def _filter_stubs(*args):
121141
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
122142

123143

124-
_dtype_name_pairs = []
125-
for name in _dtype_names:
126-
try:
127-
dtype = getattr(_xp, name)
128-
except AttributeError:
129-
continue
130-
_dtype_name_pairs.append((dtype, name))
131-
dtype_to_name = EqualityMapping(_dtype_name_pairs)
132-
133-
134144
dtype_to_scalars = EqualityMapping(
135145
[
136146
(xp.bool, [bool]),
@@ -179,47 +189,59 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
179189
return bool
180190

181191

192+
def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
193+
dtype_value_pairs = []
194+
for name, value in mapping.items():
195+
assert isinstance(name, str) and name in dtype_names # sanity check
196+
try:
197+
dtype = getattr(xp, name)
198+
except AttributeError:
199+
continue
200+
dtype_value_pairs.append((dtype, value))
201+
return EqualityMapping(dtype_value_pairs)
202+
203+
182204
class MinMax(NamedTuple):
183205
min: Union[int, float]
184206
max: Union[int, float]
185207

186208

187-
dtype_ranges = EqualityMapping(
188-
[
189-
(xp.int8, MinMax(-128, +127)),
190-
(xp.int16, MinMax(-32_768, +32_767)),
191-
(xp.int32, MinMax(-2_147_483_648, +2_147_483_647)),
192-
(xp.int64, MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807)),
193-
(xp.uint8, MinMax(0, +255)),
194-
(xp.uint16, MinMax(0, +65_535)),
195-
(xp.uint32, MinMax(0, +4_294_967_295)),
196-
(xp.uint64, MinMax(0, +18_446_744_073_709_551_615)),
197-
(xp.float32, MinMax(-3.4028234663852886e38, 3.4028234663852886e38)),
198-
(xp.float64, MinMax(-1.7976931348623157e308, 1.7976931348623157e308)),
199-
]
209+
dtype_ranges = _make_dtype_mapping_from_names(
210+
{
211+
"int8": MinMax(-128, +127),
212+
"int16": MinMax(-32_768, +32_767),
213+
"int32": MinMax(-2_147_483_648, +2_147_483_647),
214+
"int64": MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807),
215+
"uint8": MinMax(0, +255),
216+
"uint16": MinMax(0, +65_535),
217+
"uint32": MinMax(0, +4_294_967_295),
218+
"uint64": MinMax(0, +18_446_744_073_709_551_615),
219+
"float32": MinMax(-3.4028234663852886e38, 3.4028234663852886e38),
220+
"float64": MinMax(-1.7976931348623157e308, 1.7976931348623157e308),
221+
}
200222
)
201223

202224

203-
dtype_nbits = EqualityMapping(
204-
[(d, 8) for d in _filter_stubs(xp.int8, xp.uint8)]
205-
+ [(d, 16) for d in _filter_stubs(xp.int16, xp.uint16)]
206-
+ [(d, 32) for d in _filter_stubs(xp.int32, xp.uint32, xp.float32)]
207-
+ [(d, 64) for d in _filter_stubs(xp.int64, xp.uint64, xp.float64, xp.complex64)]
208-
+ [(d, 128) for d in _filter_stubs(xp.complex128)]
209-
)
225+
r_nbits = re.compile(r"[a-z]+([0-9]+)")
226+
_dtype_nbits: Dict[str, int] = {}
227+
for name in numeric_names:
228+
m = r_nbits.fullmatch(name)
229+
assert m is not None # sanity check / for mypy
230+
_dtype_nbits[name] = int(m.group(1))
231+
dtype_nbits = _make_dtype_mapping_from_names(_dtype_nbits)
210232

211233

212-
dtype_signed = EqualityMapping(
213-
[(d, True) for d in int_dtypes] + [(d, False) for d in uint_dtypes]
234+
dtype_signed = _make_dtype_mapping_from_names(
235+
{**{name: True for name in int_names}, **{name: False for name in uint_names}}
214236
)
215237

216238

217-
dtype_components = EqualityMapping(
218-
[(xp.complex64, xp.float32), (xp.complex128, xp.float64)]
239+
dtype_components = _make_dtype_mapping_from_names(
240+
{"complex64": xp.float32, "complex128": xp.float64}
219241
)
220242

221243

222-
if isinstance(xp.asarray, _UndefinedStub):
244+
if not hasattr(xp, "asarray"):
223245
default_int = xp.int32
224246
default_float = xp.float32
225247
warn(
@@ -243,60 +265,73 @@ class MinMax(NamedTuple):
243265
else:
244266
default_complex = None
245267
if dtype_nbits[default_int] == 32:
246-
default_uint = xp.uint32
268+
default_uint = getattr(xp, "uint32", None)
247269
else:
248-
default_uint = xp.uint64
249-
270+
default_uint = getattr(xp, "uint64", None)
250271

251-
_numeric_promotions = [
272+
_promotion_table: Dict[Tuple[str, str], str] = {
273+
("bool", "bool"): "bool",
252274
# ints
253-
((xp.int8, xp.int8), xp.int8),
254-
((xp.int8, xp.int16), xp.int16),
255-
((xp.int8, xp.int32), xp.int32),
256-
((xp.int8, xp.int64), xp.int64),
257-
((xp.int16, xp.int16), xp.int16),
258-
((xp.int16, xp.int32), xp.int32),
259-
((xp.int16, xp.int64), xp.int64),
260-
((xp.int32, xp.int32), xp.int32),
261-
((xp.int32, xp.int64), xp.int64),
262-
((xp.int64, xp.int64), xp.int64),
275+
("int8", "int8"): "int8",
276+
("int8", "int16"): "int16",
277+
("int8", "int32"): "int32",
278+
("int8", "int64"): "int64",
279+
("int16", "int16"): "int16",
280+
("int16", "int32"): "int32",
281+
("int16", "int64"): "int64",
282+
("int32", "int32"): "int32",
283+
("int32", "int64"): "int64",
284+
("int64", "int64"): "int64",
263285
# uints
264-
((xp.uint8, xp.uint8), xp.uint8),
265-
((xp.uint8, xp.uint16), xp.uint16),
266-
((xp.uint8, xp.uint32), xp.uint32),
267-
((xp.uint8, xp.uint64), xp.uint64),
268-
((xp.uint16, xp.uint16), xp.uint16),
269-
((xp.uint16, xp.uint32), xp.uint32),
270-
((xp.uint16, xp.uint64), xp.uint64),
271-
((xp.uint32, xp.uint32), xp.uint32),
272-
((xp.uint32, xp.uint64), xp.uint64),
273-
((xp.uint64, xp.uint64), xp.uint64),
286+
("uint8", "uint8"): "uint8",
287+
("uint8", "uint16"): "uint16",
288+
("uint8", "uint32"): "uint32",
289+
("uint8", "uint64"): "uint64",
290+
("uint16", "uint16"): "uint16",
291+
("uint16", "uint32"): "uint32",
292+
("uint16", "uint64"): "uint64",
293+
("uint32", "uint32"): "uint32",
294+
("uint32", "uint64"): "uint64",
295+
("uint64", "uint64"): "uint64",
274296
# ints and uints (mixed sign)
275-
((xp.int8, xp.uint8), xp.int16),
276-
((xp.int8, xp.uint16), xp.int32),
277-
((xp.int8, xp.uint32), xp.int64),
278-
((xp.int16, xp.uint8), xp.int16),
279-
((xp.int16, xp.uint16), xp.int32),
280-
((xp.int16, xp.uint32), xp.int64),
281-
((xp.int32, xp.uint8), xp.int32),
282-
((xp.int32, xp.uint16), xp.int32),
283-
((xp.int32, xp.uint32), xp.int64),
284-
((xp.int64, xp.uint8), xp.int64),
285-
((xp.int64, xp.uint16), xp.int64),
286-
((xp.int64, xp.uint32), xp.int64),
297+
("int8", "uint8"): "int16",
298+
("int8", "uint16"): "int32",
299+
("int8", "uint32"): "int64",
300+
("int16", "uint8"): "int16",
301+
("int16", "uint16"): "int32",
302+
("int16", "uint32"): "int64",
303+
("int32", "uint8"): "int32",
304+
("int32", "uint16"): "int32",
305+
("int32", "uint32"): "int64",
306+
("int64", "uint8"): "int64",
307+
("int64", "uint16"): "int64",
308+
("int64", "uint32"): "int64",
287309
# floats
288-
((xp.float32, xp.float32), xp.float32),
289-
((xp.float32, xp.float64), xp.float64),
290-
((xp.float64, xp.float64), xp.float64),
310+
("float32", "float32"): "float32",
311+
("float32", "float64"): "float64",
312+
("float64", "float64"): "float64",
291313
# complex
292-
((xp.complex64, xp.complex64), xp.complex64),
293-
((xp.complex64, xp.complex128), xp.complex128),
294-
((xp.complex128, xp.complex128), xp.complex128),
295-
]
296-
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
297-
_promotion_table = list(set(_numeric_promotions))
298-
_promotion_table.insert(0, ((xp.bool, xp.bool), xp.bool))
299-
promotion_table = EqualityMapping(_promotion_table)
314+
("complex64", "complex64"): "complex64",
315+
("complex64", "complex128"): "complex128",
316+
("complex128", "complex128"): "complex128",
317+
}
318+
_promotion_table.update({(d2, d1): res for (d1, d2), res in _promotion_table.items()})
319+
_promotion_table_pairs: List[Tuple[Tuple[DataType, DataType], DataType]] = []
320+
for (in_name1, in_name2), res_name in _promotion_table.items():
321+
try:
322+
in_dtype1 = getattr(xp, in_name1)
323+
except AttributeError:
324+
continue
325+
try:
326+
in_dtype2 = getattr(xp, in_name2)
327+
except AttributeError:
328+
continue
329+
try:
330+
res_dtype = getattr(xp, res_name)
331+
except AttributeError:
332+
continue
333+
_promotion_table_pairs.append(((in_dtype1, in_dtype2), res_dtype))
334+
promotion_table = EqualityMapping(_promotion_table_pairs)
300335

301336

302337
def result_type(*dtypes: DataType):
@@ -325,6 +360,7 @@ def result_type(*dtypes: DataType):
325360
}
326361
func_in_dtypes: DefaultDict[str, Tuple[DataType, ...]] = defaultdict(lambda: all_dtypes)
327362
for name, func in name_to_func.items():
363+
assert func.__doc__ is not None # for mypy
328364
if m := r_in_dtypes.search(func.__doc__):
329365
dtype_category = m.group(1)
330366
if dtype_category == "numeric" and r_int_note.search(func.__doc__):
@@ -457,11 +493,10 @@ def result_type(*dtypes: DataType):
457493
}
458494

459495

496+
# Construct func_in_dtypes and func_returns bool
460497
for op, elwise_func in op_to_func.items():
461498
func_in_dtypes[op] = func_in_dtypes[elwise_func]
462499
func_returns_bool[op] = func_returns_bool[elwise_func]
463-
464-
465500
inplace_op_to_symbol = {}
466501
for op, symbol in binary_op_to_symbol.items():
467502
if op == "__matmul__" or func_returns_bool[op]:
@@ -470,8 +505,6 @@ def result_type(*dtypes: DataType):
470505
inplace_op_to_symbol[iop] = f"{symbol}="
471506
func_in_dtypes[iop] = func_in_dtypes[op]
472507
func_returns_bool[iop] = func_returns_bool[op]
473-
474-
475508
func_in_dtypes["__bool__"] = (xp.bool,)
476509
func_in_dtypes["__int__"] = all_int_dtypes
477510
func_in_dtypes["__index__"] = all_int_dtypes

0 commit comments

Comments
 (0)