-
Notifications
You must be signed in to change notification settings - Fork 4
BUG: make array-detection logic in array(...) more robust #138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This is needed for arrays hiding as elements of nested lists: e.g. asarray([[1, 2], [3, np.array(4)]])
@@ -449,12 +449,28 @@ def __dlpack_device__(self): | |||
return self.tensor.__dlpack_device__() | |||
|
|||
|
|||
def _tolist(obj): | |||
"""Recusrively convert tensors into lists.""" | |||
a1 = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bring here the if obj is isntance list/tuple, otherwise this function assumes that obj is an iterable, whcih is a non-trivial thing to assume in this context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's guarded with an if isinstance check at the call site?
https://github.com/Quansight-Labs/numpy_pytorch_interop/pull/138/files#diff-83812bd01681b687b8fced4eac464b076aecffc471e0364d905b38fba8794e6cR487
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. What I'm saying is that the logic of this function assumes that if implicitly. The if instance should be inside this function, not in the caller site..
if isinstance(obj, (list, tuple)): | ||
a1 = [] | ||
for elem in obj: | ||
if isinstance(elem, ndarray): | ||
a1.append(elem.tensor.tolist()) | ||
else: | ||
a1.append(elem) | ||
obj = a1 | ||
obj = _tolist(obj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, in what context do we need this? as_tensor
already seems to do this:
>>>torch.as_tensor([[2,3], np.array([2,3])])
tensor([[2, 3],
[2, 3]])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In main, it only looks into the first level of nesting:
In [3]: tnp.array([[1, 2], [tnp.int8(1), 2]])
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[3], line 1
----> 1 tnp.array([[1, 2], [tnp.int8(1), 2]])
File ~/sweethome/proj/scipy/torch_np_compat/torch_np/_ndarray.py:495, in array(obj, dtype, copy, order, subok, ndmin, like)
493 def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
494 # The result of the public `np.array(obj)` is not weakly typed.
--> 495 return _array(obj, dtype, copy=copy, order=order, subok=subok, ndmin=ndmin, like=like, is_weak=False)
File ~/sweethome/proj/scipy/torch_np_compat/torch_np/_ndarray.py:489, in _array(obj, dtype, copy, order, subok, ndmin, like, is_weak)
486 if dtype is not None:
487 torch_dtype = _dtypes.dtype(dtype).torch_dtype
--> 489 tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin, is_weak)
490 return ndarray(tensor)
File ~/sweethome/proj/scipy/torch_np_compat/torch_np/_util.py:213, in _coerce_to_tensor(obj, dtype, copy, ndmin, is_weak)
211 tensor = torch.as_tensor(obj, dtype=dtype)
212 else:
--> 213 tensor = torch.as_tensor(obj)
215 # tensor.dtype is the pytorch default, typically float32. If obj's elements
216 # are not exactly representable in float32, we've lost precision:
217 # >>> torch.as_tensor(1e12).item() - 1e12
(...)
220 # Therefore, we treat `tensor.dtype` as a hint, and convert the
221 # original object *again*, this time with an explicit dtype.
222 torch_dtype = _dtypes_impl.get_default_dtype_for(tensor.dtype)
File ~/sweethome/proj/scipy/torch_np_compat/torch_np/_ndarray.py:391, in ndarray.__len__(self)
390 def __len__(self):
--> 391 return self.tensor.shape[0]
IndexError: tuple index out of range
EDIT: the traceback shows it's not from a branch not main, but the point stands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think this works in PyTorch?
>>> torch.as_tensor([[1,2],[np.int8(1), 2]])
tensor([[1, 2],
[1, 2]])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With numpy ndarrays, yes. With torch_np wrapper ndarrays, not in main, yes with this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then why calling tolist()
? Shouldn't we just unwrap the tensor if it's a tnp
array and call item()
if it's of weak dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm, I just read #138 (comment). Can you add a comment explaining this point?
Note that this function is almost identical to
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just missing a couple nits discussed in the comments, but otherwise LGTM.
@@ -449,12 +449,28 @@ def __dlpack_device__(self): | |||
return self.tensor.__dlpack_device__() | |||
|
|||
|
|||
def _tolist(obj): | |||
"""Recusrively convert tensors into lists.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit Recursively
As discussed offline, this may be not super efficient, but we don't care about tracing efficiency atm, so in it goes. |
Needed for arrays hiding as elements of nested lists e.g.
np.array([[1, 2], [3, np.array(4)]])
.Split off from gh-137.