Skip to content

BUG: make array-detection logic in array(...) more robust #138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 17, 2023

Conversation

ev-br
Copy link
Collaborator

@ev-br ev-br commented May 13, 2023

Needed for arrays hiding as elements of nested lists e.g. np.array([[1, 2], [3, np.array(4)]]).

Split off from gh-137.

ev-br added 2 commits May 13, 2023 12:37
This is needed for arrays hiding as elements of nested lists:
e.g. asarray([[1, 2], [3, np.array(4)]])
@ev-br ev-br requested a review from lezcano May 13, 2023 09:40
@@ -449,12 +449,28 @@ def __dlpack_device__(self):
return self.tensor.__dlpack_device__()


def _tolist(obj):
"""Recusrively convert tensors into lists."""
a1 = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bring here the if obj is isntance list/tuple, otherwise this function assumes that obj is an iterable, whcih is a non-trivial thing to assume in this context.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. What I'm saying is that the logic of this function assumes that if implicitly. The if instance should be inside this function, not in the caller site..

Comment on lines 487 to +488
if isinstance(obj, (list, tuple)):
a1 = []
for elem in obj:
if isinstance(elem, ndarray):
a1.append(elem.tensor.tolist())
else:
a1.append(elem)
obj = a1
obj = _tolist(obj)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, in what context do we need this? as_tensor already seems to do this:

>>>torch.as_tensor([[2,3], np.array([2,3])])
tensor([[2, 3],
        [2, 3]])

Copy link
Collaborator Author

@ev-br ev-br May 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In main, it only looks into the first level of nesting:

In [3]: tnp.array([[1, 2], [tnp.int8(1), 2]])
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[3], line 1
----> 1 tnp.array([[1, 2], [tnp.int8(1), 2]])

File ~/sweethome/proj/scipy/torch_np_compat/torch_np/_ndarray.py:495, in array(obj, dtype, copy, order, subok, ndmin, like)
    493 def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
    494     # The result of the public `np.array(obj)` is not weakly typed.
--> 495     return _array(obj, dtype, copy=copy, order=order, subok=subok, ndmin=ndmin, like=like, is_weak=False)

File ~/sweethome/proj/scipy/torch_np_compat/torch_np/_ndarray.py:489, in _array(obj, dtype, copy, order, subok, ndmin, like, is_weak)
    486 if dtype is not None:
    487     torch_dtype = _dtypes.dtype(dtype).torch_dtype
--> 489 tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin, is_weak)
    490 return ndarray(tensor)

File ~/sweethome/proj/scipy/torch_np_compat/torch_np/_util.py:213, in _coerce_to_tensor(obj, dtype, copy, ndmin, is_weak)
    211     tensor = torch.as_tensor(obj, dtype=dtype)
    212 else:
--> 213     tensor = torch.as_tensor(obj)
    215     # tensor.dtype is the pytorch default, typically float32. If obj's elements
    216     # are not exactly representable in float32, we've lost precision:
    217     # >>> torch.as_tensor(1e12).item() - 1e12
   (...)
    220     # Therefore, we treat `tensor.dtype` as a hint, and convert the
    221     # original object *again*, this time with an explicit dtype.
    222     torch_dtype = _dtypes_impl.get_default_dtype_for(tensor.dtype)

File ~/sweethome/proj/scipy/torch_np_compat/torch_np/_ndarray.py:391, in ndarray.__len__(self)
    390 def __len__(self):
--> 391     return self.tensor.shape[0]

IndexError: tuple index out of range

A smoke test is https://github.com/Quansight-Labs/numpy_pytorch_interop/pull/138/files#diff-66e581e6a3a373197465b30dd84684d915c83f3c619773382e8f8a40140662a5R558

EDIT: the traceback shows it's not from a branch not main, but the point stands.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think this works in PyTorch?

>>> torch.as_tensor([[1,2],[np.int8(1), 2]])
tensor([[1, 2],
        [1, 2]])

Copy link
Collaborator Author

@ev-br ev-br May 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With numpy ndarrays, yes. With torch_np wrapper ndarrays, not in main, yes with this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then why calling tolist()? Shouldn't we just unwrap the tensor if it's a tnp array and call item() if it's of weak dtype?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, I just read #138 (comment). Can you add a comment explaining this point?

@ev-br
Copy link
Collaborator Author

ev-br commented May 14, 2023

Note that this function is almost identical to _helpers/ndarrays_to_tensors, which is used for fancy indexing and basically converts nested arrays to tensors. Letting go of .tolist() here sounded tempting but does not quite work because numpy allows, for instance,

(Pdb) import numpy as _np
(Pdb) p x
array([[1., 2.],
       [3., 4.]], dtype=float32)
(Pdb) _np.asarray([x, 2*x, 3*x])
array([[[ 1.,  2.],
        [ 3.,  4.]],

       [[ 2.,  4.],
        [ 6.,  8.]],

       [[ 3.,  6.],
        [ 9., 12.]]], dtype=float32)

(Pdb) import torch
(Pdb) t
tensor([[1., 2.],
        [3., 4.]])
(Pdb) torch.as_tensor([t, 2*t, 3*t])
*** ValueError: only one element tensors can be converted to Python scalars

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just missing a couple nits discussed in the comments, but otherwise LGTM.

@@ -449,12 +449,28 @@ def __dlpack_device__(self):
return self.tensor.__dlpack_device__()


def _tolist(obj):
"""Recusrively convert tensors into lists."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit Recursively

@lezcano
Copy link
Collaborator

lezcano commented May 17, 2023

As discussed offline, this may be not super efficient, but we don't care about tracing efficiency atm, so in it goes.

@lezcano lezcano merged commit 133c367 into main May 17, 2023
@ev-br ev-br deleted the array_recurse branch May 17, 2023 15:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants