Skip to content

Commit 133c367

Browse files
authored
BUG: make array-detection logic in array(...) more robust (#138)
1 parent 736c29a commit 133c367

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

torch_np/_ndarray.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -449,12 +449,28 @@ def __dlpack_device__(self):
449449
return self.tensor.__dlpack_device__()
450450

451451

452+
def _tolist(obj):
453+
"""Recusrively convert tensors into lists."""
454+
a1 = []
455+
for elem in obj:
456+
if isinstance(elem, (list, tuple)):
457+
elem = _tolist(elem)
458+
if isinstance(elem, ndarray):
459+
a1.append(elem.tensor.tolist())
460+
else:
461+
a1.append(elem)
462+
return a1
463+
464+
452465
# This is the ideally the only place which talks to ndarray directly.
453466
# The rest goes through asarray (preferred) or array.
454467

455468

456469
def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
457-
_util.subok_not_ok(like, subok)
470+
if subok is not False:
471+
raise NotImplementedError(f"'subok' parameter is not supported.")
472+
if like is not None:
473+
raise NotImplementedError(f"'like' parameter is not supported.")
458474
if order != "K":
459475
raise NotImplementedError
460476

@@ -469,13 +485,7 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N
469485

470486
# lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
471487
if isinstance(obj, (list, tuple)):
472-
a1 = []
473-
for elem in obj:
474-
if isinstance(elem, ndarray):
475-
a1.append(elem.tensor.tolist())
476-
else:
477-
a1.append(elem)
478-
obj = a1
488+
obj = _tolist(obj)
479489

480490
# is obj an ndarray already?
481491
if isinstance(obj, ndarray):

torch_np/_util.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,6 @@ def is_sequence(seq):
1919
return True
2020

2121

22-
def subok_not_ok(like=None, subok=False):
23-
if like is not None:
24-
raise ValueError("like=... parameter is not supported.")
25-
if subok:
26-
raise ValueError("subok parameter is not supported.")
27-
28-
2922
class AxisError(ValueError, IndexError):
3023
pass
3124

torch_np/tests/test_basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,9 @@ def test_exported_objects(self):
553553
)
554554
diff = set(exported_fns).difference(set(dir(_np)))
555555
assert len(diff) == 0
556+
557+
558+
class TestCtorNested:
559+
def test_arrays_in_lists(self):
560+
lst = [[1, 2], [3, w.array(4)]]
561+
assert_equal(w.asarray(lst), [[1, 2], [3, 4]])

0 commit comments

Comments
 (0)