-
Notifications
You must be signed in to change notification settings - Fork 4
bare-bones normalizations via type hints #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
23abbb0
to
7b7224e
Compare
c81bb9e
to
3dfddd6
Compare
The initial scaffolding is there, so am removing the WIP admonition. While this is not yet complete (ufuncs, reductions), this is something to take a look at, if only to see if a general pattern is tolerable. |
e94647a
to
edda25b
Compare
Gradual (!) typing WTF: only annotate the dtype can get rid of dtype_to_torch decorator. Annotating SeqArrayLike typing TBD.
This is a bit clumsy: func(*args : Annotation) gives a single annotation for a runtime-determined number of arguments. There is no way to annotate individual elements of *args AFAICS. Thus register a special annotation to repack args into a tuple and a normalizer to normalize this tuple.
This is ready from my side. Further possible cleanups include
both are somewhat large and need some experimentation, so are best postponed to follow-up PRs. |
Just passing by to say that the output one can be done without anotation. Just do a For the |
Hmm, this breaks in presence of |
But if we implement the |
Certainly can. My point is simply that it's a bit more than just using |
Nevemind, I see what you mean. Adding a different annotation for this sort of arguments LGTM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch_np/_wrapper
starts looking good! Let's revisit the idea of moving the actual implementations of the torch functions to the wrapper after we have merged this PR and the return PR.
torch_np/_normalizations.py
Outdated
# first, check for *args in positional parameters. Case in point: | ||
# atleast_1d(*arys: UnpackedSequenceArrayLike) | ||
# if found, consume all args into a tuple to normalize as a whole | ||
for j, param in enumerate(sig.parameters.values()): | ||
if param.annotation == UnpackedSeqArrayLike: | ||
if j == 0: | ||
args = (args,) | ||
else: | ||
# args = args[:j] + (args[j:],) would likely work | ||
# not present in numpy codebase, so do not bother just yet. | ||
# NB: branching on j ==0 is to avoid the empty tuple, args[:j] | ||
raise NotImplementedError | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A better way to do this is to check below whether param.kind == VAR_POSITIONAL
, and then, if so, treat it as a List[T]
, where T
is the annotation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So your suggestion is to check for parm.kind == VAR_POSITIONAL
instead of adding a dedicated annotation?
Note that args = (args,)
is still needed in some form, in typing language it's not just List[T]
, it's Union[T, List[T]]
Probably I'm being dense here, would you mind elaborating?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be a matter of adding an if in normalize_this
checking whether it's a VAR_POSTIIONAL
arg and then process the argument/arguments accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That won't work, because tnp.atleast(1, 2)
gets two arguments and only a single annotation. I can of course check param.kind == VAR_POSITIONAL
instead of a special annotation and consume the rest of args
instead of checking if param.annotation == UnpackedSeqArrayLike
, this does not seem any simpler or clearer, does it. I mean, what's the endgame here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, how about 9d75cab
|
||
import torch | ||
|
||
# renames |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the different imports here, rather than just do a from torch import ( blablalba )
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is isort
from pre-commit run
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Terribly weird. @honno any idea what's going on here?
torch_np/_wrapper.py
Outdated
return tuple(asarray(_) for _ in res) | ||
@normalizer | ||
def broadcast_arrays(*args: UnpackedSeqArrayLike, subok: SubokLike = False): | ||
args = args[0] # undo the *args wrapping in normalizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should hopefully be able to fix this with a proper preprocessing of variadic inputs in the normalizer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 9d75cab
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me make a counter-offer:
def maybe_normalize(arg, parm):
"""Normalize arg if a normalizer is registred."""
normalizer = normalizers.get(parm.annotation, None)
return normalizer(arg) if normalizer else arg
def normalizer(func):
def wrapped(*args, **kwds):
params = inspect.signature(func).parameters
first_param = next(iter(params.values()))
# NumPy's API does not have positional args before variadic positional args
if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
args = [maybe_normalize(arg, first_param) for arg in args]
else:
args = [maybe_normalize(arg, parm) for arg, parm in zip(args, params.values())]
kwds = {
name: maybe_normalize(arg, sp[name]) if name in sp else arg
for name, arg in kwds.items()
}
return func(*args, **kwds)
return wrapped
If you think that that comment about NumPy's API is not good enough and you want extra safety (I don't think we need it) you assert that there isn't any other parameter that is variadic that's not on the first position. This should be one line as well.
Note: Do we need the lst += args[len(lst) :]
? What would be an example of call where we need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is slick! Taken over in 8dc2628, together with the args += extra_args_to_raise_later
addition. What's the email to attribute the commit to?
Note: Do we need the lst += args[len(lst) :]? What would be an example of call where we need it?
Yes. Extra unknown positional arguments. The issue is that zip(short_sequence, longer_sequence)
drops trailing elements from longer_sequence
. Here's a test (np.nonzero
only accepts a single argument):
def test_unknown_args(self):
# Check that unknown args to decorated functions fail
a = w.arange(7) % 2 == 0
# unknown positional args
with assert_raises(TypeError):
> w.nonzero(a, "kaboom")
E Failed: DID NOT RAISE <class 'TypeError'>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, right, I forgot we need to error out on those. In my head I always thought of the non-error case.
Also, I appreciate it, but no need to attribute the commit really :)
torch_np/_wrapper.py
Outdated
|
||
from . import _dtypes, _helpers, _decorators # isort: skip # XXX | ||
from ._ndarray import array, asarray, maybe_set_base, ndarray | ||
from ._normalizations import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference between _funcs.py
and _wrapper.py
? Should we merge them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely. The split is temporary, and we'll cleanly merge them once the cleanups from other PRs further up the stack are done.
At this stage, imports from ._ndarray cause circular imports. So either let's live with the split for a while, or I can bloat this PR with what's up the stack. Reviewer's choice :-).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's avoid copying the args, but otherwise this LGTM. As discussed, let's merge all the PRs as-is (as-are?) and let's try to factor out the out=
kwarg implementation and then have a simple implementation of the wrapping of the outputs.
cc6cc7c
to
f5e5eaf
Compare
An experiment along the lines of the suggestion in #32 : (ab)use type annotations to mark the logic for normalizing user-facing arguments into pytorch tensors and dtypes, reshuffle arguments etc via static types.
A somewhat clumsy bit is normalizing
*args
--- case in point is e.g.atleast_1d(*arys)
.The issue is that
func(*args : Annotation)
gives a single annotation for a runtime-determined number of arguments. There is no way to annotate individual elements of *args AFAICS.Thus register a special annotation to repack args into a tuple and a corresponding normalizer to normalize this tuple.