diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index df7b54a8..8b765f14 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -449,12 +449,28 @@ def __dlpack_device__(self): return self.tensor.__dlpack_device__() +def _tolist(obj): + """Recusrively convert tensors into lists.""" + a1 = [] + for elem in obj: + if isinstance(elem, (list, tuple)): + elem = _tolist(elem) + if isinstance(elem, ndarray): + a1.append(elem.tensor.tolist()) + else: + a1.append(elem) + return a1 + + # This is the ideally the only place which talks to ndarray directly. # The rest goes through asarray (preferred) or array. def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None): - _util.subok_not_ok(like, subok) + if subok is not False: + raise NotImplementedError(f"'subok' parameter is not supported.") + if like is not None: + raise NotImplementedError(f"'like' parameter is not supported.") if order != "K": raise NotImplementedError @@ -469,13 +485,7 @@ def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=N # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists if isinstance(obj, (list, tuple)): - a1 = [] - for elem in obj: - if isinstance(elem, ndarray): - a1.append(elem.tensor.tolist()) - else: - a1.append(elem) - obj = a1 + obj = _tolist(obj) # is obj an ndarray already? if isinstance(obj, ndarray): diff --git a/torch_np/_util.py b/torch_np/_util.py index e120898c..be1f9c7c 100644 --- a/torch_np/_util.py +++ b/torch_np/_util.py @@ -19,13 +19,6 @@ def is_sequence(seq): return True -def subok_not_ok(like=None, subok=False): - if like is not None: - raise ValueError("like=... parameter is not supported.") - if subok: - raise ValueError("subok parameter is not supported.") - - class AxisError(ValueError, IndexError): pass diff --git a/torch_np/tests/test_basic.py b/torch_np/tests/test_basic.py index 7bcab5ab..194f7c3a 100644 --- a/torch_np/tests/test_basic.py +++ b/torch_np/tests/test_basic.py @@ -553,3 +553,9 @@ def test_exported_objects(self): ) diff = set(exported_fns).difference(set(dir(_np))) assert len(diff) == 0 + + +class TestCtorNested: + def test_arrays_in_lists(self): + lst = [[1, 2], [3, w.array(4)]] + assert_equal(w.asarray(lst), [[1, 2], [3, 4]])