Skip to content

Refactor the internals to better separate wrappers from ops on tensors #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
afec98d
MAINT: move several utils to _util from _helpers
ev-br Jan 13, 2023
053e015
MAINT: inline the standardize_axis helper (only used once)
ev-br Jan 13, 2023
5ed0cae
BUG: ndarrays are not hashable
ev-br Jan 13, 2023
50c0384
MAINT: rationalize DTypes
ev-br Jan 14, 2023
83ac0df
BUG: ones_like dtype default to arg
ev-br Jan 14, 2023
a2c0dcd
MAINT: move _scalar_types to the _detail namespace
ev-br Jan 14, 2023
9575b25
BUG: torch.full needs explicit dtype
ev-br Jan 14, 2023
ed4c30d
BUG: correctly infer dtype of a Tensor instance
ev-br Jan 14, 2023
d882d4c
MAINT: consolidate default dtype handling
ev-br Jan 14, 2023
261a035
MAINT: move _util to _detail
ev-br Jan 15, 2023
d3f2c14
MAINT: move can_cast impl to work with torch dtypes
ev-br Jan 15, 2023
e1a5b50
MAINT: make cast_dont_broadcast operate on tensors, simplify concatenate
ev-br Jan 15, 2023
feb52e6
MAINT: move cast_and_broadcast down to _detail/_util
ev-br Jan 16, 2023
e0b9896
MAINT: reductions: split into unwrap to tensors + operate on tensors
ev-br Jan 18, 2023
4dae067
MAINT: reductions: simplify signatures
ev-br Jan 18, 2023
6525969
MAINT: move axis_out_keepdims to _util/_decorators
ev-br Jan 21, 2023
d3c1afb
MAINT: move core asarray(...) logic to util / make torch only
ev-br Jan 22, 2023
b34afdb
MAINT: split ufunc wrappers to _detail (tensors only) and decorators
ev-br Jan 23, 2023
03d75e0
MAINT: simplify ndarray.__dunders__
ev-br Jan 23, 2023
fe7b8b4
MAINT: move axis_keepdims_wrapper to decorators
ev-br Jan 23, 2023
09655c9
MAINT: simplify ndarray.__richcompare__
ev-br Jan 23, 2023
1d3995f
BUG: fix isclose, actually check if args are close
ev-br Jan 23, 2023
d568f2a
MAINT: remove stale comments
ev-br Jan 24, 2023
214e0f1
MAINT: address review comments
ev-br Jan 25, 2023
e3fdc3b
Update torch_np/_detail/_util.py
ev-br Jan 26, 2023
7cf3344
MAINT: address review comments
ev-br Jan 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions autogen/gen_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,18 @@ def __init__(self, name):
'int64',
'bool']


tmap = {dt: torch.as_tensor(np.ones(1, dtype=dt)).dtype for dt in dt_names}



templ = """\
{name} = dtype("{name}")
"""




############### Output the dtypes #############

src_lines = [templ.format(name=name) for name in dt_names]
Expand All @@ -51,8 +58,8 @@ def generate_can_cast(casting):
for dtyp2 in dt_names:
can_cast = np.can_cast(np.dtype(dtyp1), np.dtype(dtyp2),
casting=casting)
dct_dtyp1[dtyp2] = can_cast
dct[dtyp1] = dct_dtyp1
dct_dtyp1[tmap[dtyp2]] = can_cast
dct[tmap[dtyp1]] = dct_dtyp1
return dct


Expand All @@ -63,8 +70,8 @@ def generate_result_type():
dct_dtyp1 = {}
for dtyp2 in dt_names:
result_type = np.result_type(np.dtype(dtyp1), np.dtype(dtyp2))
dct_dtyp1[dtyp2] = result_type.name
dct[dtyp1] = dct_dtyp1
dct_dtyp1[tmap[dtyp2]] = tmap[result_type.name]
dct[tmap[dtyp1]] = dct_dtyp1
return dct


