Skip to content

Refactor the internals to better separate wrappers from ops on tensors #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jan 31, 2023

Conversation

ev-br
Copy link
Collaborator

@ev-br ev-br commented Jan 24, 2023

Introduce the _detail subpackage: its contents only operates on pytorch tensors and does not know about wrappers. User-facing functionality mostly deals with array-likes, argument differences between numpy and pytorch, and delegates heavy lifting to functions in _detail.

ev-br added 23 commits January 13, 2023 10:50
Move dealing with tensors to to _detail._reductions
A drawback of the current version is that the out=... arg
needs to be passed as a keyword argument, and it cannot be a positional arg.
decorators in decorators.py know about arrays, _detail does not.
Also mimic numpy, which casts inputs to inexact types (pytorch raises if
dtypes differ).
@rgommers
Copy link
Member

This sounds like a useful split. I'll resist having a closer look until the PR is no longer WIP.

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty clean overall. A few comments from just reading the diff.

@lezcano lezcano self-requested a review January 24, 2023 16:42
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the semantic difference between torch_np/_util.py and troch_np/_helpers.py? Should helpers.py be merged into util.py?

Also, feel free to stop prepending underscores in all the files / folders. I don't think it should be needed, and it makes everything a bit odd to read as you don't know what's public, what's private and what's what.

All in all, I think this is certainly going in the right direction. Now the mappings are much cleaner, which is great.

# TODO: 1. get rid of _helpers.result_or_out
# 2. sort out function signatures: how they flow through all decorators etc
@functools.wraps(func)
def wrapped(a, axis=None, out=None, keepdims=NoValue, *args, **kwds):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get the feeling that this function does too much. I really liked how you split the dtype and out processing above, but here you have mixed them all together. As a result, you seem to forget handling the out kwarg.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, out= arg should not be here already (removed in the last commit).
Note the usage is emulate_out_arg(axis_keepdims(...)): https://github.com/Quansight-Labs/numpy_pytorch_interop/blob/refactor/torch_np/_ndarray.py#L281

This decorator itself simply unwraps ndarrays and passes heavy lifting to _util.axis_keepdims which does just what it says on the tin: handles axis tuples and keepdims=True.

Not sure how to simplify it further. Would a comment help?

The fact that the usage is not very clear is I guess a direct consequence of having separate decorators for various arguments. So if it's the direction we want to go, there will be more of this I'm afraid.

Copy link
Collaborator

@lezcano lezcano Jan 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, there are two general issues here.
I don't see why would you always need to use this in conjunction to emulate_out_arg. These two functions may be used independently, so having out in the signature of this wrapper is unnecessary.
Even more, not only is unnecessary, but it's also incorrect. If you use it with emulate_out_arg afterwards, you happen to get the right behaviour almost by chance. Now, if you swap the order, you'll get a function with the same signature, but that takes an out= kwarg and... discards it!

As mentioned, I think that this function and the one that implements the out behaviour do two very different things, and should be independent of each other.

Fuse logical ops in normalize_axis_index

Co-authored-by: Mario Lezcano Casado <[email protected]>
@ev-br
Copy link
Collaborator Author

ev-br commented Jan 27, 2023

What's the semantic difference between torch_np/_util.py and troch_np/_helpers.py? Should helpers.py be merged into util.py?

torch_np/_detail/_util.py is in the _detail subpackage, hence it only knows about pytorch tensors, and knows nothing about wrapper ndarrays.

torch_np/_helpers.py is a collection of small helpers which do know about wrapper ndarrays and handle things like extracting tensors from ndarrays before handing the work over to utilities in _detail.py.

Maybe it makes sense to merge helpers.py and decorators.py, if that's wanted (I personally am on a fence).

@ev-br ev-br changed the title WIP: Refactor the internals to better separate wrappers from ops on tensors Refactor the internals to better separate wrappers from ops on tensors Jan 28, 2023
@lezcano
Copy link
Collaborator

lezcano commented Jan 30, 2023

About the _helpers.py and _utils.py: If _helpers.py contains decorators, I would rename it to _decorators.py. That would simplify things.

A point that was not discussed: Why not removing the underscore that precedes all the names of all the files in the project?

About the general structure of the project and what this PR aims to achieve: I think it's going in the right direction, but it can be further improved. The reason I say this is because there are still a number of local imports, which means that there are still some code smells.

Proposal for how to structure the project

I propose still further refactor the code as follows, to have a more definite "chain of command" when it comes to which files do what:

  • We have an ndarray class in torch_np/ndarray.py. This class is nothing but a wrapper around a torch.Tensor.
  • Within torch_np/ndarray.py we have a function wrap_ndarray that wraps an np.ndarray or a list or any other weird thing NumPy accepts into an ndarray.ndarray (i.e. it generates from those a torch.Tensor and wraps it in an ndarray.ndarray).

wrap_ndarray will need some support functionality, like that in torch_np/{dtypes,scalar_types}.py and otherwise.

We will also need to cast other NumPy-specific arguments into PyTorch specific arguments (e.g. np.float32 -> torch.float32). We will call this process of going from NumPy args to PyTorch args (also the one performed by wrap_ndarray) "normalisation".

We will have a number of functions that perform normalisation on dtype args, ndarray args, axis args, etc. All these functions will take one parameter and will return a normalised version of it.

Summarising what we have until now, we have that in the top level of the torch_np/ folder we have an ndarray class and a number of normalisation functions. We have not discussed yet how to use these normalisation functions, we'll do that in a bit.

  • In the inner folder, (we may call it np_refs, in reference to _refs/ in the PrimTorch project), we will implement all the relevant mappings.
    • These will assume that the inputs are torch.Tensors or normalised inputs, and they will return torch.Tensors. These functions will, of course, implement the NumPy behaviour using torch.Tensor.
    • These functions will always assume that their inputs are normalised.
  • In this folder we will also implement wrappers that help implementing generic functionality. One such a example is the out_wrapper, but there may be others. See for example the wrappers in https://github.com/pytorch/pytorch/blob/master/torch/_prims_common/wrappers.py.

Consider how to implement now a non-trivial example like concatenate. This function needs to perform some non-standard preprocessing on the inputs (and perhaps some non-standard postprocessing). Given that we cannot normalise the inputs to this function in a generic way, we will implement a specific wrapper for it in the top level folder that does the preprocessing of the inputs (does all the necessary checks and sets the relevant dtype correctly). Then, these normalised inputs are passed to the concatenate function, which would just see torch.Tensors and torch.dtypes and will return a torch.Tensor. And that's it.

The point that we have not discussed is how to map the normalisation functions to their relevant args. I propose doing so via some light typing. Let's do a simple example where we omit plenty of kwargs.

def add(a: Tensor, b: Tensor, *, dtype: Optional[torch.dtype]) -> Tensor

(note that the dtype kwarg may be implemented generically via a decorator, or even better, all those params could be implemented generically for all ufuncs.

We would then have a function on the main folder with signature:

GenericNdArray = Union[np.ndarray, list[float], NumpyScalar]  # Incomplete list of types. NumpyScalar will be its own union type

def add(a: GenericNdArray, b: GenericNdArray, *, dtype: Optional[np.dtype]):

Note: We don't need to annotate the returned type. We will see why afterwards.

Then, we would have have a number of normalisation functions that would look something like as:

import np_refs
from .ndarray import ndarray

def bind(fn):
    def inner(np_interface):
        fn_params = inspect.signature(np_interface).parameters
        # Do this or have it as a global variable. Either works
        # You can also have them in their own file and iterate over all the functions in that file
        normalizations = [v.__name__.startswith("normalize_") for v in globals() if hasattr(v, "__name__")]
        for normalization in normalizations:
            fn = normalize(torch_impl, fn_params, normalization)
        return fn
    return inner


def normalize(fn, fn_params, normalization_fn, condition):
    @wraps(fn)
    def _fn(*args, **kwargs):
        fn_params = inspect.signature(foo).parameters

        # edit inplace or should we just zip the parameters and the args create a new list?
        for i, param in enumerate(list(fn_params)[:len(args)]):
            if condition(arg, fn_params[param]):
                args[i] = normalization_fn(args[i])

        # same as above
        for kwarg, value in kwargs.items():
            v = kwargs[kwarg]
            if condition(v, fn_params[kwarg]):
                kwargs[kwarg] = normalization_fn(v)
        return fn(*args, **kwargs)

    return _fn


def normalize_ndarray(fn):
    """ Wrap parameters that accept any array-like object """
    return normalize(fn, wrap_ndarray, lambda val, _type: _type == GenericNdArray)

def normalize_tensor(fn):
    """ Wrap parameters that accept ndarrays """
    return normalize(fn, torch.from_numpy, lambda val, _type: _type == torch.Tensor)

def normalize_dtype(fn):
    ...

def normalize_output(fn):
    @wraps(fn)
    def _fn(*args, **kwargs):
        out = fn(*args, **kwargs)
        # TODO implement support for tuples
        return ndarray(out) if isinstance(out, torch.Tensor) else out

    return _fn


@bind(np_refs.add)
def add(a: GenericNdArray, b: GenericNdArray, *, dtype: Optional[np.dtype]):
    pass

You could even do all this within the ndarray.py file, and then afterwards create in a forloop all the ndarray methods from these.

@rgommers
Copy link
Member

A point that was not discussed: Why not removing the underscore that precedes all the names of all the files in the project?

If there's now an underscore in _detail then the files in that subdir don't need an underscore. In general it's very good practice to add underscores to filenames that are not public modules. Most projects aren't great at it, but it's the only way to prevent people doing import somepkg.some_private_file and it looking public.

@rgommers
Copy link
Member

The overall structure you propose seems nice and logical @lezcano.

The point that we have not discussed is how to map the normalisation functions to their relevant args. I propose doing so via some light typing.

This I'm not 100% sure about. Typing for the PyTorch layer is straightforward; the GenericNdArray part is not. It's not completely clear that an incomplete set of types in that union will be sufficient. And it's very hard to be complete here - have a look at how NumPy implements its NDArray static type will show that quickly enough.

@lezcano
Copy link
Collaborator

lezcano commented Jan 30, 2023

It's not completely clear that an incomplete set of types in that union will be sufficient. And it's very hard to be complete here

My point here is that this will serve as some annotations to:

  • Have the tools to apply the normalisation functions via introspection
  • Have a place where we sort of annotate (a subset of) the types the normalisation functions support. The set of supported types may be larger just by chance, but that's alright.

Note that these annotations are not meant to be used to strictly typecheck programs.

@ev-br
Copy link
Collaborator Author

ev-br commented Jan 30, 2023

As discussed, let's merge this one as is and discuss the next restructure in #32

@ev-br ev-br mentioned this pull request Jan 30, 2023
@ev-br
Copy link
Collaborator Author

ev-br commented Jan 30, 2023

OK, no, cannot merge this. I am failing to appease the powers of linting to merge this with main with black and all. Am giving up for today. Hopefully @honno would be able to give me a hand with this.

@ev-br ev-br merged commit 15e6704 into main Jan 31, 2023
@ev-br ev-br deleted the refactor branch January 31, 2023 10:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants