Skip to content

Commit 8ecc725

Browse files
committed
Tests for converting numpy dtypes to interop dtypes (mostly failing)
1 parent 65c2127 commit 8ecc725

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

torch_np/tests/test_dtype.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
import pytest
3+
4+
import torch_np as tnp
5+
6+
dtype_names = [
7+
"bool_",
8+
*[f"int{w}" for w in [8, 16, 32, 64]],
9+
"uint8",
10+
*[f"float{w}" for w in [16, 32, 64]],
11+
*[f"complex{w}" for w in [64, 128]],
12+
]
13+
np_dtype_params = []
14+
np_dtype_params.append(pytest.param("bool", "bool", id="'bool'"))
15+
np_dtype_params.append(pytest.param("bool", np.dtype("bool"), id=f"np.dtype('bool')"))
16+
for name in dtype_names:
17+
np_dtype_params.append(pytest.param(name, name, id=repr(name)))
18+
np_dtype_params.append(pytest.param(name, getattr(np, name), id=f"np.{name}"))
19+
np_dtype_params.append(pytest.param(name, np.dtype(name), id=f"np.dtype({name!r})"))
20+
21+
22+
@pytest.mark.parametrize("name, np_dtype", np_dtype_params)
23+
def test_convert_np_dtypes(name, np_dtype):
24+
tnp_dtype = tnp.dtype(np_dtype)
25+
if name == "bool_":
26+
assert tnp_dtype == tnp.bool
27+
elif tnp_dtype.name == "bool_":
28+
assert np_dtype.startswith("bool")
29+
else:
30+
assert tnp_dtype.name == name

torch_np/tests/test_xps.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,12 @@ def test_put(np_x, data):
144144
note(f"(after put) {tnp_x=}")
145145

146146
assert_array_equal(tnp_x, tnp.asarray(np_x).astype(tnp_x.dtype))
147+
148+
149+
@given(a=nps.arrays(dtype=nps.scalar_dtypes(), shape=nps.array_shapes()))
150+
def test_asarray_np_arrays(a):
151+
x = tnp.asarray(a)
152+
if a.dtype == np.bool_:
153+
assert x.dtype == tnp.bool
154+
else:
155+
assert x.dtype.name == a.dtype.name

0 commit comments

Comments
 (0)