Skip to content

Commit bef69cf

Browse files
committed
Implemented test_trunc
1 parent 70099e5 commit bef69cf

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

array_api_tests/test_elementwise_functions.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
boolean_dtype_objects, floating_dtypes,
2727
numeric_dtypes, integer_or_boolean_dtypes,
2828
boolean_dtypes, mutually_promotable_dtypes,
29-
array_scalars, shared_arrays1, shared_arrays2)
29+
array_scalars, shared_arrays1, shared_arrays2,
30+
xps)
3031
from .array_helpers import (assert_exactly_equal, negative,
3132
positive_mathematical_sign,
3233
negative_mathematical_sign, logical_not,
@@ -37,7 +38,7 @@
3738
ndindex, promote_dtypes, is_integer_dtype,
3839
is_float_dtype, not_equal, float64, asarray,
3940
dtype_ranges, full, true, false, assert_same_sign,
40-
isnan, less)
41+
isnan, equal, less)
4142
# We might as well use this implementation rather than requiring
4243
# mod.broadcast_shapes(). See test_equal() and others.
4344
from .test_broadcasting import broadcast_shapes
@@ -901,7 +902,16 @@ def test_tanh(x):
901902
# a = _array_module.tanh(x)
902903
pass
903904

904-
@given(numeric_scalars)
905+
@given(xps.arrays(dtype=numeric_dtypes, shape=xps.array_shapes()))
905906
def test_trunc(x):
906-
# a = _array_module.trunc(x)
907-
pass
907+
a = _array_module.trunc(x)
908+
assert a.dtype == x.dtype, f"{x.dtype=!s}, but trunc() did not produce a {x.dtype} array - instead was {a.dtype}"
909+
if x.dtype in integer_dtype_objects:
910+
assert array_all(equal(x, a)), f"{x=!s} but trunc(x)={x} - {x.dtype=!s} so trunc(x) should do nothing"
911+
else:
912+
# TODO: a method that generates all indices, so we don't have to flatten first
913+
a = _array_module.reshape(a, a.size)
914+
finite_mask = _array_module.isfinite(a)
915+
for i in range(a.size):
916+
if finite_mask[i]:
917+
assert float(a[i]).is_integer(), f"trunc(x) did not round float {a[i]} to 0 decimals"

0 commit comments

Comments
 (0)