Skip to content

Commit 085fa8f

Browse files
authored
Merge pull request #34 from honno/refactor-fmt
Refactor w/ formatting and resolved merge conflicts
2 parents 69888d7 + 15e6704 commit 085fa8f

22 files changed

+2153
-2109
lines changed

autogen/gen_dtypes.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def __init__(self, name):
2828
"bool",
2929
]
3030

31+
32+
tmap = {dt: torch.as_tensor(np.ones(1, dtype=dt)).dtype for dt in dt_names}
33+
34+
3135
templ = """\
3236
{name} = dtype("{name}")
3337
"""
@@ -55,8 +59,8 @@ def generate_can_cast(casting):
5559
dct_dtyp1 = {}
5660
for dtyp2 in dt_names:
5761
can_cast = np.can_cast(np.dtype(dtyp1), np.dtype(dtyp2), casting=casting)
58-
dct_dtyp1[dtyp2] = can_cast
59-
dct[dtyp1] = dct_dtyp1
62+
dct_dtyp1[tmap[dtyp2]] = can_cast
63+
dct[tmap[dtyp1]] = dct_dtyp1
6064
return dct
6165

6266

@@ -67,8 +71,8 @@ def generate_result_type():
6771
dct_dtyp1 = {}
6872
for dtyp2 in dt_names:
6973
result_type = np.result_type(np.dtype(dtyp1), np.dtype(dtyp2))
70-
dct_dtyp1[dtyp2] = result_type.name
71-
dct[dtyp1] = dct_dtyp1
74+
dct_dtyp1[tmap[dtyp2]] = tmap[result_type.name]
75+
dct[tmap[dtyp1]] = dct_dtyp1
7276
return dct
7377

7478

torch_np/__init__.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,14 @@
1+
from ._wrapper import * # isort: skip # XXX: currently this prevents circular imports
12
from ._binary_ufuncs import *
3+
from ._detail._scalar_types import *
4+
from ._detail._util import AxisError, UFuncTypeError
25
from ._dtypes import *
36
from ._getlimits import errstate, finfo, iinfo
47
from ._ndarray import can_cast, newaxis, result_type
5-
from ._scalar_types import *
68
from ._unary_ufuncs import *
7-
from ._util import AxisError, UFuncTypeError
8-
from ._wrapper import *
99

1010
# from . import testing
1111

1212

1313
inf = float("inf")
1414
nan = float("nan")
15-
16-
17-
alltrue = all
18-
19-
#### HACK HACK HACK ####
20-
import torch
21-
22-
torch.set_default_dtype(torch.float64)
23-
del torch

torch_np/_binary_ufuncs.py

Lines changed: 44 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,52 @@
1-
import functools
2-
3-
import torch
4-
5-
from . import _dtypes, _helpers, _ufunc_impl, _util
6-
from ._ndarray import asarray
1+
from ._decorators import deco_binary_ufunc_from_impl
2+
from ._detail import _ufunc_impl
73

8-
#
9-
# Functions in _ufunc_impl receive arrays, implement common tasks with ufunc args
10-
# and delegate heavy lifting to pytorch equivalents.
114
#
125
# Functions in this file implement binary ufuncs: wrap two first arguments in
136
# asarray and delegate to functions from _ufunc_impl.
147
#
15-
# One other user of _ufunc_impl functions in ndarray, where its __add__ method
16-
# calls _ufunc_impl.add and so on. Note that ndarray dunders already know
17-
# that its first arg is an array, so they only convert the second argument.
18-
#
19-
# XXX: While it sounds tempting to merge _binary_ufuncs.py and _ufunc_impl.py
20-
# files, doing it would currently create import cycles.
8+
# Functions in _detail/_ufunc_impl.py receive tensors, implement common tasks
9+
# with ufunc args, and delegate heavy lifting to pytorch equivalents.
2110
#
2211

23-
# TODO: deduplicate with _unary_ufuncs/deco_unary_ufunc_from_impl,
24-
# _ndarray/asarray_replacer, and _wrapper/concatenate et al
25-
def deco_ufunc_from_impl(impl_func):
26-
@functools.wraps(impl_func)
27-
def wrapped(x1, x2, *args, **kwds):
28-
x1_array = asarray(x1)
29-
x2_array = asarray(x2)
30-
return impl_func(x1_array, x2_array, *args, **kwds)
31-
32-
return wrapped
33-
34-
3512
# the list is autogenerated, cf autogen/gen_ufunc_2.py
36-
add = deco_ufunc_from_impl(_ufunc_impl.add)
37-
arctan2 = deco_ufunc_from_impl(_ufunc_impl.arctan2)
38-
bitwise_and = deco_ufunc_from_impl(_ufunc_impl.bitwise_and)
39-
bitwise_or = deco_ufunc_from_impl(_ufunc_impl.bitwise_or)
40-
bitwise_xor = deco_ufunc_from_impl(_ufunc_impl.bitwise_xor)
41-
copysign = deco_ufunc_from_impl(_ufunc_impl.copysign)
42-
divide = deco_ufunc_from_impl(_ufunc_impl.divide)
43-
equal = deco_ufunc_from_impl(_ufunc_impl.equal)
44-
float_power = deco_ufunc_from_impl(_ufunc_impl.float_power)
45-
floor_divide = deco_ufunc_from_impl(_ufunc_impl.floor_divide)
46-
fmax = deco_ufunc_from_impl(_ufunc_impl.fmax)
47-
fmin = deco_ufunc_from_impl(_ufunc_impl.fmin)
48-
fmod = deco_ufunc_from_impl(_ufunc_impl.fmod)
49-
gcd = deco_ufunc_from_impl(_ufunc_impl.gcd)
50-
greater = deco_ufunc_from_impl(_ufunc_impl.greater)
51-
greater_equal = deco_ufunc_from_impl(_ufunc_impl.greater_equal)
52-
heaviside = deco_ufunc_from_impl(_ufunc_impl.heaviside)
53-
hypot = deco_ufunc_from_impl(_ufunc_impl.hypot)
54-
lcm = deco_ufunc_from_impl(_ufunc_impl.lcm)
55-
ldexp = deco_ufunc_from_impl(_ufunc_impl.ldexp)
56-
left_shift = deco_ufunc_from_impl(_ufunc_impl.left_shift)
57-
less = deco_ufunc_from_impl(_ufunc_impl.less)
58-
less_equal = deco_ufunc_from_impl(_ufunc_impl.less_equal)
59-
logaddexp = deco_ufunc_from_impl(_ufunc_impl.logaddexp)
60-
logaddexp2 = deco_ufunc_from_impl(_ufunc_impl.logaddexp2)
61-
logical_and = deco_ufunc_from_impl(_ufunc_impl.logical_and)
62-
logical_or = deco_ufunc_from_impl(_ufunc_impl.logical_or)
63-
logical_xor = deco_ufunc_from_impl(_ufunc_impl.logical_xor)
64-
matmul = deco_ufunc_from_impl(_ufunc_impl.matmul)
65-
maximum = deco_ufunc_from_impl(_ufunc_impl.maximum)
66-
minimum = deco_ufunc_from_impl(_ufunc_impl.minimum)
67-
remainder = deco_ufunc_from_impl(_ufunc_impl.remainder)
68-
multiply = deco_ufunc_from_impl(_ufunc_impl.multiply)
69-
nextafter = deco_ufunc_from_impl(_ufunc_impl.nextafter)
70-
not_equal = deco_ufunc_from_impl(_ufunc_impl.not_equal)
71-
power = deco_ufunc_from_impl(_ufunc_impl.power)
72-
remainder = deco_ufunc_from_impl(_ufunc_impl.remainder)
73-
right_shift = deco_ufunc_from_impl(_ufunc_impl.right_shift)
74-
subtract = deco_ufunc_from_impl(_ufunc_impl.subtract)
75-
divide = deco_ufunc_from_impl(_ufunc_impl.divide)
13+
add = deco_binary_ufunc_from_impl(_ufunc_impl.add)
14+
arctan2 = deco_binary_ufunc_from_impl(_ufunc_impl.arctan2)
15+
bitwise_and = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_and)
16+
bitwise_or = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_or)
17+
bitwise_xor = deco_binary_ufunc_from_impl(_ufunc_impl.bitwise_xor)
18+
copysign = deco_binary_ufunc_from_impl(_ufunc_impl.copysign)
19+
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)
20+
equal = deco_binary_ufunc_from_impl(_ufunc_impl.equal)
21+
float_power = deco_binary_ufunc_from_impl(_ufunc_impl.float_power)
22+
floor_divide = deco_binary_ufunc_from_impl(_ufunc_impl.floor_divide)
23+
fmax = deco_binary_ufunc_from_impl(_ufunc_impl.fmax)
24+
fmin = deco_binary_ufunc_from_impl(_ufunc_impl.fmin)
25+
fmod = deco_binary_ufunc_from_impl(_ufunc_impl.fmod)
26+
gcd = deco_binary_ufunc_from_impl(_ufunc_impl.gcd)
27+
greater = deco_binary_ufunc_from_impl(_ufunc_impl.greater)
28+
greater_equal = deco_binary_ufunc_from_impl(_ufunc_impl.greater_equal)
29+
heaviside = deco_binary_ufunc_from_impl(_ufunc_impl.heaviside)
30+
hypot = deco_binary_ufunc_from_impl(_ufunc_impl.hypot)
31+
lcm = deco_binary_ufunc_from_impl(_ufunc_impl.lcm)
32+
ldexp = deco_binary_ufunc_from_impl(_ufunc_impl.ldexp)
33+
left_shift = deco_binary_ufunc_from_impl(_ufunc_impl.left_shift)
34+
less = deco_binary_ufunc_from_impl(_ufunc_impl.less)
35+
less_equal = deco_binary_ufunc_from_impl(_ufunc_impl.less_equal)
36+
logaddexp = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp)
37+
logaddexp2 = deco_binary_ufunc_from_impl(_ufunc_impl.logaddexp2)
38+
logical_and = deco_binary_ufunc_from_impl(_ufunc_impl.logical_and)
39+
logical_or = deco_binary_ufunc_from_impl(_ufunc_impl.logical_or)
40+
logical_xor = deco_binary_ufunc_from_impl(_ufunc_impl.logical_xor)
41+
matmul = deco_binary_ufunc_from_impl(_ufunc_impl.matmul)
42+
maximum = deco_binary_ufunc_from_impl(_ufunc_impl.maximum)
43+
minimum = deco_binary_ufunc_from_impl(_ufunc_impl.minimum)
44+
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
45+
multiply = deco_binary_ufunc_from_impl(_ufunc_impl.multiply)
46+
nextafter = deco_binary_ufunc_from_impl(_ufunc_impl.nextafter)
47+
not_equal = deco_binary_ufunc_from_impl(_ufunc_impl.not_equal)
48+
power = deco_binary_ufunc_from_impl(_ufunc_impl.power)
49+
remainder = deco_binary_ufunc_from_impl(_ufunc_impl.remainder)
50+
right_shift = deco_binary_ufunc_from_impl(_ufunc_impl.right_shift)
51+
subtract = deco_binary_ufunc_from_impl(_ufunc_impl.subtract)
52+
divide = deco_binary_ufunc_from_impl(_ufunc_impl.divide)

torch_np/_decorators.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import functools
2+
import operator
3+
4+
import torch
5+
6+
from . import _dtypes, _helpers
7+
from ._detail import _util
8+
9+
NoValue = None
10+
11+
12+
def dtype_to_torch(func):
13+
@functools.wraps(func)
14+
def wrapped(*args, dtype=None, **kwds):
15+
torch_dtype = None
16+
if dtype is not None:
17+
dtype = _dtypes.dtype(dtype)
18+
torch_dtype = dtype._scalar_type.torch_dtype
19+
return func(*args, dtype=torch_dtype, **kwds)
20+
21+
return wrapped
22+
23+
24+
def emulate_out_arg(func):
25+
"""Simulate the out=... handling: move the result tensor to the out array.
26+
27+
With this decorator, the inner function just does not see the out array.
28+
"""
29+
30+
@functools.wraps(func)
31+
def wrapped(*args, out=None, **kwds):
32+
result_tensor = func(*args, **kwds)
33+
return _helpers.result_or_out(result_tensor, out)
34+
35+
return wrapped
36+
37+
38+
def out_shape_dtype(func):
39+
"""Handle out=... kwarg for ufuncs.
40+
41+
With ufuncs, `out` array can typcast and broadcast ufunc arguments, hence
42+
extract the shape and dtype of the tensor which backs the `out` array
43+
and pass these through.
44+
"""
45+
46+
@functools.wraps(func)
47+
def wrapped(*args, out=None, **kwds):
48+
if out is not None:
49+
kwds.update({"out_shape_dtype": (out.get().dtype, out.get().shape)})
50+
result_tensor = func(*args, **kwds)
51+
return _helpers.result_or_out(result_tensor, out)
52+
53+
return wrapped
54+
55+
56+
def deco_unary_ufunc_from_impl(impl_func):
57+
@functools.wraps(impl_func)
58+
@dtype_to_torch
59+
@out_shape_dtype
60+
def wrapped(x1, *args, **kwds):
61+
from ._ndarray import asarray
62+
63+
x1_tensor = asarray(x1).get()
64+
result = impl_func((x1_tensor,), *args, **kwds)
65+
return result
66+
67+
return wrapped
68+
69+
70+
# TODO: deduplicate with _ndarray/asarray_replacer,
71+
# and _wrapper/concatenate et al
72+
def deco_binary_ufunc_from_impl(impl_func):
73+
@functools.wraps(impl_func)
74+
@dtype_to_torch
75+
@out_shape_dtype
76+
def wrapped(x1, x2, *args, **kwds):
77+
from ._ndarray import asarray
78+
79+
x1_tensor = asarray(x1).get()
80+
x2_tensor = asarray(x2).get()
81+
return impl_func((x1_tensor, x2_tensor), *args, **kwds)
82+
83+
return wrapped
84+
85+
86+
def axis_keepdims_wrapper(func):
87+
"""`func` accepts an array-like as a 1st arg, returns a tensor.
88+
89+
This decorator implements the generic handling of axis, out and keepdims
90+
arguments for reduction functions.
91+
92+
Note that we peel off `out=...` and `keepdims=...` args (torch functions never
93+
see them). The `axis` argument we normalize and pass through to pytorch functions.
94+
95+
"""
96+
# XXX: move this out of _ndarray.py (circular imports)
97+
#
98+
# TODO: 1. get rid of _helpers.result_or_out
99+
# 2. sort out function signatures: how they flow through all decorators etc
100+
@functools.wraps(func)
101+
def wrapped(a, axis=None, keepdims=NoValue, *args, **kwds):
102+
from ._ndarray import asarray, ndarray
103+
104+
tensor = asarray(a).get()
105+
106+
# standardize the axis argument
107+
if isinstance(axis, ndarray):
108+
axis = operator.index(axis)
109+
110+
result = _util.axis_keepdims(func, tensor, axis, keepdims, *args, **kwds)
111+
return result
112+
113+
return wrapped

torch_np/_detail/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)