Skip to content

Commit 2ce7d21

Browse files
authored
Merge pull request #125 from Quansight-Labs/nep50
TST: check status of NEP 50
2 parents aaabfda + 1d8a769 commit 2ce7d21

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

autogen/gen_dtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import numpy as np
99
import torch
1010

11+
np._set_promotion_state("weak")
12+
1113

1214
class dtype:
1315
def __init__(self, name):

torch_np/tests/test_nep50_examples.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Test examples for NEP 50."""
2+
3+
from torch_np import array, float32, float64, inf, int64, uint8
4+
from torch_np.testing import assert_allclose
5+
6+
uint16 = uint8 # can be anything here, see below
7+
8+
9+
# from numpy import array, uint8, uint16, int64, float32, float64, inf
10+
# from numpy.testing import assert_allclose
11+
# import numpy as np
12+
# np._set_promotion_state('weak')
13+
14+
import pytest
15+
from pytest import raises as assert_raises
16+
17+
unchanged = None
18+
19+
# expression old result new_result
20+
examples = {
21+
"uint8(1) + 2": (int64(3), uint8(3)),
22+
"array([1], uint8) + int64(1)": (array([2], uint8), array([2], int64)),
23+
"array([1], uint8) + array(1, int64)": (array([2], uint8), array([2], int64)),
24+
"array([1.], float32) + float64(1.)": (
25+
array([2.0], float32),
26+
array([2.0], float64),
27+
),
28+
"array([1.], float32) + array(1., float64)": (
29+
array([2.0], float32),
30+
array([2.0], float64),
31+
),
32+
"array([1], uint8) + 1": (array([2], uint8), unchanged),
33+
"array([1], uint8) + 200": (array([201], uint8), unchanged),
34+
"array([100], uint8) + 200": (array([44], uint8), unchanged),
35+
"array([1], uint8) + 300": (array([301], uint16), Exception),
36+
"uint8(1) + 300": (int64(301), Exception),
37+
"uint8(100) + 200": (int64(301), uint8(44)), # and RuntimeWarning
38+
"float32(1) + 3e100": (float64(3e100), float32(inf)), # and RuntimeWarning [T7]
39+
# "array([0.1], float32) == 0.1": (array([False]), unchanged), # XXX: a typo in NEP50?
40+
"array([0.1], float32) == float64(0.1)": (array([True]), array([False])),
41+
"array([1.], float32) + 3": (array([4.0], float32), unchanged),
42+
"array([1.], float32) + int64(3)": (array([4.0], float32), array([4.0], float64)),
43+
}
44+
45+
46+
fails = [
47+
"uint8(1) + 2",
48+
"array([1], uint8) + 1",
49+
"array([1], uint8) + 200",
50+
"array([1], uint8) + array(1, int64)",
51+
"array([100], uint8) + 200",
52+
"array([1], uint8) + 300",
53+
"uint8(1) + 300",
54+
"uint8(100) + 200",
55+
"float32(1) + 3e100",
56+
"array([1.], float32) + 3",
57+
]
58+
59+
60+
@pytest.mark.parametrize("example", examples)
61+
def test_nep50_exceptions(example):
62+
63+
if example in fails:
64+
pytest.xfail(reason="scalars")
65+
66+
old, new = examples[example]
67+
68+
if new == Exception:
69+
with assert_raises(OverflowError):
70+
eval(example)
71+
72+
else:
73+
result = eval(example)
74+
75+
if new is unchanged:
76+
new = old
77+
78+
assert_allclose(result, new, atol=1e-16)
79+
assert result.dtype == new.dtype

0 commit comments

Comments
 (0)