Skip to content

bare-bones normalizations via type hints #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Mar 22, 2023
Merged

bare-bones normalizations via type hints #70

merged 33 commits into from
Mar 22, 2023

Conversation

ev-br
Copy link
Collaborator

@ev-br ev-br commented Feb 27, 2023

An experiment along the lines of the suggestion in #32 : (ab)use type annotations to mark the logic for normalizing user-facing arguments into pytorch tensors and dtypes, reshuffle arguments etc via static types.

A somewhat clumsy bit is normalizing *args --- case in point is e.g. atleast_1d(*arys).
The issue is that func(*args : Annotation) gives a single annotation for a runtime-determined number of arguments. There is no way to annotate individual elements of *args AFAICS.

Thus register a special annotation to repack args into a tuple and a corresponding normalizer to normalize this tuple.

Base automatically changed from free_funcs to main February 27, 2023 17:25
@ev-br ev-br mentioned this pull request Feb 28, 2023
@ev-br ev-br force-pushed the normalizations branch 4 times, most recently from c81bb9e to 3dfddd6 Compare March 4, 2023 13:56
@ev-br ev-br changed the title WIP: bare-bones normalizations via type hints bare-bones normalizations via type hints Mar 4, 2023
@ev-br
Copy link
Collaborator Author

ev-br commented Mar 4, 2023

The initial scaffolding is there, so am removing the WIP admonition. While this is not yet complete (ufuncs, reductions), this is something to take a look at, if only to see if a general pattern is tolerable.

@ev-br ev-br requested review from lezcano, honno and rgommers March 4, 2023 13:59
@ev-br ev-br force-pushed the normalizations branch 3 times, most recently from e94647a to edda25b Compare March 10, 2023 14:10
ev-br added 10 commits March 10, 2023 20:16
Gradual (!) typing WTF: only annotate the dtype can get rid of dtype_to_torch decorator.
Annotating SeqArrayLike typing TBD.
This is a bit clumsy: func(*args : Annotation) gives a single annotation for a
runtime-determined number of arguments. There is no way to annotate individual
elements of *args AFAICS.

Thus register a special annotation to repack args into a tuple and a normalizer
to normalize this tuple.
@ev-br
Copy link
Collaborator Author

ev-br commented Mar 11, 2023

This is ready from my side.

Further possible cleanups include

both are somewhat large and need some experimentation, so are best postponed to follow-up PRs.

@ev-br ev-br mentioned this pull request Mar 11, 2023
@lezcano
Copy link
Collaborator

lezcano commented Mar 11, 2023

Just passing by to say that the output one can be done without anotation. Just do a if isintance(out, torch.Tensor): return as_array(out) in the normalization wrapper, and that'll get rid of all those _helper.array_from. You can do similarly for tuples. This can be done because a function's returned value is already "normalised" (it'll always be a tensor or a tuple of tensors or things that we don't want to touch).

For the out= kwarg, we can do it like in PrimTorch:
https://github.com/pytorch/pytorch/blob/ab148da66cb9433effac90c7bd4930a961481d19/torch/_prims_common/wrappers.py#L187
Note that that code is quite tricky, because it handles namedtuples and it also puts in the correct annotation, but you can probably simplify it in our case.

@ev-br
Copy link
Collaborator Author

ev-br commented Mar 11, 2023

Just do a if isintance(out, torch.Tensor): return as_array(out)

Hmm, this breaks in presence of out= argument unless as_array has the out= semantics?

@lezcano
Copy link
Collaborator

lezcano commented Mar 11, 2023

But if we implement the out= kwarg ourselves as a decorator, we should be able make both things work with each other.

@ev-br
Copy link
Collaborator Author

ev-br commented Mar 11, 2023

Certainly can. My point is simply that it's a bit more than just using if isinstance(result, torch.Tensor): return asarray(result), and I'd prefer to limit the scope of this PR and deal with returns in a follow-up.

@lezcano
Copy link
Collaborator

lezcano commented Mar 12, 2023

About annotating variadic args, does this SO solve the question? https://stackoverflow.com/a/37032111/5280578

Nevemind, I see what you mean. Adding a different annotation for this sort of arguments LGTM

@ev-br ev-br mentioned this pull request Mar 16, 2023
6 tasks
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_np/_wrapper starts looking good! Let's revisit the idea of moving the actual implementations of the torch functions to the wrapper after we have merged this PR and the return PR.