Expand Down
9 changes: 2 additions & 7 deletions torch_np/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from ._dtypes import *
from ._scalar_types import *
from ._detail._scalar_types import *
from ._wrapper import *
#from . import testing

from ._unary_ufuncs import *
from ._binary_ufuncs import *
from ._ndarray import can_cast, result_type, newaxis
from ._util import AxisError, UFuncTypeError
from ._detail._util import AxisError, UFuncTypeError
from ._getlimits import iinfo, finfo
from ._getlimits import errstate

inf = float('inf')
nan = float('nan')


#### HACK HACK HACK ####
import torch
torch.set_default_dtype(torch.float64)
del torch
114 changes: 44 additions & 70 deletions torch_np/_binary_ufuncs.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,53 @@
import functools
import torch
from ._decorators import deco_binary_ufunc_from_impl
from ._detail import _ufunc_impl

from . import _util

from ._ndarray import asarray
from . import _dtypes
from . import _helpers

from . import _ufunc_impl

#
# Functions in _ufunc_impl receive arrays, implement common tasks with ufunc args
# and delegate heavy lifting to pytorch equivalents.
#
# Functions in this file implement binary ufuncs: wrap two first arguments in
# asarray and delegate to functions from _ufunc_impl.
#
# One other user of _ufunc_impl functions in ndarray, where its __add__ method
# calls _ufunc_impl.add and so on. Note that ndarray dunders already know
# that its first arg is an array, so they only convert the second argument.
# Functions in _detail/_ufunc_impl.py receive tensors, implement common tasks
# with ufunc args, and delegate heavy lifting to pytorch equivalents.
#
# XXX: While it sounds tempting to merge _binary_ufuncs.py and _ufunc_impl.py
# files, doing it would currently create import cycles.
#

# TODO: deduplicate with _unary_ufuncs/deco_unary_ufunc_from_impl,
# _ndarray/asarray_replacer, and _wrapper/concatenate et al
def deco_ufunc_from_impl(impl_func):
@functools.wraps(impl_func)
def wrapped(x1, x2, *args, **kwds):
x1_array = asarray(x1)
x2_array = asarray(x2)
return impl_func(x1_array, x2_array, *args, **kwds)
return wrapped


