Skip to content

Commit 599238d

Browse files
committed
TST: add NEP 50 examples as a test
1 parent aaabfda commit 599238d

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

autogen/gen_dtypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import torch
1010

11+
np._set_promotion_state("weak")
1112

1213
class dtype:
1314
def __init__(self, name):

torch_np/tests/test_nep50_examples.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
2+
from torch_np import array, uint8, int64, float32, float64, inf
3+
from torch_np.testing import assert_allclose
4+
uint16 = uint8 # can be anything here, see below
5+
6+
7+
#from numpy import array, uint8, uint16, int64, float32, float64, inf
8+
#from numpy.testing import assert_allclose
9+
#import numpy as np
10+
#np._set_promotion_state('weak')
11+
12+
import pytest
13+
from pytest import raises as assert_raises
14+
15+
unchanged = None
16+
17+
# expression old result new_result
18+
# uint8(1) + 2 int64(3) uint8(3) [T1]
19+
examples = {"uint8(1) + 2": (int64(3), uint8(3)),
20+
"array([1], uint8) + int64(1)": (array([2], uint8), array([2], int64)),
21+
"array([1], uint8) + array(1, int64)": (array([2], uint8), array([2], int64)),
22+
"array([1.], float32) + float64(1.)": (array([2.], float32), array([2.], float64)),
23+
"array([1.], float32) + array(1., float64)": (array([2.], float32), array([2.], float64)),
24+
"array([1], uint8) + 1": (array([2], uint8), unchanged),
25+
"array([1], uint8) + 200": (array([201], uint8), unchanged),
26+
"array([100], uint8) + 200": (array([ 44], uint8), unchanged),
27+
"array([1], uint8) + 300": (array([301], uint16), Exception),
28+
"uint8(1) + 300": (int64(301), Exception),
29+
"uint8(100) + 200": (int64(301), uint8(44)), # and RuntimeWarning
30+
"float32(1) + 3e100" : (float64(3e100), float32(inf)), # and RuntimeWarning [T7]
31+
# "array([0.1], float32) == 0.1": (array([False]), unchanged), # XXX: a typo in NEP50?
32+
"array([0.1], float32) == float64(0.1)": (array([ True]), array([False])),
33+
"array([1.], float32) + 3": (array([4.], float32), unchanged),
34+
"array([1.], float32) + int64(3)": (array([4.], float32), array([4.], float64))
35+
}
36+
37+
38+
@pytest.mark.parametrize("example", examples)
39+
def test_nep50_exceptions(example):
40+
old, new = examples[example]
41+
42+
if new == Exception:
43+
with assert_raises(OverflowError):
44+
eval(example)
45+
46+
else:
47+
result = eval(example)
48+
49+
if new is unchanged:
50+
new = old
51+
52+
assert_allclose(result, new, atol=1e-16)
53+
assert result.dtype == new.dtype
54+

0 commit comments

Comments
 (0)