Skip to content

Refactor the internals to better separate wrappers from ops on tensors #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
afec98d
MAINT: move several utils to _util from _helpers
ev-br Jan 13, 2023
053e015
MAINT: inline the standardize_axis helper (only used once)
ev-br Jan 13, 2023
5ed0cae
BUG: ndarrays are not hashable
ev-br Jan 13, 2023
50c0384
MAINT: rationalize DTypes
ev-br Jan 14, 2023
83ac0df
BUG: ones_like dtype default to arg
ev-br Jan 14, 2023
a2c0dcd
MAINT: move _scalar_types to the _detail namespace
ev-br Jan 14, 2023
9575b25
BUG: torch.full needs explicit dtype
ev-br Jan 14, 2023
ed4c30d
BUG: correctly infer dtype of a Tensor instance
ev-br Jan 14, 2023
d882d4c
MAINT: consolidate default dtype handling
ev-br Jan 14, 2023
261a035
MAINT: move _util to _detail
ev-br Jan 15, 2023
d3f2c14
MAINT: move can_cast impl to work with torch dtypes
ev-br Jan 15, 2023
e1a5b50
MAINT: make cast_dont_broadcast operate on tensors, simplify concatenate
ev-br Jan 15, 2023
feb52e6
MAINT: move cast_and_broadcast down to _detail/_util
ev-br Jan 16, 2023
e0b9896
MAINT: reductions: split into unwrap to tensors + operate on tensors
ev-br Jan 18, 2023
4dae067
MAINT: reductions: simplify signatures
ev-br Jan 18, 2023
6525969
MAINT: move axis_out_keepdims to _util/_decorators
ev-br Jan 21, 2023
d3c1afb
MAINT: move core asarray(...) logic to util / make torch only
ev-br Jan 22, 2023
b34afdb
MAINT: split ufunc wrappers to _detail (tensors only) and decorators
ev-br Jan 23, 2023
03d75e0
MAINT: simplify ndarray.__dunders__
ev-br Jan 23, 2023
fe7b8b4
MAINT: move axis_keepdims_wrapper to decorators
ev-br Jan 23, 2023
09655c9
MAINT: simplify ndarray.__richcompare__
ev-br Jan 23, 2023
1d3995f
BUG: fix isclose, actually check if args are close
ev-br Jan 23, 2023
d568f2a
MAINT: remove stale comments
ev-br Jan 24, 2023
214e0f1
MAINT: address review comments
ev-br Jan 25, 2023
e3fdc3b
Update torch_np/_detail/_util.py
ev-br Jan 26, 2023
7cf3344
MAINT: address review comments
ev-br Jan 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions autogen/gen_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,18 @@ def __init__(self, name):
'int64',
'bool']


tmap = {dt: torch.as_tensor(np.ones(1, dtype=dt)).dtype for dt in dt_names}



templ = """\
{name} = dtype("{name}")
"""




############### Output the dtypes #############

src_lines = [templ.format(name=name) for name in dt_names]
Expand All @@ -51,8 +58,8 @@ def generate_can_cast(casting):
for dtyp2 in dt_names:
can_cast = np.can_cast(np.dtype(dtyp1), np.dtype(dtyp2),
casting=casting)
dct_dtyp1[dtyp2] = can_cast
dct[dtyp1] = dct_dtyp1
dct_dtyp1[tmap[dtyp2]] = can_cast
dct[tmap[dtyp1]] = dct_dtyp1
return dct


Expand All @@ -63,8 +70,8 @@ def generate_result_type():
dct_dtyp1 = {}
for dtyp2 in dt_names:
result_type = np.result_type(np.dtype(dtyp1), np.dtype(dtyp2))
dct_dtyp1[dtyp2] = result_type.name
dct[dtyp1] = dct_dtyp1
dct_dtyp1[tmap[dtyp2]] = tmap[result_type.name]
dct[tmap[dtyp1]] = dct_dtyp1
return dct


Expand Down
9 changes: 2 additions & 7 deletions torch_np/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from ._dtypes import *
from ._scalar_types import *
from ._detail._scalar_types import *
from ._wrapper import *
#from . import testing

from ._unary_ufuncs import *
from ._binary_ufuncs import *
from ._ndarray import can_cast, result_type, newaxis
from ._util import AxisError, UFuncTypeError
from ._detail._util import AxisError, UFuncTypeError
from ._getlimits import iinfo, finfo
from ._getlimits import errstate

inf = float('inf')
nan = float('nan')


#### HACK HACK HACK ####
import torch
torch.set_default_dtype(torch.float64)
del torch
34 changes: 4 additions & 30 deletions torch_np/_binary_ufuncs.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,13 @@
import functools
import torch
from ._decorators import deco_ufunc_from_impl
from ._detail import _ufunc_impl

from . import _util

from ._ndarray import asarray
from . import _dtypes
from . import _helpers

from . import _ufunc_impl

#
# Functions in _ufunc_impl receive arrays, implement common tasks with ufunc args
# and delegate heavy lifting to pytorch equivalents.
#
# Functions in this file implement binary ufuncs: wrap two first arguments in
# asarray and delegate to functions from _ufunc_impl.
#
# One other user of _ufunc_impl functions in ndarray, where its __add__ method
# calls _ufunc_impl.add and so on. Note that ndarray dunders already know
# that its first arg is an array, so they only convert the second argument.
# Functions in _detail/_ufunc_impl.py receive tensors, implement common tasks
# with ufunc args, and delegate heavy lifting to pytorch equivalents.
#
# XXX: While it sounds tempting to merge _binary_ufuncs.py and _ufunc_impl.py
# files, doing it would currently create import cycles.
#

