diff --git a/torch_np/tests/test_dtype.py b/torch_np/tests/test_dtype.py new file mode 100644 index 00000000..76392abb --- /dev/null +++ b/torch_np/tests/test_dtype.py @@ -0,0 +1,51 @@ +import numpy as np +import pytest + +import torch_np as tnp + +dtype_names = [ + "bool_", + *[f"int{w}" for w in [8, 16, 32, 64]], + "uint8", + *[f"float{w}" for w in [16, 32, 64]], + *[f"complex{w}" for w in [64, 128]], +] +np_dtype_params = [] +np_dtype_params.append(pytest.param("bool", "bool", id="'bool'")) +np_dtype_params.append( + pytest.param( + "bool", + np.dtype("bool"), + id=f"np.dtype('bool')", + marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"), + ) +) +for name in dtype_names: + np_dtype_params.append(pytest.param(name, name, id=repr(name))) + np_dtype_params.append( + pytest.param( + name, + getattr(np, name), + id=f"np.{name}", + marks=pytest.mark.xfail(reason="XXX: namespaced dtypes not supported"), + ) + ) + np_dtype_params.append( + pytest.param( + name, + np.dtype(name), + id=f"np.dtype({name!r})", + marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"), + ) + ) + + +@pytest.mark.parametrize("name, np_dtype", np_dtype_params) +def test_convert_np_dtypes(name, np_dtype): + tnp_dtype = tnp.dtype(np_dtype) + if name == "bool_": + assert tnp_dtype == tnp.bool_ + elif tnp_dtype.name == "bool_": + assert name.startswith("bool") + else: + assert tnp_dtype.name == name diff --git a/torch_np/tests/test_xps.py b/torch_np/tests/test_xps.py index 10700092..1d434ce9 100644 --- a/torch_np/tests/test_xps.py +++ b/torch_np/tests/test_xps.py @@ -144,3 +144,13 @@ def test_put(np_x, data): note(f"(after put) {tnp_x=}") assert_array_equal(tnp_x, tnp.asarray(np_x).astype(tnp_x.dtype)) + + +@pytest.mark.xfail(reason="XXX: support converting namespaced dtypes") +@given(a=nps.arrays(dtype=nps.scalar_dtypes(), shape=nps.array_shapes())) +def test_asarray_np_arrays(a): + x = tnp.asarray(a) + if a.dtype == np.bool_: + assert x.dtype == tnp.bool + else: + assert x.dtype.name == a.dtype.name