Skip to content

Commit d2267e4

Browse files
committed
Change dtype helpers behaviour depending on api_version
1 parent 33a0f6c commit d2267e4

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,13 @@ def __repr__(self):
103103
all_int_dtypes = uint_dtypes + int_dtypes
104104
real_dtypes = all_int_dtypes + float_dtypes
105105
complex_dtypes = tuple(getattr(xp, name) for name in _complex_names)
106-
numeric_dtypes = real_dtypes + complex_dtypes
106+
numeric_dtypes = real_dtypes
107+
if api_version > "2021.12":
108+
numeric_dtypes += complex_dtypes
107109
all_dtypes = (xp.bool,) + numeric_dtypes
108-
all_float_dtypes = float_dtypes + complex_dtypes
110+
all_float_dtypes = float_dtypes
111+
if api_version > "2021.12":
112+
all_float_dtypes += complex_dtypes
109113
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
110114

111115

@@ -132,7 +136,10 @@ def is_float_dtype(dtype):
132136
# See https://github.com/numpy/numpy/issues/18434
133137
if dtype is None:
134138
return False
135-
return dtype in float_dtypes
139+
valid_dtypes = float_dtypes
140+
if api_version > "2021.12":
141+
valid_dtypes += complex_dtypes
142+
return dtype in valid_dtypes
136143

137144

138145
def get_scalar_type(dtype: DataType) -> ScalarType:

0 commit comments

Comments
 (0)