|
| 1 | + |
| 2 | +from torch_np import array, uint8, int64, float32, float64, inf |
| 3 | +from torch_np.testing import assert_allclose |
| 4 | +uint16 = uint8 # can be anything here, see below |
| 5 | + |
| 6 | + |
| 7 | +#from numpy import array, uint8, uint16, int64, float32, float64, inf |
| 8 | +#from numpy.testing import assert_allclose |
| 9 | +#import numpy as np |
| 10 | +#np._set_promotion_state('weak') |
| 11 | + |
| 12 | +import pytest |
| 13 | +from pytest import raises as assert_raises |
| 14 | + |
| 15 | +unchanged = None |
| 16 | + |
| 17 | +# expression old result new_result |
| 18 | +# uint8(1) + 2 int64(3) uint8(3) [T1] |
| 19 | +examples = {"uint8(1) + 2": (int64(3), uint8(3)), |
| 20 | + "array([1], uint8) + int64(1)": (array([2], uint8), array([2], int64)), |
| 21 | + "array([1], uint8) + array(1, int64)": (array([2], uint8), array([2], int64)), |
| 22 | + "array([1.], float32) + float64(1.)": (array([2.], float32), array([2.], float64)), |
| 23 | + "array([1.], float32) + array(1., float64)": (array([2.], float32), array([2.], float64)), |
| 24 | + "array([1], uint8) + 1": (array([2], uint8), unchanged), |
| 25 | + "array([1], uint8) + 200": (array([201], uint8), unchanged), |
| 26 | + "array([100], uint8) + 200": (array([ 44], uint8), unchanged), |
| 27 | + "array([1], uint8) + 300": (array([301], uint16), Exception), |
| 28 | + "uint8(1) + 300": (int64(301), Exception), |
| 29 | + "uint8(100) + 200": (int64(301), uint8(44)), # and RuntimeWarning |
| 30 | + "float32(1) + 3e100" : (float64(3e100), float32(inf)), # and RuntimeWarning [T7] |
| 31 | + # "array([0.1], float32) == 0.1": (array([False]), unchanged), # XXX: a typo in NEP50? |
| 32 | + "array([0.1], float32) == float64(0.1)": (array([ True]), array([False])), |
| 33 | + "array([1.], float32) + 3": (array([4.], float32), unchanged), |
| 34 | + "array([1.], float32) + int64(3)": (array([4.], float32), array([4.], float64)) |
| 35 | +} |
| 36 | + |
| 37 | + |
| 38 | +@pytest.mark.parametrize("example", examples) |
| 39 | +def test_nep50_exceptions(example): |
| 40 | + old, new = examples[example] |
| 41 | + |
| 42 | + if new == Exception: |
| 43 | + with assert_raises(OverflowError): |
| 44 | + eval(example) |
| 45 | + |
| 46 | + else: |
| 47 | + result = eval(example) |
| 48 | + |
| 49 | + if new is unchanged: |
| 50 | + new = old |
| 51 | + |
| 52 | + assert_allclose(result, new, atol=1e-16) |
| 53 | + assert result.dtype == new.dtype |
| 54 | + |
0 commit comments