-
Notifications
You must be signed in to change notification settings - Fork 4
Add ndarray.dot #72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ndarray.dot #72
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor points, but overall LGTM
@@ -362,6 +362,7 @@ def reshape(self, *shape, order="C"): | |||
|
|||
diagonal = _funcs.diagonal | |||
trace = _funcs.trace | |||
dot = _funcs.dot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing also vdot
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing in numpy, yes :-)
In [23]: hasattr(np.array([1, 2, 3]), 'vdot')
Out[23]: False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol. Let's still add it to be forward-looking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, can do.
If we're starting to do this however, what's the guiding principle, where do we stop. E.g. do we want feature parity for the main namespace and ndarray methods?
In [24]: len(dir(np))
Out[24]: 595
In [25]: len(dir(np.ndarray))
Out[25]: 165
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can do it "for free" it may be a fine thing to do? WDYT @rgommers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no guiding principle, it's an accident of history. It's highly unlikely though that we'll add more method to numpy.ndarray
, there are already too many. So I wouldn't add anything that's not already present.
torch_np/_detail/_ufunc_impl.py
Outdated
def _matmul(x, y): | ||
# work around RuntimeError: expected scalar type Int but found Double | ||
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype)) | ||
x = x.to(dtype) | ||
y = y.to(dtype) | ||
result = torch.matmul(x, y) | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conditional cast.
Also, torch.matmul
is pretty much an alias for np.dot
(and I believe np.matmul
so you should just need to implement one of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Devil's in the details (edge cases). I'll see if implementations can be merged, after all wrinkles are ironed out. For now there are xfails still.
From https://numpy.org/doc/stable/reference/generated/numpy.matmul.html
matmul differs from dot in two important ways:
Multiplication by scalars is not allowed, use * instead.
Stacks of matrices are broadcast together as if the matrices were elements, respecting the signature (n,k),(k,m)->(n,m):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, ok, then torch.matmul
is np.dot
. You'll need to implement np.matmul
implementing that weird broadcasting behaviour by hand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes (sigh). Am going to postpone this a bit in favor of gh-70. Once that stabilizes, will turn back to matmul.
BTW, this 'signature' thing is gufuncs, I wonder if pytorch has an equivalent or we're facing the need tp mirror the full gufunc machinery.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand what you mean by "this signature thing", but I happen to know that PyTorch doesn't have generalised ufuncs, so we'll need to replicate that machinery at some point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"This signature thing":
respecting the signature (n,k),(k,m)->(n,m)
also referred to as that weird broadcasting a couple of messages above :-).
So yes, am going to postpone this to some point in the future.
Co-authored-by: Mario Lezcano Casado <[email protected]>
Now that both np.dot and ndarray.dot are there, can remove xfails in tests. Which, in turn, smokes out several problems in arange(..., dtype=complex), so fix them. And while at it, remove xfails of TestArange itself and fix it, too :-)