Skip to content

Add ndarray.dot #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 28, 2023
Merged

Add ndarray.dot #72

merged 5 commits into from
Feb 28, 2023

Conversation

ev-br
Copy link
Collaborator

@ev-br ev-br commented Feb 28, 2023

Now that both np.dot and ndarray.dot are there, can remove xfails in tests. Which, in turn, smokes out several problems in arange(..., dtype=complex), so fix them. And while at it, remove xfails of TestArange itself and fix it, too :-)

@ev-br ev-br requested a review from lezcano February 28, 2023 08:29
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor points, but overall LGTM

@@ -362,6 +362,7 @@ def reshape(self, *shape, order="C"):

diagonal = _funcs.diagonal
trace = _funcs.trace
dot = _funcs.dot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing also vdot?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing in numpy, yes :-)

In [23]: hasattr(np.array([1, 2, 3]), 'vdot')
Out[23]: False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol. Let's still add it to be forward-looking.

Copy link
Collaborator Author

@ev-br ev-br Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, can do.
If we're starting to do this however, what's the guiding principle, where do we stop. E.g. do we want feature parity for the main namespace and ndarray methods?

In [24]: len(dir(np))
Out[24]: 595

In [25]: len(dir(np.ndarray))
Out[25]: 165

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can do it "for free" it may be a fine thing to do? WDYT @rgommers?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no guiding principle, it's an accident of history. It's highly unlikely though that we'll add more method to numpy.ndarray, there are already too many. So I wouldn't add anything that's not already present.

Comment on lines 146 to 152
def _matmul(x, y):
# work around RuntimeError: expected scalar type Int but found Double
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
x = x.to(dtype)
y = y.to(dtype)
result = torch.matmul(x, y)
return result
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conditional cast.
Also, torch.matmul is pretty much an alias for np.dot (and I believe np.matmul so you should just need to implement one of them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devil's in the details (edge cases). I'll see if implementations can be merged, after all wrinkles are ironed out. For now there are xfails still.

From https://numpy.org/doc/stable/reference/generated/numpy.matmul.html

matmul differs from dot in two important ways:

Multiplication by scalars is not allowed, use * instead.

Stacks of matrices are broadcast together as if the matrices were elements, respecting the signature (n,k),(k,m)->(n,m):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, ok, then torch.matmul is np.dot. You'll need to implement np.matmul implementing that weird broadcasting behaviour by hand.

Copy link
Collaborator Author

@ev-br ev-br Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes (sigh). Am going to postpone this a bit in favor of gh-70. Once that stabilizes, will turn back to matmul.

BTW, this 'signature' thing is gufuncs, I wonder if pytorch has an equivalent or we're facing the need tp mirror the full gufunc machinery.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand what you mean by "this signature thing", but I happen to know that PyTorch doesn't have generalised ufuncs, so we'll need to replicate that machinery at some point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"This signature thing":

respecting the signature (n,k),(k,m)->(n,m)

also referred to as that weird broadcasting a couple of messages above :-).

So yes, am going to postpone this to some point in the future.

@ev-br ev-br merged commit 828d7e0 into main Feb 28, 2023
@ev-br ev-br deleted the unxfail_dot branch February 28, 2023 09:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants