|
| 1 | +import numpy as np |
| 2 | +import pytest |
| 3 | + |
| 4 | +import torch_np as tnp |
| 5 | + |
| 6 | +dtype_names = [ |
| 7 | + "bool_", |
| 8 | + *[f"int{w}" for w in [8, 16, 32, 64]], |
| 9 | + "uint8", |
| 10 | + *[f"float{w}" for w in [16, 32, 64]], |
| 11 | + *[f"complex{w}" for w in [64, 128]], |
| 12 | +] |
| 13 | +np_dtype_params = [] |
| 14 | +np_dtype_params.append(pytest.param("bool", "bool", id="'bool'")) |
| 15 | +np_dtype_params.append( |
| 16 | + pytest.param( |
| 17 | + "bool", |
| 18 | + np.dtype("bool"), |
| 19 | + id=f"np.dtype('bool')", |
| 20 | + marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"), |
| 21 | + ) |
| 22 | +) |
| 23 | +for name in dtype_names: |
| 24 | + np_dtype_params.append(pytest.param(name, name, id=repr(name))) |
| 25 | + np_dtype_params.append( |
| 26 | + pytest.param( |
| 27 | + name, |
| 28 | + getattr(np, name), |
| 29 | + id=f"np.{name}", |
| 30 | + marks=pytest.mark.xfail(reason="XXX: namespaced dtypes not supported"), |
| 31 | + ) |
| 32 | + ) |
| 33 | + np_dtype_params.append( |
| 34 | + pytest.param( |
| 35 | + name, |
| 36 | + np.dtype(name), |
| 37 | + id=f"np.dtype({name!r})", |
| 38 | + marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"), |
| 39 | + ) |
| 40 | + ) |
| 41 | + |
| 42 | + |
| 43 | +@pytest.mark.parametrize("name, np_dtype", np_dtype_params) |
| 44 | +def test_convert_np_dtypes(name, np_dtype): |
| 45 | + tnp_dtype = tnp.dtype(np_dtype) |
| 46 | + if name == "bool_": |
| 47 | + assert tnp_dtype == tnp.bool_ |
| 48 | + elif tnp_dtype.name == "bool_": |
| 49 | + assert name.startswith("bool") |
| 50 | + else: |
| 51 | + assert tnp_dtype.name == name |
0 commit comments