diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 8f05bc13..c88d0a53 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -20,6 +20,8 @@ def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] + if hh.FILTER_UNDEFINED_DTYPES: + dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] return st.none() | st.sampled_from(dtypes)