Skip to content

Commit 71ea537

Browse files
committed
MAINT: rework unary and binary ufuncs w/ normalizations
1 parent 1d33f1a commit 71ea537

File tree

8 files changed

+185
-359
lines changed

8 files changed

+185
-359
lines changed

torch_np/_binary_ufuncs.py

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,60 @@
1-
from ._decorators import deco_binary_ufunc_from_impl
2-
from ._detail import _ufunc_impl
1+
from ._detail import _binary_ufuncs
2+
3+
__all__ = [name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"]
4+
5+
6+
# TODO: consolidate normalizations
7+
from ._funcs import normalizer, ArrayLike, SubokLike, DTypeLike
8+
from ._detail import _util
9+
from . import _helpers
10+
11+
12+
def deco_binary_ufunc(torch_func):
13+
"""Common infra for unary ufuncs.
14+
15+
Normalize arguments, sort out type casting, broadcasting and delegate to
16+
the pytorch functions for the actual work.
17+
"""
18+
def wrapped(
19+
x1 : ArrayLike,
20+
x2 : ArrayLike,
21+
/,
22+
out=None,
23+
*,
24+
where=True,
25+
casting="same_kind",
26+
order="K",
27+
dtype: DTypeLike=None,
28+
subok: SubokLike=False,
29+
signature=None,
30+
extobj=None
31+
):
32+
if order != "K" or not where or signature or extobj:
33+
raise NotImplementedError
34+
35+
# XXX: dtype=... parameter
36+
if dtype is not None:
37+
raise NotImplementedError
38+
39+
out_shape_dtype = None
40+
if out is not None:
41+
out_shape_dtype = (out.get().dtype, out.get().shape)
42+
43+
tensors = _util.cast_and_broadcast((x1, x2), out_shape_dtype, casting)
44+
45+
result = torch_func(*tensors)
46+
return _helpers.result_or_out(result, out)
47+
48+
return wrapped
349

450
#
5-
# Functions in this file implement binary ufuncs: wrap two first arguments in
6-
# asarray and delegate to functions from _ufunc_impl.
7-
#
8-
# Functions in _detail/_ufunc_impl.py receive tensors, implement common tasks
9-
# with ufunc args, and delegate heavy lifting to pytorch equivalents.
51+
# For each torch ufunc implementation, decorate and attach the decorated name
52+
# to this module. Its contents is then exported to the public namespace in __init__.py
1053
#
54+
for name in __all__:
55+
ufunc = getattr(_binary_ufuncs, name)
56+
decorated = normalizer(deco_binary_ufunc(ufunc))
1157

12-
# the list is autogenerated, cf autogen/gen_ufunc_2.py
13-
add = deco_binary_ufunc_from_impl(_ufunc_impl.add)
14-
arctan2 = deco_binary_ufunc_from_impl(_ufunc_impl.arctan2)
15-
bitwise_and = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_and)
16-
bitwise_or = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_or)
17-
bitwise_xor = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_xor)
18-
copysign = deco_binary_ufunc_from_impl(_ufunc_impl.copysign)
19-
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)
20-
equal = deco_binary_ufunc_from_impl(_ufunc_impl.equal)
21-
float_power = deco_binary_ufunc_from_impl(_ufunc_impl.float_power)
22-
floor_divide = deco_binary_ufunc_from_impl(_ufunc_impl.floor_divide)
23-
fmax = deco_binary_ufunc_from_impl(_ufunc_impl.fmax)
24-
fmin = deco_binary_ufunc_from_impl(_ufunc_impl.fmin)
25-
fmod = deco_binary_ufunc_from_impl(_ufunc_impl.fmod)
26-
gcd = deco_binary_ufunc_from_impl(_ufunc_impl.gcd)
27-
greater = deco_binary_ufunc_from_impl(_ufunc_impl.greater)
28-
greater_equal = deco_binary_ufunc_from_impl(_ufunc_impl.greater_equal)
29-
heaviside = deco_binary_ufunc_from_impl(_ufunc_impl.heaviside)
30-
hypot = deco_binary_ufunc_from_impl(_ufunc_impl.hypot)
31-
lcm = deco_binary_ufunc_from_impl(_ufunc_impl.lcm)
32-
ldexp = deco_binary_ufunc_from_impl(_ufunc_impl.ldexp)
33-
left_shift = deco_binary_ufunc_from_impl(_ufunc_impl.left_shift)
34-
less = deco_binary_ufunc_from_impl(_ufunc_impl.less)
35-
less_equal = deco_binary_ufunc_from_impl(_ufunc_impl.less_equal)
36-
logaddexp = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp)
37-
logaddexp2 = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp2)
38-
logical_and = deco_binary_ufunc_from_impl(_ufunc_impl.logical_and)
39-
logical_or = deco_binary_ufunc_from_impl(_ufunc_impl.logical_or)
40-
logical_xor = deco_binary_ufunc_from_impl(_ufunc_impl.logical_xor)
41-
matmul = deco_binary_ufunc_from_impl(_ufunc_impl.matmul)
42-
maximum = deco_binary_ufunc_from_impl(_ufunc_impl.maximum)
43-
minimum = deco_binary_ufunc_from_impl(_ufunc_impl.minimum)
44-
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
45-
multiply = deco_binary_ufunc_from_impl(_ufunc_impl.multiply)
46-
nextafter = deco_binary_ufunc_from_impl(_ufunc_impl.nextafter)
47-
not_equal = deco_binary_ufunc_from_impl(_ufunc_impl.not_equal)
48-
power = deco_binary_ufunc_from_impl(_ufunc_impl.power)
49-
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
50-
right_shift = deco_binary_ufunc_from_impl(_ufunc_impl.right_shift)
51-
subtract = deco_binary_ufunc_from_impl(_ufunc_impl.subtract)
52-
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)
58+
decorated.__qualname__ = name # XXX: is this really correct?
59+
decorated.__name__ = name
60+
vars()[name] = decorated

