diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index 4dd95937..1c7d324f 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1783,48 +1783,42 @@ def test_trunc(x): def _check_binary_with_scalars(func_data, x1x2): x1, x2 = x1x2 - func, name, refimpl, kwds, expected_dtype = func_data + func_name, refimpl, kwds, expected_dtype = func_data + func = getattr(xp, func_name) out = func(x1, x2) in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2) _assert_correctness_binary( - name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds + func_name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds ) def _filter_zero(x): return x != 0 if dh.is_scalar(x) else (not xp.any(x == 0)) -# workarounds for xp.copysign etc only available in 2023.12 -# Without it, test suite fails to import with ARRAY_API_VERSION=2022.12 -_xp_copysign = getattr(xp, "copysign", None) -_xp_hypot = getattr(xp, "hypot", None) -_xp_maximum = getattr(xp, "maximum", None) -_xp_minimum = getattr(xp, "minimum", None) - @pytest.mark.min_version("2024.12") @pytest.mark.parametrize('func_data', - # xp_func, name, refimpl, kwargs, expected_dtype + # func_name, refimpl, kwargs, expected_dtype [ - (xp.add, "add", operator.add, {}, None), - (xp.atan2, "atan2", math.atan2, {}, None), - (_xp_copysign, "copysign", math.copysign, {}, None), - (xp.divide, "divide", operator.truediv, {"filter_": lambda s: s != 0}, None), - (_xp_hypot, "hypot", math.hypot, {}, None), - (xp.logaddexp, "logaddexp", logaddexp_refimpl, {}, None), - (_xp_maximum, "maximum", max, {'strict_check': True}, None), - (_xp_minimum, "minimum", min, {'strict_check': True}, None), - (xp.multiply, "mul", operator.mul, {}, None), - (xp.subtract, "sub", operator.sub, {}, None), - - (xp.equal, "equal", operator.eq, {}, xp.bool), - (xp.not_equal, "neq", operator.ne, {}, xp.bool), - (xp.less, "less", operator.lt, {}, xp.bool), - (xp.less_equal, "les_equal", operator.le, {}, xp.bool), - (xp.greater, "greater", operator.gt, {}, xp.bool), - (xp.greater_equal, "greater_equal", operator.ge, {}, xp.bool), + ("add", operator.add, {}, None), + ("atan2", math.atan2, {}, None), + ("copysign", math.copysign, {}, None), + ("divide", operator.truediv, {"filter_": lambda s: s != 0}, None), + ("hypot", math.hypot, {}, None), + ("logaddexp", logaddexp_refimpl, {}, None), + ("maximum", max, {'strict_check': True}, None), + ("minimum", min, {'strict_check': True}, None), + ("multiply", operator.mul, {}, None), + ("subtract", operator.sub, {}, None), + + ("equal", operator.eq, {}, xp.bool), + ("not_equal", operator.ne, {}, xp.bool), + ("less", operator.lt, {}, xp.bool), + ("less_equal", operator.le, {}, xp.bool), + ("greater", operator.gt, {}, xp.bool), + ("greater_equal", operator.ge, {}, xp.bool), ], - ids=lambda func_data: func_data[1] # use names for test IDs + ids=lambda func_data: func_data[0] # use names for test IDs ) @given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes)) def test_binary_with_scalars_real(func_data, x1x2): @@ -1833,13 +1827,13 @@ def test_binary_with_scalars_real(func_data, x1x2): @pytest.mark.min_version("2024.12") @pytest.mark.parametrize('func_data', - # xp_func, name, refimpl, kwargs, expected_dtype + # func_name, refimpl, kwargs, expected_dtype [ - (xp.logical_and, "logical_and", operator.and_, {"expr_template": "({} or {})={}"}, None), - (xp.logical_or, "logical_or", operator.or_, {"expr_template": "({} or {})={}"}, None), - (xp.logical_xor, "logical_xor", operator.xor, {"expr_template": "({} or {})={}"}, None), + ("logical_and", operator.and_, {"expr_template": "({} or {})={}"}, None), + ("logical_or", operator.or_, {"expr_template": "({} or {})={}"}, None), + ("logical_xor", operator.xor, {"expr_template": "({} or {})={}"}, None), ], - ids=lambda func_data: func_data[1] # use names for test IDs + ids=lambda func_data: func_data[0] # use names for test IDs ) @given(x1x2=hh.array_and_py_scalar([xp.bool])) def test_binary_with_scalars_bool(func_data, x1x2): @@ -1848,17 +1842,15 @@ def test_binary_with_scalars_bool(func_data, x1x2): @pytest.mark.min_version("2024.12") @pytest.mark.parametrize('func_data', - # xp_func, name, refimpl, kwargs, expected_dtype + # func_name, refimpl, kwargs, expected_dtype [ - - (xp.floor_divide, "floor_divide", operator.floordiv, {}, None), - (xp.remainder, "remainder", operator.mod, {}, None), + ("floor_divide", operator.floordiv, {}, None), + ("remainder", operator.mod, {}, None), ], - ids=lambda func_data: func_data[1] # use names for test IDs + ids=lambda func_data: func_data[0] # use names for test IDs ) @given(x1x2=hh.array_and_py_scalar([xp.int64])) def test_binary_with_scalars_int(func_data, x1x2): - assume(_filter_zero(x1x2[1])) assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1])) _check_binary_with_scalars(func_data, x1x2) @@ -1866,18 +1858,18 @@ def test_binary_with_scalars_int(func_data, x1x2): @pytest.mark.min_version("2024.12") @pytest.mark.parametrize('func_data', - # xp_func, name, refimpl, kwargs, expected_dtype + # func_name, refimpl, kwargs, expected_dtype [ - (xp.bitwise_and, "bitwise_and", operator.and_, {}, None), - (xp.bitwise_or, "bitwise_or", operator.or_, {}, None), - (xp.bitwise_xor, "bitwise_xor", operator.xor, {}, None), + ("bitwise_and", operator.and_, {}, None), + ("bitwise_or", operator.or_, {}, None), + ("bitwise_xor", operator.xor, {}, None), ], - ids=lambda func_data: func_data[1] # use names for test IDs + ids=lambda func_data: func_data[0] # use names for test IDs ) @given(x1x2=hh.array_and_py_scalar([xp.int32])) def test_binary_with_scalars_bitwise(func_data, x1x2): - xp_func, name, refimpl, kwargs, expected = func_data + func_name, refimpl, kwargs, expected = func_data # repack the refimpl refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 ) - _check_binary_with_scalars((xp_func, name, refimpl_, kwargs,expected), x1x2) + _check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)