Skip to content

Commit aefe64f

Browse files
committed
Try str(arg) when inferring dtype
1 parent 20fab73 commit aefe64f

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

torch_np/_dtypes.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,12 @@ class bool_(generic):
231231
}
232232

233233

234-
def sctype_from_string(s):
234+
def sctype_from_string(s: str):
235235
"""Normalize a string value: a type 'name' or a typecode or a width alias."""
236+
# Normalise "bool_" (i.e. from NumPy) as "bool"
237+
if s == "bool_":
238+
s = "bool"
239+
236240
if s in _names:
237241
return _names[s]
238242
if s in _name_aliases.keys():
@@ -273,10 +277,12 @@ def __init__(self, arg):
273277
elif isinstance(arg, DType):
274278
sctype = arg._scalar_type
275279
# a has a right attribute?
276-
elif hasattr(arg, "dtype"):
280+
elif hasattr(arg, "dtype") and hasattr(arg.dtype, "_scalar_type"):
277281
sctype = arg.dtype._scalar_type
278-
else:
282+
elif isinstance(arg, str):
279283
sctype = sctype_from_string(arg)
284+
else:
285+
sctype = sctype_from_string(str(arg))
280286
self._scalar_type = sctype
281287

282288
@property

torch_np/tests/test_dtype.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,23 @@
1515
np_dtype_params.append(pytest.param("bool", np.dtype("bool"), id=f"np.dtype('bool')"))
1616
for name in dtype_names:
1717
np_dtype_params.append(pytest.param(name, name, id=repr(name)))
18-
np_dtype_params.append(pytest.param(name, getattr(np, name), id=f"np.{name}"))
18+
np_dtype_params.append(
19+
pytest.param(
20+
name,
21+
getattr(np, name),
22+
id=f"np.{name}",
23+
marks=pytest.mark.xfail(reason="XXX: namespaced dtypes not supported"),
24+
)
25+
)
1926
np_dtype_params.append(pytest.param(name, np.dtype(name), id=f"np.dtype({name!r})"))
2027

2128

2229
@pytest.mark.parametrize("name, np_dtype", np_dtype_params)
2330
def test_convert_np_dtypes(name, np_dtype):
2431
tnp_dtype = tnp.dtype(np_dtype)
2532
if name == "bool_":
26-
assert tnp_dtype == tnp.bool
33+
assert tnp_dtype == tnp.bool_
2734
elif tnp_dtype.name == "bool_":
28-
assert np_dtype.startswith("bool")
35+
assert name.startswith("bool")
2936
else:
3037
assert tnp_dtype.name == name

torch_np/tests/test_xps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def test_put(np_x, data):
146146
assert_array_equal(tnp_x, tnp.asarray(np_x).astype(tnp_x.dtype))
147147

148148

149+
@pytest.mark.xfail(reason="XXX: support converting namespaced dtypes")
149150
@given(a=nps.arrays(dtype=nps.scalar_dtypes(), shape=nps.array_shapes()))
150151
def test_asarray_np_arrays(a):
151152
x = tnp.asarray(a)

0 commit comments

Comments
 (0)