Skip to content

Commit fc7c8b7

Browse files
committed
Fix get_scalar_type() for complex dtypes
1 parent 2e43666 commit fc7c8b7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def is_float_dtype(dtype):
169169

170170

171171
def get_scalar_type(dtype: DataType) -> ScalarType:
172-
if is_int_dtype(dtype):
172+
if dtype in all_int_dtypes:
173173
return int
174-
elif is_float_dtype(dtype):
174+
elif dtype in float_dtypes:
175175
return float
176176
elif dtype in complex_dtypes:
177177
return complex

0 commit comments

Comments
 (0)