Skip to content

Commit 736c29a

Browse files
authored
Tests for converting NumPy dtypes to their equivalent (#128)
1 parent dfc1db5 commit 736c29a

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

torch_np/tests/test_dtype.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
import pytest
3+
4+
import torch_np as tnp
5+
6+
dtype_names = [
7+
"bool_",
8+
*[f"int{w}" for w in [8, 16, 32, 64]],
9+
"uint8",
10+
*[f"float{w}" for w in [16, 32, 64]],
11+
*[f"complex{w}" for w in [64, 128]],
12+
]
13+
np_dtype_params = []
14+
np_dtype_params.append(pytest.param("bool", "bool", id="'bool'"))
15+
np_dtype_params.append(
16+
pytest.param(
17+
"bool",
18+
np.dtype("bool"),
19+
id=f"np.dtype('bool')",
20+
marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"),
21+
)
22+
)
23+
for name in dtype_names:
24+
np_dtype_params.append(pytest.param(name, name, id=repr(name)))
25+
np_dtype_params.append(
26+
pytest.param(
27+
name,
28+
getattr(np, name),
29+
id=f"np.{name}",
30+
marks=pytest.mark.xfail(reason="XXX: namespaced dtypes not supported"),
31+
)
32+
)
33+
np_dtype_params.append(
34+
pytest.param(
35+
name,
36+
np.dtype(name),
37+
id=f"np.dtype({name!r})",
38+
marks=pytest.mark.xfail(reason="XXX: np.dtype() objects not supported"),
39+
)
40+
)
41+
42+
43+
@pytest.mark.parametrize("name, np_dtype", np_dtype_params)
44+
def test_convert_np_dtypes(name, np_dtype):
45+
tnp_dtype = tnp.dtype(np_dtype)
46+
if name == "bool_":
47+
assert tnp_dtype == tnp.bool_
48+
elif tnp_dtype.name == "bool_":
49+
assert name.startswith("bool")
50+
else:
51+
assert tnp_dtype.name == name

torch_np/tests/test_xps.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,13 @@ def test_put(np_x, data):
144144
note(f"(after put) {tnp_x=}")
145145

146146
assert_array_equal(tnp_x, tnp.asarray(np_x).astype(tnp_x.dtype))
147+
148+
149+
@pytest.mark.xfail(reason="XXX: support converting namespaced dtypes")
150+
@given(a=nps.arrays(dtype=nps.scalar_dtypes(), shape=nps.array_shapes()))
151+
def test_asarray_np_arrays(a):
152+
x = tnp.asarray(a)
153+
if a.dtype == np.bool_:
154+
assert x.dtype == tnp.bool
155+
else:
156+
assert x.dtype.name == a.dtype.name

0 commit comments

Comments
 (0)