-
Notifications
You must be signed in to change notification settings - Fork 4
ENH: add einsum #127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: add einsum #127
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can just imagine the pain that was implementing this one. I left a few comments, but overall looks good.
torch_np/_funcs_impl.py
Outdated
from ._normalizations import ( | ||
maybe_copy_to, | ||
normalize_casting, | ||
normalize_dtype, | ||
normalize_not_implemented, | ||
normalize_outarray, | ||
wrap_tensors, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just define it at the end of _funcs.py
really and leave a note as to why.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If removing imports is the goal, that won't help really. These normalizers are otherwise only used in _normalizations.py
, so these need to be imported either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I'd like to remove is local imports. They are a massive red flag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would create an import cycle since _funcs.py
imports _funcs_impl
. Also from _ndarray import ndarray
must be local (as it is in _normalizations
).
We can of course, import these at the _funcs_impl
module level and special-case them to not get dumped into the global namespace, if that's really what you think is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would just need to import things from _normalizations
(which we already do) and well, we would still have the ndarray import dangling there, but there's not much that can be done about that really, as we have that one all throughout the codebase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, done in the last commit
torch_np/_funcs_impl.py
Outdated
parm = lambda _: None # a fake duck-typed inspect.Parameter stub | ||
parm.name = "out" | ||
out = normalize_outarray(out, parm=parm) | ||
|
||
parm.default = "K" | ||
parm.name = "order" | ||
order = normalize_not_implemented(kwargs.pop("order", "K"), parm=parm) | ||
if kwargs: | ||
raise TypeError("unknown arguments: ", kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two normalizers don't do all that much, so let's just do the error checking directly here for conciseness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also assert that optimize is False
otherwise raise a NotImplementedError similar to order
and so on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh. Now that I re-read https://pytorch.org/docs/stable/generated/torch.einsum.html, there is torch.backends.opt_einsum = "greedy"
so we can support this :-). Question: is there a way of controlling torch.backends
in a context manager, other than a try... finally
block?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we have a context manager for that, no. try...finally
seems reasonable to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, done in 211a461
torch_np/_funcs_impl.py
Outdated
else: | ||
# op, str, op, str ... format: normalize every other argument | ||
sublist_format = True | ||
array_operands = operands[:-1][::2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the [:-1]
? Isn't sublistout
optional? Also, don't we want to preprocess that one as well, perhaps asserting that it's an ndarray
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly!
- If sublistout is not given, the length of
operands
is even, and we pick odd-numbered elements, which are arrays. - If sublistout is given, the length of
operands
is odd, we peel off the last one, and pick odd-numbered elements, which are arrays. Without[:-1]
, we would have pickedsublistout
, too --- and it's a sublist, not an array.
And, no, it's not an array really. Can contain e.g. an Ellipsis:
assert_equal(np.einsum("...i->...", a, optimize=do_opt),
np.sum(a, axis=-1).astype(dtype))
assert_equal(np.einsum(a, [Ellipsis, 0], [Ellipsis], optimize=do_opt),
np.sum(a, axis=-1).astype(dtype))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. Could you please leave a comment or a link to this explanation?
torch_np/_funcs_impl.py
Outdated
|
||
is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32] | ||
if is_short_int: | ||
target_dtype, result_dtype = torch.int64, target_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
result_dtype
is not being used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, removed.
NB it is not used because the current implementation follows pytorch in that e.g.
In [48]: a = np.arange(8, dtype=np.int8)
In [49]: torch.einsum(torch.as_tensor(a), [0], []).dtype
Out[49]: torch.int64
unlike numpy where the last line is int8
(sigh)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is because, as sum
and reductions that accumulate, integeras are upcasted to int64
. This function simply uses sum
and bmm
internally so it has the same semantics as those two.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! What a pain of a function...
Thanks Mario for the review! |
A couple of notes:
einsum
andgradient
, so keep the generic machinery simple and do things manually here.**kwargs
, so follow that. Not sure why it is what it is, one possible reason could be that keyword-only args after varargs is a syntax error (?) E.g.def func(a, *args, *, out=None): pass
.optimize = {False, True, ‘greedy’, ‘optimal’}
argument, so silenty ignore this argumentAlso while at it, add a dedicated annotation to validate the
casting=...
argument. This is strictly speaking extraneous to this PR, so can take it out if desired.