# the list is autogenerated, cf autogen/gen_ufunc_2.py
add = deco_ufunc_from_impl(_ufunc_impl.add)
arctan2 = deco_ufunc_from_impl(_ufunc_impl.arctan2)
bitwise_and = deco_ufunc_from_impl(_ufunc_impl.bitwise_and)
bitwise_or = deco_ufunc_from_impl(_ufunc_impl.bitwise_or)
bitwise_xor = deco_ufunc_from_impl(_ufunc_impl.bitwise_xor)
copysign = deco_ufunc_from_impl(_ufunc_impl.copysign)
divide = deco_ufunc_from_impl(_ufunc_impl.divide)
equal = deco_ufunc_from_impl(_ufunc_impl.equal)
float_power = deco_ufunc_from_impl(_ufunc_impl.float_power)
floor_divide = deco_ufunc_from_impl(_ufunc_impl.floor_divide)
fmax = deco_ufunc_from_impl(_ufunc_impl.fmax)
fmin = deco_ufunc_from_impl(_ufunc_impl.fmin)
fmod = deco_ufunc_from_impl(_ufunc_impl.fmod)
gcd = deco_ufunc_from_impl(_ufunc_impl.gcd)
greater = deco_ufunc_from_impl(_ufunc_impl.greater)
greater_equal = deco_ufunc_from_impl(_ufunc_impl.greater_equal)
heaviside = deco_ufunc_from_impl(_ufunc_impl.heaviside)
hypot = deco_ufunc_from_impl(_ufunc_impl.hypot)
lcm = deco_ufunc_from_impl(_ufunc_impl.lcm)
ldexp = deco_ufunc_from_impl(_ufunc_impl.ldexp)
left_shift = deco_ufunc_from_impl(_ufunc_impl.left_shift)
less = deco_ufunc_from_impl(_ufunc_impl.less)
less_equal = deco_ufunc_from_impl(_ufunc_impl.less_equal)
logaddexp = deco_ufunc_from_impl(_ufunc_impl.logaddexp)
logaddexp2 = deco_ufunc_from_impl(_ufunc_impl.logaddexp2)
logical_and = deco_ufunc_from_impl(_ufunc_impl.logical_and)
logical_or = deco_ufunc_from_impl(_ufunc_impl.logical_or)
logical_xor = deco_ufunc_from_impl(_ufunc_impl.logical_xor)
matmul = deco_ufunc_from_impl(_ufunc_impl.matmul)
maximum = deco_ufunc_from_impl(_ufunc_impl.maximum)
minimum = deco_ufunc_from_impl(_ufunc_impl.minimum)
remainder = deco_ufunc_from_impl(_ufunc_impl.remainder)
multiply = deco_ufunc_from_impl(_ufunc_impl.multiply)
nextafter = deco_ufunc_from_impl(_ufunc_impl.nextafter)
not_equal = deco_ufunc_from_impl(_ufunc_impl.not_equal)
power = deco_ufunc_from_impl(_ufunc_impl.power)
remainder = deco_ufunc_from_impl(_ufunc_impl.remainder)
right_shift = deco_ufunc_from_impl(_ufunc_impl.right_shift)
subtract = deco_ufunc_from_impl(_ufunc_impl.subtract)
divide = deco_ufunc_from_impl(_ufunc_impl.divide)
add = deco_binary_ufunc_from_impl(_ufunc_impl.add)
arctan2 = deco_binary_ufunc_from_impl(_ufunc_impl.arctan2)
bitwise_and = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_and)
bitwise_or = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_or)
bitwise_xor = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_xor)
copysign = deco_binary_ufunc_from_impl(_ufunc_impl.copysign)
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)
equal = deco_binary_ufunc_from_impl(_ufunc_impl.equal)
float_power = deco_binary_ufunc_from_impl(_ufunc_impl.float_power)
floor_divide = deco_binary_ufunc_from_impl(_ufunc_impl.floor_divide)
fmax = deco_binary_ufunc_from_impl(_ufunc_impl.fmax)
fmin = deco_binary_ufunc_from_impl(_ufunc_impl.fmin)
fmod = deco_binary_ufunc_from_impl(_ufunc_impl.fmod)
gcd = deco_binary_ufunc_from_impl(_ufunc_impl.gcd)
greater = deco_binary_ufunc_from_impl(_ufunc_impl.greater)
greater_equal = deco_binary_ufunc_from_impl(_ufunc_impl.greater_equal)
heaviside = deco_binary_ufunc_from_impl(_ufunc_impl.heaviside)
hypot = deco_binary_ufunc_from_impl(_ufunc_impl.hypot)
lcm = deco_binary_ufunc_from_impl(_ufunc_impl.lcm)
ldexp = deco_binary_ufunc_from_impl(_ufunc_impl.ldexp)
left_shift = deco_binary_ufunc_from_impl(_ufunc_impl.left_shift)
less = deco_binary_ufunc_from_impl(_ufunc_impl.less)
less_equal = deco_binary_ufunc_from_impl(_ufunc_impl.less_equal)
logaddexp = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp)
logaddexp2 = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp2)
logical_and = deco_binary_ufunc_from_impl(_ufunc_impl.logical_and)
logical_or = deco_binary_ufunc_from_impl(_ufunc_impl.logical_or)
logical_xor = deco_binary_ufunc_from_impl(_ufunc_impl.logical_xor)
matmul = deco_binary_ufunc_from_impl(_ufunc_impl.matmul)
maximum = deco_binary_ufunc_from_impl(_ufunc_impl.maximum)
minimum = deco_binary_ufunc_from_impl(_ufunc_impl.minimum)
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
multiply = deco_binary_ufunc_from_impl(_ufunc_impl.multiply)
nextafter = deco_binary_ufunc_from_impl(_ufunc_impl.nextafter)
not_equal = deco_binary_ufunc_from_impl(_ufunc_impl.not_equal)
power = deco_binary_ufunc_from_impl(_ufunc_impl.power)
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
right_shift = deco_binary_ufunc_from_impl(_ufunc_impl.right_shift)
subtract = deco_binary_ufunc_from_impl(_ufunc_impl.subtract)
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)

