Skip to content

Commit d0bbc9c

Browse files
Fix test for neg on unsigned
Due to changes in numpy conversion rules (NEP 50), overflows are not ignored; in particular, negating a unsigned int causes an overflow error. The test for `neg` has been changed to check that this error is raised.
1 parent 5b61cd4 commit d0bbc9c

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

tests/tensor/test_math.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,11 +391,20 @@ def test_maximum_minimum_grad():
391391
grad=_grad_broadcast_unary_normal,
392392
)
393393

394+
395+
# in numpy >= 2.0, negating a uint raises an error
396+
neg_good = _good_broadcast_unary_normal.copy()
397+
if np.lib.NumpyVersion(np.__version__) >= "2.0.0rc1":
398+
neg_bad = {"uint8": neg_good.pop("uint8"), "uint16": neg_good.pop("uint16")}
399+
else:
400+
neg_bad = None
401+
394402
TestNegBroadcast = makeBroadcastTester(
395403
op=neg,
396404
expected=lambda x: -x,
397-
good=_good_broadcast_unary_normal,
405+
good=neg_good,
398406
grad=_grad_broadcast_unary_normal,
407+
bad_compile=neg_bad,
399408
)
400409

401410
TestSgnBroadcast = makeBroadcastTester(

tests/tensor/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def makeTester(
339339
good=None,
340340
bad_build=None,
341341
bad_runtime=None,
342+
bad_compile=None,
342343
grad=None,
343344
mode=None,
344345
grad_rtol=None,
@@ -373,6 +374,7 @@ def makeTester(
373374
_test_memmap = test_memmap
374375
_check_name = check_name
375376
_grad_eps = grad_eps
377+
_bad_compile = bad_compile or {}
376378

377379
class Checker:
378380
op = staticmethod(_op)
@@ -382,6 +384,7 @@ class Checker:
382384
good = _good
383385
bad_build = _bad_build
384386
bad_runtime = _bad_runtime
387+
bad_compile = _bad_compile
385388
grad = _grad
386389
mode = _mode
387390
skip = skip_
@@ -539,6 +542,24 @@ def test_bad_build(self):
539542
# instantiated on the following bad inputs: %s"
540543
# % (self.op, testname, node, inputs))
541544

545+
@config.change_flags(compute_test_value="off")
546+
@pytest.mark.skipif(skip, reason="Skipped")
547+
def test_bad_compile(self):
548+
for testname, inputs in self.bad_compile.items():
549+
inputrs = [shared(input) for input in inputs]
550+
try:
551+
node = safe_make_node(self.op, *inputrs)
552+
except Exception as exc:
553+
err_msg = (
554+
f"Test {self.op}::{testname}: Error occurred while trying"
555+
f" to make a node with inputs {inputs}"
556+
)
557+
exc.args += (err_msg,)
558+
raise
559+
560+
with pytest.raises(Exception):
561+
inplace_func([], node.outputs, mode=mode, name="test_bad_runtime")
562+
542563
@config.change_flags(compute_test_value="off")
543564
@pytest.mark.skipif(skip, reason="Skipped")
544565
def test_bad_runtime(self):

0 commit comments

Comments
 (0)