-
Notifications
You must be signed in to change notification settings - Fork 4
Refactor the internals to better separate wrappers from ops on tensors #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 23 commits
afec98d
053e015
5ed0cae
50c0384
83ac0df
a2c0dcd
9575b25
ed4c30d
d882d4c
261a035
d3f2c14
e1a5b50
feb52e6
e0b9896
4dae067
6525969
d3c1afb
b34afdb
03d75e0
fe7b8b4
09655c9
1d3995f
d568f2a
214e0f1
e3fdc3b
7cf3344
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,15 @@ | ||
from ._dtypes import * | ||
from ._scalar_types import * | ||
from ._detail._scalar_types import * | ||
from ._wrapper import * | ||
#from . import testing | ||
|
||
from ._unary_ufuncs import * | ||
from ._binary_ufuncs import * | ||
from ._ndarray import can_cast, result_type, newaxis | ||
from ._util import AxisError, UFuncTypeError | ||
from ._detail._util import AxisError, UFuncTypeError | ||
from ._getlimits import iinfo, finfo | ||
from ._getlimits import errstate | ||
|
||
inf = float('inf') | ||
nan = float('nan') | ||
|
||
|
||
#### HACK HACK HACK #### | ||
import torch | ||
torch.set_default_dtype(torch.float64) | ||
del torch |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import functools | ||
import operator | ||
|
||
import torch | ||
|
||
from . import _dtypes | ||
from . import _helpers | ||
from ._detail import _util | ||
|
||
NoValue = None | ||
|
||
|
||
def dtype_to_torch(func): | ||
## @functools.wraps # XXX: why | ||
def wrapped(*args, dtype=None, **kwds): | ||
torch_dtype = None | ||
if dtype is not None: | ||
dtype = _dtypes.dtype(dtype) | ||
torch_dtype = dtype._scalar_type.torch_dtype | ||
return func(*args, dtype=torch_dtype, **kwds) | ||
return wrapped | ||
|
||
|
||
def emulate_out_arg(func): | ||
"""Simulate the out=... handling *for functions which do not need it*. | ||
ev-br marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
With this decorator, the inner function just does not see the out array. | ||
""" | ||
def wrapped(*args, out=None, **kwds): | ||
from ._ndarray import ndarray | ||
if out is not None: | ||
if not isinstance(out, ndarray): | ||
raise TypeError("Return arrays must be of ArrayType") | ||
result_tensor = func(*args, **kwds) | ||
return _helpers.result_or_out(result_tensor, out) | ||
|
||
return wrapped | ||
|
||
|
||
def out_shape_dtype(func): | ||
"""Handle out=... kwarg for ufuncs. | ||
|
||
With ufuncs, `out` array can typcast and broadcast ufunc arguments, hence | ||
extract the shape and dtype of the tensor which backs the `out` array | ||
and pass these through. | ||
""" | ||
def wrapped(*args, out=None, **kwds): | ||
from ._ndarray import ndarray | ||
if out is not None: | ||
if not isinstance(out, ndarray): | ||
raise TypeError("Return arrays must be of ArrayType") | ||
kwds.update({'out_shape_dtype': (out.get().dtype, out.get().shape)}) | ||
ev-br marked this conversation as resolved.
Show resolved
Hide resolved
|
||
result_tensor = func(*args, **kwds) | ||
return _helpers.result_or_out(result_tensor, out) | ||
|
||
return wrapped | ||
|
||
|
||
def deco_unary_ufunc_from_impl(impl_func): | ||
@functools.wraps(impl_func) | ||
@dtype_to_torch | ||
@out_shape_dtype | ||
def wrapped(x1, *args, **kwds): | ||
from ._ndarray import asarray | ||
x1_tensor = asarray(x1).get() | ||
result = impl_func((x1_tensor,), *args, **kwds) | ||
return result | ||
return wrapped | ||
|
||
|
||
# TODO: deduplicate with _ndarray/asarray_replacer, | ||
# and _wrapper/concatenate et al | ||
def deco_ufunc_from_impl(impl_func): | ||
ev-br marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@functools.wraps(impl_func) | ||
@dtype_to_torch | ||
@out_shape_dtype | ||
def wrapped(x1, x2, *args, **kwds): | ||
from ._ndarray import asarray | ||
x1_tensor = asarray(x1).get() | ||
x2_tensor = asarray(x2).get() | ||
return impl_func((x1_tensor, x2_tensor), *args, **kwds) | ||
return wrapped | ||
|
||
|
||
|
||
def axis_keepdims_wrapper(func): | ||
"""`func` accepts an array-like as a 1st arg, returns a tensor. | ||
|
||
This decorator implements the generic handling of axis, out and keepdims | ||
arguments for reduction functions. | ||
|
||
Note that we peel off `out=...` and `keepdims=...` args (torch functions never | ||
see them). The `axis` argument we normalize and pass through to pytorch functions. | ||
|
||
""" | ||
# XXX: move this out of _ndarray.py (circular imports) | ||
# | ||
# TODO: 1. get rid of _helpers.result_or_out | ||
# 2. sort out function signatures: how they flow through all decorators etc | ||
@functools.wraps(func) | ||
def wrapped(a, axis=None, out=None, keepdims=NoValue, *args, **kwds): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I get the feeling that this function does too much. I really liked how you split the dtype and out processing above, but here you have mixed them all together. As a result, you seem to forget handling the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, This decorator itself simply unwraps ndarrays and passes heavy lifting to Not sure how to simplify it further. Would a comment help? The fact that the usage is not very clear is I guess a direct consequence of having separate decorators for various arguments. So if it's the direction we want to go, there will be more of this I'm afraid. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, there are two general issues here. As mentioned, I think that this function and the one that implements the out behaviour do two very different things, and should be independent of each other. |
||
from ._ndarray import ndarray, asarray | ||
tensor = asarray(a).get() | ||
|
||
# standardize the axis argument | ||
if isinstance(axis, ndarray): | ||
axis = operator.index(axis) | ||
|
||
result = _util.axis_keepdims(func, tensor, axis, keepdims, *args, **kwds) | ||
return result | ||
|
||
return wrapped |
Uh oh!
There was an error while loading. Please reload this page.