-
Notifications
You must be signed in to change notification settings - Fork 4
Refactor the internals to better separate wrappers from ops on tensors #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Move dealing with tensors to to _detail._reductions
A drawback of the current version is that the out=... arg needs to be passed as a keyword argument, and it cannot be a positional arg.
decorators in decorators.py know about arrays, _detail does not.
Also mimic numpy, which casts inputs to inexact types (pytorch raises if dtypes differ).
This sounds like a useful split. I'll resist having a closer look until the PR is no longer WIP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty clean overall. A few comments from just reading the diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the semantic difference between torch_np/_util.py
and troch_np/_helpers.py
? Should helpers.py
be merged into util.py
?
Also, feel free to stop prepending underscores in all the files / folders. I don't think it should be needed, and it makes everything a bit odd to read as you don't know what's public, what's private and what's what.
All in all, I think this is certainly going in the right direction. Now the mappings are much cleaner, which is great.
torch_np/_decorators.py
Outdated
# TODO: 1. get rid of _helpers.result_or_out | ||
# 2. sort out function signatures: how they flow through all decorators etc | ||
@functools.wraps(func) | ||
def wrapped(a, axis=None, out=None, keepdims=NoValue, *args, **kwds): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get the feeling that this function does too much. I really liked how you split the dtype and out processing above, but here you have mixed them all together. As a result, you seem to forget handling the out
kwarg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, out=
arg should not be here already (removed in the last commit).
Note the usage is emulate_out_arg(axis_keepdims(...))
: https://github.com/Quansight-Labs/numpy_pytorch_interop/blob/refactor/torch_np/_ndarray.py#L281
This decorator itself simply unwraps ndarrays and passes heavy lifting to _util.axis_keepdims
which does just what it says on the tin: handles axis tuples and keepdims=True
.
Not sure how to simplify it further. Would a comment help?
The fact that the usage is not very clear is I guess a direct consequence of having separate decorators for various arguments. So if it's the direction we want to go, there will be more of this I'm afraid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, there are two general issues here.
I don't see why would you always need to use this in conjunction to emulate_out_arg
. These two functions may be used independently, so having out
in the signature of this wrapper is unnecessary.
Even more, not only is unnecessary, but it's also incorrect. If you use it with emulate_out_arg
afterwards, you happen to get the right behaviour almost by chance. Now, if you swap the order, you'll get a function with the same signature, but that takes an out=
kwarg and... discards it!
As mentioned, I think that this function and the one that implements the out behaviour do two very different things, and should be independent of each other.
Fuse logical ops in normalize_axis_index Co-authored-by: Mario Lezcano Casado <[email protected]>
Maybe it makes sense to merge |
About the A point that was not discussed: Why not removing the underscore that precedes all the names of all the files in the project? About the general structure of the project and what this PR aims to achieve: I think it's going in the right direction, but it can be further improved. The reason I say this is because there are still a number of local imports, which means that there are still some code smells. Proposal for how to structure the projectI propose still further refactor the code as follows, to have a more definite "chain of command" when it comes to which files do what:
We will also need to cast other NumPy-specific arguments into PyTorch specific arguments (e.g. We will have a number of functions that perform normalisation on dtype args, ndarray args, axis args, etc. All these functions will take one parameter and will return a normalised version of it. Summarising what we have until now, we have that in the top level of the
Consider how to implement now a non-trivial example like The point that we have not discussed is how to map the normalisation functions to their relevant args. I propose doing so via some light typing. Let's do a simple example where we omit plenty of kwargs. def add(a: Tensor, b: Tensor, *, dtype: Optional[torch.dtype]) -> Tensor (note that the We would then have a function on the main folder with signature: GenericNdArray = Union[np.ndarray, list[float], NumpyScalar] # Incomplete list of types. NumpyScalar will be its own union type
def add(a: GenericNdArray, b: GenericNdArray, *, dtype: Optional[np.dtype]): Note: We don't need to annotate the returned type. We will see why afterwards. Then, we would have have a number of normalisation functions that would look something like as: import np_refs
from .ndarray import ndarray
def bind(fn):
def inner(np_interface):
fn_params = inspect.signature(np_interface).parameters
# Do this or have it as a global variable. Either works
# You can also have them in their own file and iterate over all the functions in that file
normalizations = [v.__name__.startswith("normalize_") for v in globals() if hasattr(v, "__name__")]
for normalization in normalizations:
fn = normalize(torch_impl, fn_params, normalization)
return fn
return inner
def normalize(fn, fn_params, normalization_fn, condition):
@wraps(fn)
def _fn(*args, **kwargs):
fn_params = inspect.signature(foo).parameters
# edit inplace or should we just zip the parameters and the args create a new list?
for i, param in enumerate(list(fn_params)[:len(args)]):
if condition(arg, fn_params[param]):
args[i] = normalization_fn(args[i])
# same as above
for kwarg, value in kwargs.items():
v = kwargs[kwarg]
if condition(v, fn_params[kwarg]):
kwargs[kwarg] = normalization_fn(v)
return fn(*args, **kwargs)
return _fn
def normalize_ndarray(fn):
""" Wrap parameters that accept any array-like object """
return normalize(fn, wrap_ndarray, lambda val, _type: _type == GenericNdArray)
def normalize_tensor(fn):
""" Wrap parameters that accept ndarrays """
return normalize(fn, torch.from_numpy, lambda val, _type: _type == torch.Tensor)
def normalize_dtype(fn):
...
def normalize_output(fn):
@wraps(fn)
def _fn(*args, **kwargs):
out = fn(*args, **kwargs)
# TODO implement support for tuples
return ndarray(out) if isinstance(out, torch.Tensor) else out
return _fn
@bind(np_refs.add)
def add(a: GenericNdArray, b: GenericNdArray, *, dtype: Optional[np.dtype]):
pass You could even do all this within the |
If there's now an underscore in |
The overall structure you propose seems nice and logical @lezcano.
This I'm not 100% sure about. Typing for the PyTorch layer is straightforward; the |
My point here is that this will serve as some annotations to:
Note that these annotations are not meant to be used to strictly typecheck programs. |
As discussed, let's merge this one as is and discuss the next restructure in #32 |
OK, no, cannot merge this. I am failing to appease the powers of linting to merge this with main with black and all. Am giving up for today. Hopefully @honno would be able to give me a hand with this. |
Introduce the
_detail
subpackage: its contents only operates on pytorch tensors and does not know about wrappers. User-facing functionality mostly deals with array-likes, argument differences between numpy and pytorch, and delegates heavy lifting to functions in_detail
.