# TODO: deduplicate with _unary_ufuncs/deco_unary_ufunc_from_impl,
# _ndarray/asarray_replacer, and _wrapper/concatenate et al
def deco_ufunc_from_impl(impl_func):
@functools.wraps(impl_func)
def wrapped(x1, x2, *args, **kwds):
x1_array = asarray(x1)
x2_array = asarray(x2)
return impl_func(x1_array, x2_array, *args, **kwds)
return wrapped


# the list is autogenerated, cf autogen/gen_ufunc_2.py
add = deco_ufunc_from_impl(_ufunc_impl.add)
Expand Down
112 changes: 112 additions & 0 deletions torch_np/_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import functools
import operator

import torch

from . import _dtypes
from . import _helpers
from ._detail import _util

NoValue = None


def dtype_to_torch(func):
## @functools.wraps # XXX: why
def wrapped(*args, dtype=None, **kwds):
torch_dtype = None
if dtype is not None:
dtype = _dtypes.dtype(dtype)
torch_dtype = dtype._scalar_type.torch_dtype
return func(*args, dtype=torch_dtype, **kwds)
return wrapped


def emulate_out_arg(func):
"""Simulate the out=... handling *for functions which do not need it*.

With this decorator, the inner function just does not see the out array.
"""
def wrapped(*args, out=None, **kwds):
from ._ndarray import ndarray
if out is not None:
if not isinstance(out, ndarray):
raise TypeError("Return arrays must be of ArrayType")
result_tensor = func(*args, **kwds)
return _helpers.result_or_out(result_tensor, out)

return wrapped


def out_shape_dtype(func):
"""Handle out=... kwarg for ufuncs.

With ufuncs, `out` array can typcast and broadcast ufunc arguments, hence
extract the shape and dtype of the tensor which backs the `out` array
and pass these through.
"""
def wrapped(*args, out=None, **kwds):
from ._ndarray import ndarray
if out is not None:
if not isinstance(out, ndarray):
raise TypeError("Return arrays must be of ArrayType")
kwds.update({'out_shape_dtype': (out.get().dtype, out.get().shape)})
result_tensor = func(*args, **kwds)
return _helpers.result_or_out(result_tensor, out)

return wrapped


def deco_unary_ufunc_from_impl(impl_func):
@functools.wraps(impl_func)
@dtype_to_torch
@out_shape_dtype
def wrapped(x1, *args, **kwds):
from ._ndarray import asarray
x1_tensor = asarray(x1).get()
result = impl_func((x1_tensor,), *args, **kwds)
return result
return wrapped


# TODO: deduplicate with _ndarray/asarray_replacer,
# and _wrapper/concatenate et al
def deco_ufunc_from_impl(impl_func):
@functools.wraps(impl_func)
@dtype_to_torch
@out_shape_dtype
def wrapped(x1, x2, *args, **kwds):
from ._ndarray import asarray
x1_tensor = asarray(x1).get()
x2_tensor = asarray(x2).get()
return impl_func((x1_tensor, x2_tensor), *args, **kwds)
return wrapped



def axis_keepdims_wrapper(func):
"""`func` accepts an array-like as a 1st arg, returns a tensor.

This decorator implements the generic handling of axis, out and keepdims
arguments for reduction functions.

Note that we peel off `out=...` and `keepdims=...` args (torch functions never
see them). The `axis` argument we normalize and pass through to pytorch functions.

"""
# XXX: move this out of _ndarray.py (circular imports)
#
# TODO: 1. get rid of _helpers.result_or_out
# 2. sort out function signatures: how they flow through all decorators etc
@functools.wraps(func)
def wrapped(a, axis=None, out=None, keepdims=NoValue, *args, **kwds):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get the feeling that this function does too much. I really liked how you split the dtype and out processing above, but here you have mixed them all together. As a result, you seem to forget handling the out kwarg.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, out= arg should not be here already (removed in the last commit).
Note the usage is emulate_out_arg(axis_keepdims(...)): https://github.com/Quansight-Labs/numpy_pytorch_interop/blob/refactor/torch_np/_ndarray.py#L281

This decorator itself simply unwraps ndarrays and passes heavy lifting to _util.axis_keepdims which does just what it says on the tin: handles axis tuples and keepdims=True.

Not sure how to simplify it further. Would a comment help?

The fact that the usage is not very clear is I guess a direct consequence of having separate decorators for various arguments. So if it's the direction we want to go, there will be more of this I'm afraid.

Copy link
Collaborator

@lezcano lezcano Jan 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, there are two general issues here.
I don't see why would you always need to use this in conjunction to emulate_out_arg. These two functions may be used independently, so having out in the signature of this wrapper is unnecessary.
Even more, not only is unnecessary, but it's also incorrect. If you use it with emulate_out_arg afterwards, you happen to get the right behaviour almost by chance. Now, if you swap the order, you'll get a function with the same signature, but that takes an out= kwarg and... discards it!

As mentioned, I think that this function and the one that implements the out behaviour do two very different things, and should be independent of each other.

from ._ndarray import ndarray, asarray
tensor = asarray(a).get()

# standardize the axis argument
if isinstance(axis, ndarray):
axis = operator.index(axis)

result = _util.axis_keepdims(func, tensor, axis, keepdims, *args, **kwds)
return result

return wrapped
Empty file added torch_np/_detail/__init__.py
Empty file.
Loading