@@ -84,6 +84,12 @@ def category(dtyp):
84
84
85
85
86
86
def nep50_to_tensors (x1 , x2 ):
87
+ """If either of inputs is a python scalar, type-promote with NEP 50.
88
+
89
+ NB: NEP 50 mandates RuntimeWarnings on some overflows. We do not emit them:
90
+ we either raise OverflowError or just do the computation.
91
+ """
92
+
87
93
x1_type , x2_type = type (x1 ), type (x2 )
88
94
x1_is_weak = x1_type in SCALAR_TYPES
89
95
x2_is_weak = x2_type in SCALAR_TYPES
@@ -112,6 +118,11 @@ def nep50_to_tensors(x1, x2):
112
118
dt = torch .complex64
113
119
114
120
# finally, can cast make `weak` into a 0D tensor
115
- weak = torch .as_tensor (weak , dtype = dt )
121
+ weak_ = torch .as_tensor (weak , dtype = dt )
122
+
123
+ # detect uint overflow: in PyTorch, uint8(-1) wraps around to 255,
124
+ # while NEP50 mandates an exception.
125
+ if weak_ .dtype == torch .uint8 and weak_ .item () != weak :
126
+ raise OverflowError (f"Python integer { weak } out of bounds for { weak_ .dtype } " )
116
127
117
- return (weak , not_weak ) if x1_is_weak else (not_weak , weak )
128
+ return (weak_ , not_weak ) if x1_is_weak else (not_weak , weak_ )
0 commit comments