Skip to content

Commit 64204a2

Browse files
committed
Add ndindex() and promote_dtype() helper functions
The latter should probably be used in some more places in the dtype tests, but for now I am not refactoring those. The conversion between dtype objects and their string representation is finicky, and probably should be removed entirely. It's not simple though, because the spec doesn't require dtype objects to be usable as dict keys (and, e.g., NumPy does weird stuff like float64 == None, which we have to be careful to guard against). Also we need to make sure that the tests will work even if a dtype object is _UndefinedStub.
1 parent 528dda4 commit 64204a2

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

array_api_tests/array_helpers.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
from ._array_module import (isnan, all, any, equal, not_equal, logical_and,
24
logical_or, isfinite, greater, less, less_equal,
35
zeros, ones, full, bool, int8, int16, int32,
@@ -22,7 +24,7 @@
2224
'assert_isinf', 'positive_mathematical_sign',
2325
'assert_positive_mathematical_sign', 'negative_mathematical_sign',
2426
'assert_negative_mathematical_sign', 'same_sign',
25-
'assert_same_sign']
27+
'assert_same_sign', 'ndindex', 'promote_dtypes']
2628

2729
def zero(shape, dtype):
2830
"""
@@ -333,3 +335,32 @@ def int_to_dtype(x, n, signed):
333335
if x & highest_bit:
334336
x = -((~x & mask) + 1)
335337
return x
338+
339+
def ndindex(shape):
340+
"""
341+
Iterator of n-D indices to an array
342+
343+
Yields tuples of integers to index every element of an array of shape
344+
`shape`. Same as np.ndindex().
345+
346+
"""
347+
return itertools.product(*[range(i) for i in shape])
348+
349+
def promote_dtypes(dtype1, dtype2):
350+
"""
351+
Special case of result_type() which uses the exact type promotion table
352+
from the spec.
353+
"""
354+
from .test_type_promotion import dtype_mapping, promotion_table
355+
356+
# Equivalent to this, but some libraries may not work properly with using
357+
# dtype objects as dict keys
358+
#
359+
# d1, d2 = reverse_dtype_mapping[dtype1], reverse_dtype_mapping[dtype2]
360+
361+
d1 = [i for i in dtype_mapping if dtype_mapping[i] == dtype1][0]
362+
d2 = [i for i in dtype_mapping if dtype_mapping[i] == dtype2][0]
363+
364+
if (d1, d2) not in promotion_table:
365+
raise ValueError(f"{d1} and {d2} are not type promotable according to the spec (this may indicate a bug in the test suite).")
366+
return dtype_mapping[promotion_table[d1, d2]]

0 commit comments

Comments
 (0)