Skip to content

ENH: add einsum #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 27, 2023
Merged

ENH: add einsum #127

merged 6 commits into from
Apr 27, 2023

Conversation

ev-br
Copy link
Collaborator

@ev-br ev-br commented Apr 25, 2023

A couple of notes:

  • the generic normalization machinery does not handle signatures with positional args before varargs. This is einsum and gradient, so keep the generic machinery simple and do things manually here.
  • einsum signature is a bit weird even in numpy, with **kwargs, so follow that. Not sure why it is what it is, one possible reason could be that keyword-only args after varargs is a syntax error (?) E.g. def func(a, *args, *, out=None): pass.
  • IIUC, pytorch does not allow the same level of control over the optimize = {False, True, ‘greedy’, ‘optimal’} argument, so silenty ignore this argument

Also while at it, add a dedicated annotation to validate the casting=... argument. This is strictly speaking extraneous to this PR, so can take it out if desired.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can just imagine the pain that was implementing this one. I left a few comments, but overall looks good.

Comment on lines 1239 to 1247
from ._normalizations import (
maybe_copy_to,
normalize_casting,
normalize_dtype,
normalize_not_implemented,
normalize_outarray,
wrap_tensors,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just define it at the end of _funcs.py really and leave a note as to why.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If removing imports is the goal, that won't help really. These normalizers are otherwise only used in _normalizations.py, so these need to be imported either way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'd like to remove is local imports. They are a massive red flag

Copy link
Collaborator Author

@ev-br ev-br Apr 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would create an import cycle since _funcs.py imports _funcs_impl. Also from _ndarray import ndarray must be local (as it is in _normalizations).

We can of course, import these at the _funcs_impl module level and special-case them to not get dumped into the global namespace, if that's really what you think is needed?

Copy link
Collaborator

@lezcano lezcano Apr 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would just need to import things from _normalizations (which we already do) and well, we would still have the ndarray import dangling there, but there's not much that can be done about that really, as we have that one all throughout the codebase.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, done in the last commit

Comment on lines 1251 to 1260
parm = lambda _: None # a fake duck-typed inspect.Parameter stub
parm.name = "out"
out = normalize_outarray(out, parm=parm)

parm.default = "K"
parm.name = "order"
order = normalize_not_implemented(kwargs.pop("order", "K"), parm=parm)
if kwargs:
raise TypeError("unknown arguments: ", kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two normalizers don't do all that much, so let's just do the error checking directly here for conciseness.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also assert that optimize is False otherwise raise a NotImplementedError similar to order and so on.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. Now that I re-read https://pytorch.org/docs/stable/generated/torch.einsum.html, there is torch.backends.opt_einsum = "greedy" so we can support this :-). Question: is there a way of controlling torch.backends in a context manager, other than a try... finally block?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have a context manager for that, no. try...finally seems reasonable to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done in 211a461

else:
# op, str, op, str ... format: normalize every other argument
sublist_format = True
array_operands = operands[:-1][::2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the [:-1]? Isn't sublistout optional? Also, don't we want to preprocess that one as well, perhaps asserting that it's an ndarray?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly!

  • If sublistout is not given, the length of operands is even, and we pick odd-numbered elements, which are arrays.
  • If sublistout is given, the length of operands is odd, we peel off the last one, and pick odd-numbered elements, which are arrays. Without [:-1], we would have picked sublistout, too --- and it's a sublist, not an array.

And, no, it's not an array really. Can contain e.g. an Ellipsis:

            assert_equal(np.einsum("...i->...", a, optimize=do_opt),
                         np.sum(a, axis=-1).astype(dtype))
            assert_equal(np.einsum(a, [Ellipsis, 0], [Ellipsis], optimize=do_opt),
                         np.sum(a, axis=-1).astype(dtype))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Could you please leave a comment or a link to this explanation?


is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32]
if is_short_int:
target_dtype, result_dtype = torch.int64, target_dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result_dtype is not being used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, removed.

NB it is not used because the current implementation follows pytorch in that e.g.

In [48]: a = np.arange(8, dtype=np.int8)

In [49]: torch.einsum(torch.as_tensor(a), [0], []).dtype
Out[49]: torch.int64

unlike numpy where the last line is int8 (sigh)

Copy link
Collaborator

@lezcano lezcano Apr 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is because, as sum and reductions that accumulate, integeras are upcasted to int64. This function simply uses sum and bmm internally so it has the same semantics as those two.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! What a pain of a function...

@ev-br
Copy link
Collaborator Author

ev-br commented Apr 27, 2023

Thanks Mario for the review!

@ev-br ev-br merged commit cd5f74a into einsum_tests Apr 27, 2023
@ev-br ev-br deleted the einsum branch April 27, 2023 10:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants