|
| 1 | +import numpy as np |
| 2 | +import pytest |
| 3 | + |
| 4 | +import torch_np as tnp |
| 5 | + |
| 6 | +dtype_names = [ |
| 7 | + "bool_", |
| 8 | + *[f"int{w}" for w in [8, 16, 32, 64]], |
| 9 | + "uint8", |
| 10 | + *[f"float{w}" for w in [16, 32, 64]], |
| 11 | + *[f"complex{w}" for w in [64, 128]], |
| 12 | +] |
| 13 | +np_dtype_params = [] |
| 14 | +np_dtype_params.append(pytest.param("bool", "bool", id="'bool'")) |
| 15 | +np_dtype_params.append(pytest.param("bool", np.dtype("bool"), id=f"np.dtype('bool')")) |
| 16 | +for name in dtype_names: |
| 17 | + np_dtype_params.append(pytest.param(name, name, id=repr(name))) |
| 18 | + np_dtype_params.append(pytest.param(name, getattr(np, name), id=f"np.{name}")) |
| 19 | + np_dtype_params.append(pytest.param(name, np.dtype(name), id=f"np.dtype({name!r})")) |
| 20 | + |
| 21 | + |
| 22 | +@pytest.mark.parametrize("name, np_dtype", np_dtype_params) |
| 23 | +def test_convert_np_dtypes(name, np_dtype): |
| 24 | + tnp_dtype = tnp.dtype(np_dtype) |
| 25 | + if name == "bool_": |
| 26 | + assert tnp_dtype == tnp.bool |
| 27 | + elif tnp_dtype.name == "bool_": |
| 28 | + assert np_dtype.startswith("bool") |
| 29 | + else: |
| 30 | + assert tnp_dtype.name == name |
0 commit comments