-
Notifications
You must be signed in to change notification settings - Fork 4
tnp.put()
+ testing
#116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tnp.put()
+ testing
#116
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, but I believe it may be possible to simplify some bits of the implementation.
Could you also try un-xfailing |
torch_np/_funcs_impl.py
Outdated
if ind.numel() > v.numel(): | ||
result_shape = torch.broadcast_shapes(v.shape, ind.shape) | ||
v = torch.broadcast_to(v, result_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? If ind = torch.arange(4)
and v = torch.arange(3).unsqueeze(1)
, then we'll broadcast v
, and we'll then trim it in the next line. The issue here is that by doing the broadcasting in the last dimension, we'll be copying the elements 0, 0, 0, 0
, rather than 0, 1, 2, 0
. For the elements to repeat as we want when doing the flatten
below we want to do the broadcasting on the left-most dimension. This was also a bug in my proposed implementation. You can do that by using the ceil computation I proposed + unsqueeze(0)
+ expand
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah you're right, for reference
>>> a = np.arange(5)
>>> ind = np.arange(4)
>>> v = np.arange(3).unsqueeze(1)
>>> np.put(a, ind, v)
>>> a
array([0, 1, 2, 0, 4])
>>> x = tnp.arange(5)
>>> ind = tnp.arange(4)
>>> v = tnp.asarray([[0], [1], [2]])
>>> tnp.put(x, ind, v)
>>> x
array_w([0, 0, 0, 0, 4])
So Ill need to do more exploration/testing which still satisfies the existing problem I came across. Will leave for next week 😅
FWIW dirty/un-minified failing case I alluded to before when using repeat instead
AssertionError:
Arrays are not equal
Mismatched elements: 2 / 12 (16.7%)
x: array_w([[ True, False, False, True],
[False, False, False, False],
[False, False, False, False]])
y: array_w([[ True, False, True, False],
[False, False, False, False],
[False, False, False, False]])
Falsifying example: test_put(
np_x=array([[False, False, False, False],
[False, False, False, False],
[False, False, False, False]]),
data=data(...),
)
Draw 1 (ind): (array([[1],
[2]]), array([[3],
[0]]))
Draw 2 (v): array([[False],
[ True]])
(after put) np_x=array([[ True, False, True, False],
[False, False, False, False],
[False, False, False, False]])
tnp_x=array_w([[False, False, False, False],
[False, False, False, False],
[False, False, False, False]])
(after put) tnp_x=array_w([[ True, False, False, True],
[False, False, False, False],
[False, False, False, False]])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pushed a fix using repeat
again, and make sure its tested for in test_put
.
Notably repeating just on the 0th axis seems the right call for our put()
to align with NumPy-proper, as other ways of constructing *sizes
falls short (i.e. Hypothesis will come up with failing examples in test_put
).
Alos partially unxfails `test_multiarray.py::TestMethods::test_put`
Prevents normalising non-ndarray arguments
* `list_at_ind` stuff should be covered when testing the normilizer * Manually converting `ind` and `v` to `tnp.ndarray` is mostly redundant
Co-authored-by: Mario Lezcano Casado <[email protected]>
Also test 0d indices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we're cooking with gas! Thank you for the PR!
put()
implementation and my own test. This was tricky but I think it's pretty robust, sans the TODOind
Notably in
test_put
I'm testingtnp.put()
output against NumPy's ownnp.put()
, which as you can see is a bit noisy right now due totnp.asarray()
not carrying over dtypes (yet? something I'd like to fix as it enables future "comparison testing").