Skip to content

Commit 4429b2f

Browse files
committed
BUG: detect uint8 overflow
1 parent 19e96c2 commit 4429b2f

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

torch_np/_dtypes_impl.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ def category(dtyp):
8484

8585

8686
def nep50_to_tensors(x1, x2):
87+
"""If either of inputs is a python scalar, type-promote with NEP 50.
88+
89+
NB: NEP 50 mandates RuntimeWarnings on some overflows. We do not emit them:
90+
we either raise OverflowError or just do the computation.
91+
"""
92+
8793
x1_type, x2_type = type(x1), type(x2)
8894
x1_is_weak = x1_type in SCALAR_TYPES
8995
x2_is_weak = x2_type in SCALAR_TYPES
@@ -112,6 +118,11 @@ def nep50_to_tensors(x1, x2):
112118
dt = torch.complex64
113119

114120
# finally, can cast make `weak` into a 0D tensor
115-
weak = torch.as_tensor(weak, dtype=dt)
121+
weak_ = torch.as_tensor(weak, dtype=dt)
122+
123+
# detect uint overflow: in PyTorch, uint8(-1) wraps around to 255,
124+
# while NEP50 mandates an exception.
125+
if weak_.dtype == torch.uint8 and weak_.item() != weak:
126+
raise OverflowError(f"Python integer {weak} out of bounds for {weak_.dtype}")
116127

117-
return (weak, not_weak) if x1_is_weak else (not_weak, weak)
128+
return (weak_, not_weak) if x1_is_weak else (not_weak, weak_)

torch_np/tests/test_nep50_examples.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@
6565

6666

6767
fails = [
68-
"array([1], uint8) + 300",
69-
"uint8(1) + 300",
7068
"array([0.1], float32) == 0.1", # TODO: fix the example
7169
]
7270

0 commit comments

Comments
 (0)