Skip to content

Commit f943eb8

Browse files
committed
Refactor pos/neg zero utils
1 parent 8de7cdf commit f943eb8

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
"doesnt_raise",
1515
"nargs",
1616
"fmt_kw",
17+
"is_pos_zero",
18+
"is_neg_zero",
1719
"assert_dtype",
1820
"assert_kw_dtype",
1921
"assert_default_float",
@@ -69,6 +71,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str:
6971
return ", ".join(f"{k}={v}" for k, v in kw.items())
7072

7173

74+
def is_pos_zero(n: float) -> bool:
75+
return n == 0 and math.copysign(1, n) == 1
76+
77+
78+
def is_neg_zero(n: float) -> bool:
79+
return n == 0 and math.copysign(1, n) == -1
80+
81+
7282
def assert_dtype(
7383
func_name: str,
7484
in_dtype: Union[DataType, Sequence[DataType]],

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ def default_filter(s: Scalar) -> bool:
123123
124124
Used by default as these values are typically special-cased.
125125
"""
126-
return math.isfinite(s) and s is not -0.0 and s is not +0.0
126+
if isinstance(s, int): # note bools are ints
127+
return True
128+
else:
129+
return math.isfinite(s) and s != 0
127130

128131

129132
T = TypeVar("T")
@@ -538,7 +541,7 @@ def test_abs(ctx, data):
538541
abs, # type: ignore
539542
expr_template="abs({})={}",
540543
filter_=lambda s: (
541-
s == float("infinity") or (math.isfinite(s) and s is not -0.0)
544+
s == float("infinity") or (math.isfinite(s) and not ph.is_neg_zero(s))
542545
),
543546
)
544547

0 commit comments

Comments
 (0)