Comment on lines 104 to 116
# first, check for *args in positional parameters. Case in point:
# atleast_1d(*arys: UnpackedSequenceArrayLike)
# if found, consume all args into a tuple to normalize as a whole
for j, param in enumerate(sig.parameters.values()):
if param.annotation == UnpackedSeqArrayLike:
if j == 0:
args = (args,)
else:
# args = args[:j] + (args[j:],) would likely work
# not present in numpy codebase, so do not bother just yet.
# NB: branching on j ==0 is to avoid the empty tuple, args[:j]
raise NotImplementedError
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better way to do this is to check below whether param.kind == VAR_POSITIONAL, and then, if so, treat it as a List[T], where T is the annotation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So your suggestion is to check for parm.kind == VAR_POSITIONAL instead of adding a dedicated annotation?
Note that args = (args,) is still needed in some form, in typing language it's not just List[T], it's Union[T, List[T]]
Probably I'm being dense here, would you mind elaborating?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be a matter of adding an if in normalize_this checking whether it's a VAR_POSTIIONAL arg and then process the argument/arguments accordingly.

Copy link
Collaborator Author

@ev-br ev-br Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That won't work, because tnp.atleast(1, 2) gets two arguments and only a single annotation. I can of course check param.kind == VAR_POSITIONAL instead of a special annotation and consume the rest of args instead of checking if param.annotation == UnpackedSeqArrayLike, this does not seem any simpler or clearer, does it. I mean, what's the endgame here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, how about 9d75cab


import torch

# renames
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the different imports here, rather than just do a from torch import ( blablalba )?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is isort from pre-commit run.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Terribly weird. @honno any idea what's going on here?

return tuple(asarray(_) for _ in res)
@normalizer
def broadcast_arrays(*args: UnpackedSeqArrayLike, subok: SubokLike = False):
args = args[0] # undo the *args wrapping in normalizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should hopefully be able to fix this with a proper preprocessing of variadic inputs in the normalizer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 9d75cab

Copy link
Collaborator

@lezcano lezcano Mar 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me make a counter-offer:

def maybe_normalize(arg, parm):
    """Normalize arg if a normalizer is registred."""
    normalizer = normalizers.get(parm.annotation, None)
    return normalizer(arg) if normalizer else arg

def normalizer(func):
    def wrapped(*args, **kwds):
        params = inspect.signature(func).parameters
        first_param = next(iter(params.values()))
        # NumPy's API does not have positional args before variadic positional args
        if first_param.kind == inspect.Parameter.VAR_POSITIONAL:
            args = [maybe_normalize(arg, first_param) for arg in args]
        else:
            args = [maybe_normalize(arg, parm) for arg, parm in zip(args, params.values())]
        kwds = {
            name: maybe_normalize(arg, sp[name]) if name in sp else arg
            for name, arg in kwds.items()
        }
        return func(*args, **kwds)
    return wrapped

If you think that that comment about NumPy's API is not good enough and you want extra safety (I don't think we need it) you assert that there isn't any other parameter that is variadic that's not on the first position. This should be one line as well.

Note: Do we need the lst += args[len(lst) :]? What would be an example of call where we need it?

Copy link
Collaborator Author

@ev-br ev-br Mar 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is slick! Taken over in 8dc2628, together with the args += extra_args_to_raise_later addition. What's the email to attribute the commit to?

Note: Do we need the lst += args[len(lst) :]? What would be an example of call where we need it?

Yes. Extra unknown positional arguments. The issue is that zip(short_sequence, longer_sequence) drops trailing elements from longer_sequence. Here's a test (np.nonzero only accepts a single argument):

    def test_unknown_args(self):
        # Check that unknown args to decorated functions fail
        a = w.arange(7) % 2 == 0
    
        # unknown positional args
        with assert_raises(TypeError):
>           w.nonzero(a, "kaboom")
E           Failed: DID NOT RAISE <class 'TypeError'>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right, I forgot we need to error out on those. In my head I always thought of the non-error case.

Also, I appreciate it, but no need to attribute the commit really :)


from . import _dtypes, _helpers, _decorators # isort: skip # XXX
from ._ndarray import array, asarray, maybe_set_base, ndarray
from ._normalizations import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between _funcs.py and _wrapper.py? Should we merge them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely. The split is temporary, and we'll cleanly merge them once the cleanups from other PRs further up the stack are done.
At this stage, imports from ._ndarray cause circular imports. So either let's live with the split for a while, or I can bloat this PR with what's up the stack. Reviewer's choice :-).

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid copying the args, but otherwise this LGTM. As discussed, let's merge all the PRs as-is (as-are?) and let's try to factor out the out= kwarg implementation and then have a simple implementation of the wrapping of the outputs.

@ev-br ev-br force-pushed the normalizations branch 2 times, most recently from cc6cc7c to f5e5eaf Compare March 22, 2023 23:28
@ev-br ev-br merged commit 023d453 into main Mar 22, 2023
@ev-br ev-br deleted the normalizations branch March 22, 2023 23:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants