Skip to content

Commit 023d453

Browse files
authored
Merge pull request #70 from Quansight-Labs/normalizations
bare-bones normalizations via type hints
2 parents 89a28ec + 7dced32 commit 023d453

22 files changed

+1288
-1070
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from ._wrapper import * # isort: skip # XXX: currently this prevents circular imports
21
from . import random
32
from ._binary_ufuncs import *
43
from ._detail._index_tricks import *
@@ -8,6 +7,7 @@
87
from ._getlimits import errstate, finfo, iinfo
98
from ._ndarray import array, asarray, can_cast, ndarray, newaxis, result_type
109
from ._unary_ufuncs import *
10+
from ._wrapper import *
1111

1212
# from . import testing
1313

torch_np/_binary_ufuncs.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,52 @@
1-
from ._decorators import deco_binary_ufunc_from_impl
2-
from ._detail import _ufunc_impl
1+
from typing import Optional
2+
3+
from . import _helpers
4+
from ._detail import _binary_ufuncs
5+
from ._normalizations import ArrayLike, DTypeLike, NDArray, SubokLike, normalizer
6+
7+
__all__ = [
8+
name for name in dir(_binary_ufuncs) if not name.startswith("_") and name != "torch"
9+
]
10+
11+
12+
def deco_binary_ufunc(torch_func):
13+
"""Common infra for binary ufuncs.
14+
15+
Normalize arguments, sort out type casting, broadcasting and delegate to
16+
the pytorch functions for the actual work.
17+
"""
18+
19+
def wrapped(
20+
x1: ArrayLike,
21+
x2: ArrayLike,
22+
/,
23+
out: Optional[NDArray] = None,
24+
*,
25+
where=True,
26+
casting="same_kind",
27+
order="K",
28+
dtype: DTypeLike = None,
29+
subok: SubokLike = False,
30+
signature=None,
31+
extobj=None,
32+
):
33+
tensors = _helpers.ufunc_preprocess(
34+
(x1, x2), out, where, casting, order, dtype, subok, signature, extobj
35+
)
36+
result = torch_func(*tensors)
37+
return _helpers.result_or_out(result, out)
38+
39+
return wrapped
40+
341

442
#
5-
# Functions in this file implement binary ufuncs: wrap two first arguments in
6-
# asarray and delegate to functions from _ufunc_impl.
7-
#
8-
# Functions in _detail/_ufunc_impl.py receive tensors, implement common tasks
9-
# with ufunc args, and delegate heavy lifting to pytorch equivalents.
43+
# For each torch ufunc implementation, decorate and attach the decorated name
44+
# to this module. Its contents is then exported to the public namespace in __init__.py
1045
#
46+
for name in __all__:
47+
ufunc = getattr(_binary_ufuncs, name)
48+
decorated = normalizer(deco_binary_ufunc(ufunc))
1149

12-
# the list is autogenerated, cf autogen/gen_ufunc_2.py
13-
add = deco_binary_ufunc_from_impl(_ufunc_impl.add)
14-
arctan2 = deco_binary_ufunc_from_impl(_ufunc_impl.arctan2)
15-
bitwise_and = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_and)
16-
bitwise_or = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_or)
17-
bitwise_xor = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_xor)
18-
copysign = deco_binary_ufunc_from_impl(_ufunc_impl.copysign)
19-
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)
20-
equal = deco_binary_ufunc_from_impl(_ufunc_impl.equal)
21-
float_power = deco_binary_ufunc_from_impl(_ufunc_impl.float_power)
22-
floor_divide = deco_binary_ufunc_from_impl(_ufunc_impl.floor_divide)
23-
fmax = deco_binary_ufunc_from_impl(_ufunc_impl.fmax)
24-
fmin = deco_binary_ufunc_from_impl(_ufunc_impl.fmin)
25-
fmod = deco_binary_ufunc_from_impl(_ufunc_impl.fmod)
26-
gcd = deco_binary_ufunc_from_impl(_ufunc_impl.gcd)
27-
greater = deco_binary_ufunc_from_impl(_ufunc_impl.greater)
28-
greater_equal = deco_binary_ufunc_from_impl(_ufunc_impl.greater_equal)
29-
heaviside = deco_binary_ufunc_from_impl(_ufunc_impl.heaviside)
30-
hypot = deco_binary_ufunc_from_impl(_ufunc_impl.hypot)
31-
lcm = deco_binary_ufunc_from_impl(_ufunc_impl.lcm)
32-
ldexp = deco_binary_ufunc_from_impl(_ufunc_impl.ldexp)
33-
left_shift = deco_binary_ufunc_from_impl(_ufunc_impl.left_shift)
34-
less = deco_binary_ufunc_from_impl(_ufunc_impl.less)
35-
less_equal = deco_binary_ufunc_from_impl(_ufunc_impl.less_equal)
36-
logaddexp = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp)
37-
logaddexp2 = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp2)
38-
logical_and = deco_binary_ufunc_from_impl(_ufunc_impl.logical_and)
39-
logical_or = deco_binary_ufunc_from_impl(_ufunc_impl.logical_or)
40-
logical_xor = deco_binary_ufunc_from_impl(_ufunc_impl.logical_xor)
41-
matmul = deco_binary_ufunc_from_impl(_ufunc_impl.matmul)
42-
maximum = deco_binary_ufunc_from_impl(_ufunc_impl.maximum)
43-
minimum = deco_binary_ufunc_from_impl(_ufunc_impl.minimum)
44-
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
45-
multiply = deco_binary_ufunc_from_impl(_ufunc_impl.multiply)
46-
nextafter = deco_binary_ufunc_from_impl(_ufunc_impl.nextafter)
47-
not_equal = deco_binary_ufunc_from_impl(_ufunc_impl.not_equal)
48-
power = deco_binary_ufunc_from_impl(_ufunc_impl.power)
49-
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
50-
right_shift = deco_binary_ufunc_from_impl(_ufunc_impl.right_shift)
51-
subtract = deco_binary_ufunc_from_impl(_ufunc_impl.subtract)
52-
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)
50+
decorated.__qualname__ = name # XXX: is this really correct?
51+
decorated.__name__ = name
52+
vars()[name] = decorated

torch_np/_decorators.py

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,10 @@
11
import functools
2-
import operator
32

43
import torch
54

65
from . import _dtypes, _helpers
76
from ._detail import _util
87

9-
NoValue = None
10-
11-
12-
def dtype_to_torch(func):
13-
@functools.wraps(func)
14-
def wrapped(*args, dtype=None, **kwds):
15-
torch_dtype = None
16-
if dtype is not None:
17-
dtype = _dtypes.dtype(dtype)
18-
torch_dtype = dtype.torch_dtype
19-
return func(*args, dtype=torch_dtype, **kwds)
20-
21-
return wrapped
22-
23-
24-
def emulate_out_arg(func):
25-
"""Simulate the out=... handling: move the result tensor to the out array.
26-
27-
With this decorator, the inner function just does not see the out array.
28-
"""
29-
30-
@functools.wraps(func)
31-
def wrapped(*args, out=None, **kwds):
32-
result_tensor = func(*args, **kwds)
33-
return _helpers.result_or_out(result_tensor, out)
34-
35-
return wrapped
36-
378

389
def out_shape_dtype(func):
3910
"""Handle out=... kwarg for ufuncs.
@@ -51,89 +22,3 @@ def wrapped(*args, out=None, **kwds):
5122
return _helpers.result_or_out(result_tensor, out)
5223

5324
return wrapped
54-
55-
56-
def deco_unary_ufunc_from_impl(impl_func):
57-
@functools.wraps(impl_func)
58-
@dtype_to_torch
59-
@out_shape_dtype
60-
def wrapped(x1, *args, **kwds):
61-
from ._ndarray import asarray
62-
63-
x1_tensor = asarray(x1).get()
64-
result = impl_func((x1_tensor,), *args, **kwds)
65-
return result
66-
67-
return wrapped
68-
69-
70-
# TODO: deduplicate with _ndarray/asarray_replacer,
71-
# and _wrapper/concatenate et al
72-
def deco_binary_ufunc_from_impl(impl_func):
73-
@functools.wraps(impl_func)
74-
@dtype_to_torch
75-
@out_shape_dtype
76-
def wrapped(x1, x2, *args, **kwds):
77-
from ._ndarray import asarray
78-
79-
x1_tensor = asarray(x1).get()
80-
x2_tensor = asarray(x2).get()
81-
return impl_func((x1_tensor, x2_tensor), *args, **kwds)
82-
83-
return wrapped
84-
85-
86-
def axis_keepdims_wrapper(func):
87-
"""`func` accepts an array-like as a 1st arg, returns a tensor.
88-
89-
This decorator implements the generic handling of axis, out and keepdims
90-
arguments for reduction functions.
91-
92-
Note that we peel off `out=...` and `keepdims=...` args (torch functions never
93-
see them). The `axis` argument we normalize and pass through to pytorch functions.
94-
95-
"""
96-
# TODO: sort out function signatures: how they flow through all decorators etc
97-
@functools.wraps(func)
98-
def wrapped(a, axis=None, keepdims=NoValue, *args, **kwds):
99-
from ._ndarray import asarray, ndarray
100-
101-
tensor = asarray(a).get()
102-
103-
# standardize the axis argument
104-
if isinstance(axis, ndarray):
105-
axis = operator.index(axis)
106-
107-
result = _util.axis_expand_func(func, tensor, axis, *args, **kwds)
108-
109-
if keepdims:
110-
result = _util.apply_keepdims(result, axis, tensor.ndim)
111-
112-
return result
113-
114-
return wrapped
115-
116-
117-
def axis_none_ravel_wrapper(func):
118-
"""`func` accepts an array-like as a 1st arg, returns a tensor.
119-
120-
This decorator implements the generic handling of axis=None acting on a
121-
raveled array. One use is cumprod / cumsum. concatenate also uses a
122-
similar logic.
123-
124-
"""
125-
126-
@functools.wraps(func)
127-
def wrapped(a, axis=None, *args, **kwds):
128-
from ._ndarray import asarray, ndarray
129-
130-
tensor = asarray(a).get()
131-
132-
# standardize the axis argument
133-
if isinstance(axis, ndarray):
134-
axis = operator.index(axis)
135-
136-
result = _util.axis_ravel_func(func, tensor, axis, *args, **kwds)
137-
return result
138-
139-
return wrapped

torch_np/_detail/_binary_ufuncs.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Export torch work functions for binary ufuncs, rename/tweak to match numpy.
2+
This listing is further exported to public symbols in the `torch_np/_binary_ufuncs.py` module.
3+
"""
4+
5+
import torch
6+
7+
# renames
8+
from torch import add, arctan2, bitwise_and
9+
from torch import bitwise_left_shift as left_shift
10+
from torch import bitwise_or
11+
from torch import bitwise_right_shift as right_shift
12+
from torch import bitwise_xor, copysign, divide
13+
from torch import eq as equal
14+
from torch import (
15+
float_power,
16+
floor_divide,
17+
fmax,
18+
fmin,
19+
fmod,
20+
gcd,
21+
greater,
22+
greater_equal,
23+
heaviside,
24+
hypot,
25+
lcm,
26+
ldexp,
27+
less,
28+
less_equal,
29+
logaddexp,
30+
logaddexp2,
31+
logical_and,
32+
logical_or,
33+
logical_xor,
34+
maximum,
35+
minimum,
36+
multiply,
37+
nextafter,
38+
not_equal,
39+
)
40+
from torch import pow as power
41+
from torch import remainder, subtract
42+
43+
from . import _dtypes_impl, _util
44+
45+
46+
# work around torch limitations w.r.t. numpy
47+
def matmul(x, y):
48+
# work around RuntimeError: expected scalar type Int but found Double
49+
dtype = _dtypes_impl.result_type_impl((x.dtype, y.dtype))
50+
x = _util.cast_if_needed(x, dtype)
51+
y = _util.cast_if_needed(y, dtype)
52+
result = torch.matmul(x, y)
53+
return result

0 commit comments

Comments
 (0)