|
| 1 | +"""Test examples for NEP 50.""" |
| 2 | + |
| 3 | +from torch_np import array, float32, float64, inf, int64, uint8 |
| 4 | +from torch_np.testing import assert_allclose |
| 5 | + |
| 6 | +uint16 = uint8 # can be anything here, see below |
| 7 | + |
| 8 | + |
| 9 | +# from numpy import array, uint8, uint16, int64, float32, float64, inf |
| 10 | +# from numpy.testing import assert_allclose |
| 11 | +# import numpy as np |
| 12 | +# np._set_promotion_state('weak') |
| 13 | + |
| 14 | +import pytest |
| 15 | +from pytest import raises as assert_raises |
| 16 | + |
| 17 | +unchanged = None |
| 18 | + |
| 19 | +# expression old result new_result |
| 20 | +examples = { |
| 21 | + "uint8(1) + 2": (int64(3), uint8(3)), |
| 22 | + "array([1], uint8) + int64(1)": (array([2], uint8), array([2], int64)), |
| 23 | + "array([1], uint8) + array(1, int64)": (array([2], uint8), array([2], int64)), |
| 24 | + "array([1.], float32) + float64(1.)": ( |
| 25 | + array([2.0], float32), |
| 26 | + array([2.0], float64), |
| 27 | + ), |
| 28 | + "array([1.], float32) + array(1., float64)": ( |
| 29 | + array([2.0], float32), |
| 30 | + array([2.0], float64), |
| 31 | + ), |
| 32 | + "array([1], uint8) + 1": (array([2], uint8), unchanged), |
| 33 | + "array([1], uint8) + 200": (array([201], uint8), unchanged), |
| 34 | + "array([100], uint8) + 200": (array([44], uint8), unchanged), |
| 35 | + "array([1], uint8) + 300": (array([301], uint16), Exception), |
| 36 | + "uint8(1) + 300": (int64(301), Exception), |
| 37 | + "uint8(100) + 200": (int64(301), uint8(44)), # and RuntimeWarning |
| 38 | + "float32(1) + 3e100": (float64(3e100), float32(inf)), # and RuntimeWarning [T7] |
| 39 | + # "array([0.1], float32) == 0.1": (array([False]), unchanged), # XXX: a typo in NEP50? |
| 40 | + "array([0.1], float32) == float64(0.1)": (array([True]), array([False])), |
| 41 | + "array([1.], float32) + 3": (array([4.0], float32), unchanged), |
| 42 | + "array([1.], float32) + int64(3)": (array([4.0], float32), array([4.0], float64)), |
| 43 | +} |
| 44 | + |
| 45 | + |
| 46 | +fails = [ |
| 47 | + "uint8(1) + 2", |
| 48 | + "array([1], uint8) + 1", |
| 49 | + "array([1], uint8) + 200", |
| 50 | + "array([1], uint8) + array(1, int64)", |
| 51 | + "array([100], uint8) + 200", |
| 52 | + "array([1], uint8) + 300", |
| 53 | + "uint8(1) + 300", |
| 54 | + "uint8(100) + 200", |
| 55 | + "float32(1) + 3e100", |
| 56 | + "array([1.], float32) + 3", |
| 57 | +] |
| 58 | + |
| 59 | + |
| 60 | +@pytest.mark.parametrize("example", examples) |
| 61 | +def test_nep50_exceptions(example): |
| 62 | + |
| 63 | + if example in fails: |
| 64 | + pytest.xfail(reason="scalars") |
| 65 | + |
| 66 | + old, new = examples[example] |
| 67 | + |
| 68 | + if new == Exception: |
| 69 | + with assert_raises(OverflowError): |
| 70 | + eval(example) |
| 71 | + |
| 72 | + else: |
| 73 | + result = eval(example) |
| 74 | + |
| 75 | + if new is unchanged: |
| 76 | + new = old |
| 77 | + |
| 78 | + assert_allclose(result, new, atol=1e-16) |
| 79 | + assert result.dtype == new.dtype |
0 commit comments