torch_np/_decorators.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,6 @@
99
NoValue = None
1010

1111

12-
def dtype_to_torch(func):
13-
@functools.wraps(func)
14-
def wrapped(*args, dtype=None, **kwds):
15-
torch_dtype = None
16-
if dtype is not None:
17-
dtype = _dtypes.dtype(dtype)
18-
torch_dtype = dtype.torch_dtype
19-
return func(*args, dtype=torch_dtype, **kwds)
20-
21-
return wrapped
22-
23-
24-
def emulate_out_arg(func):
25-
"""Simulate the out=... handling: move the result tensor to the out array.
26-
27-
With this decorator, the inner function just does not see the out array.
28-
"""
29-
30-
@functools.wraps(func)
31-
def wrapped(*args, out=None, **kwds):
32-
result_tensor = func(*args, **kwds)
33-
return _helpers.result_or_out(result_tensor, out)
34-
35-
return wrapped
36-
3712

3813
def out_shape_dtype(func):
3914
"""Handle out=... kwarg for ufuncs.
@@ -53,32 +28,3 @@ def wrapped(*args, out=None, **kwds):
5328
return wrapped
5429

5530

56-
def deco_unary_ufunc_from_impl(impl_func):
57-
@functools.wraps(impl_func)
58-
@dtype_to_torch
59-
@out_shape_dtype
60-
def wrapped(x1, *args, **kwds):
61-
from ._ndarray import asarray
62-
63-
x1_tensor = asarray(x1).get()
64-
result = impl_func((x1_tensor,), *args, **kwds)
65-
return result
66-
67-
return wrapped
68-
69-
70-
# TODO: deduplicate with _ndarray/asarray_replacer,
71-
# and _wrapper/concatenate et al
72-
def deco_binary_ufunc_from_impl(impl_func):
73-
@functools.wraps(impl_func)
74-
@dtype_to_torch
75-
@out_shape_dtype
76-
def wrapped(x1, x2, *args, **kwds):
77-
from ._ndarray import asarray
78-
79-
x1_tensor = asarray(x1).get()
80-
x2_tensor = asarray(x2).get()
81-
return impl_func((x1_tensor, x2_tensor), *args, **kwds)
82-
83-
return wrapped
84-

torch_np/_detail/_binary_ufuncs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Export torch work functions for binary ufuncs, rename/tweak to match numpy.
2+
This listing is further exported to public symbols in the `torch_np/_binary_ufuncs.py` module.
3+
"""
4+
5+
import torch
6+
7+
from . import _dtypes_impl, _util
8+
9+
10+
from torch import (add, arctan2, bitwise_and, bitwise_or, bitwise_xor, copysign, divide,
11+
float_power, floor_divide, fmax, fmin, fmod, gcd, greater, greater_equal, heaviside,
12+
hypot, lcm, ldexp, less, less_equal, logaddexp, logaddexp2, logical_and,
13+
logical_or, logical_xor, maximum, minimum, remainder, multiply, nextafter, not_equal,
14+
remainder, subtract, divide)
15+
16+
17+
# renames
18+
from torch import (eq as equal, pow as power, bitwise_right_shift as right_shift,
19+
bitwise_left_shift as left_shift,)
20+
21+
22+
# work around torch limitations w.r.t. numpy
23+
def matmul(x, y):
24+
# work around RuntimeError: expected scalar type Int but found Double
25+
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
26+
x = _util.cast_if_needed(x, dtype)
27+
y = _util.cast_if_needed(y, dtype)
28+
result = torch.matmul(x, y)
29+
return result
30+

torch_np/_detail/_ufunc_impl.py

Lines changed: 0 additions & 158 deletions
This file was deleted.

torch_np/_detail/_unary_ufuncs.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Export torch work functions for unary ufuncs, rename/tweak to match numpy.
2+
This listing is further exported to public symbols in the `torch_np/_unary_ufuncs.py` module.
3+
"""
4+
5+
import torch
6+
7+
from torch import (arccos, arccosh, arcsin, arcsinh, arctan, arctanh, ceil,
8+
cos, cosh, deg2rad, exp, exp2, expm1,
9+
floor, isfinite, isinf, isnan, log, log10, log1p, log2, logical_not,
10+
negative, rad2deg, reciprocal, sign, signbit,
11+
sin, sinh, sqrt, square, tan, tanh, trunc)
12+
13+
# renames
14+
from torch import (conj_physical as conjugate, round as rint, bitwise_not as invert, rad2deg as degrees,
15+
deg2rad as radians, absolute as fabs, )
16+
17+
# special cases: torch does not export these names
18+
def cbrt(x):
19+
return torch.pow(x, 1 / 3)
20+
21+
22+
def positive(x):
23+
return +x
24+
25+
26+
def absolute(x):
27+
# work around torch.absolute not impl for bools
28+
if x.dtype == torch.bool:
29+
return x
30+
return torch.absolute(x)
31+
32+
33+
abs = absolute
34+
conj = conjugate
35+

0 commit comments

Comments
 (0)