107 changes: 107 additions & 0 deletions torch_np/_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import functools
import operator

import torch

from . import _dtypes
from . import _helpers
from ._detail import _util

NoValue = None


def dtype_to_torch(func):
@functools.wraps(func)
def wrapped(*args, dtype=None, **kwds):
torch_dtype = None
if dtype is not None:
dtype = _dtypes.dtype(dtype)
torch_dtype = dtype._scalar_type.torch_dtype
return func(*args, dtype=torch_dtype, **kwds)
return wrapped


def emulate_out_arg(func):
"""Simulate the out=... handling: move the result tensor to the out array.

With this decorator, the inner function just does not see the out array.
"""
@functools.wraps(func)
def wrapped(*args, out=None, **kwds):
result_tensor = func(*args, **kwds)
return _helpers.result_or_out(result_tensor, out)

return wrapped


def out_shape_dtype(func):
"""Handle out=... kwarg for ufuncs.

With ufuncs, `out` array can typcast and broadcast ufunc arguments, hence
extract the shape and dtype of the tensor which backs the `out` array
and pass these through.
"""
@functools.wraps(func)
def wrapped(*args, out=None, **kwds):
if out is not None:
kwds.update({'out_shape_dtype': (out.get().dtype, out.get().shape)})
result_tensor = func(*args, **kwds)
return _helpers.result_or_out(result_tensor, out)

return wrapped


def deco_unary_ufunc_from_impl(impl_func):
@functools.wraps(impl_func)
@dtype_to_torch
@out_shape_dtype
def wrapped(x1, *args, **kwds):
from ._ndarray import asarray
x1_tensor = asarray(x1).get()
result = impl_func((x1_tensor,), *args, **kwds)
return result
return wrapped


# TODO: deduplicate with _ndarray/asarray_replacer,
# and _wrapper/concatenate et al
def deco_binary_ufunc_from_impl(impl_func):
@functools.wraps(impl_func)
@dtype_to_torch
@out_shape_dtype
def wrapped(x1, x2, *args, **kwds):
from ._ndarray import asarray
x1_tensor = asarray(x1).get()
x2_tensor = asarray(x2).get()
return impl_func((x1_tensor, x2_tensor), *args, **kwds)
return wrapped



def axis_keepdims_wrapper(func):
"""`func` accepts an array-like as a 1st arg, returns a tensor.

This decorator implements the generic handling of axis, out and keepdims
arguments for reduction functions.

Note that we peel off `out=...` and `keepdims=...` args (torch functions never
see them). The `axis` argument we normalize and pass through to pytorch functions.

"""
# XXX: move this out of _ndarray.py (circular imports)
#
# TODO: 1. get rid of _helpers.result_or_out
# 2. sort out function signatures: how they flow through all decorators etc
@functools.wraps(func)
def wrapped(a, axis=None, keepdims=NoValue, *args, **kwds):
from ._ndarray import ndarray, asarray
tensor = asarray(a).get()

# standardize the axis argument
if isinstance(axis, ndarray):
axis = operator.index(axis)

result = _util.axis_keepdims(func, tensor, axis, keepdims, *args, **kwds)
return result

return wrapped
Empty file added torch_np/_detail/__init__.py
Empty file.
Loading