Skip to content

Commit 37dcf49

Browse files
committed
MAINT: an incremental simplification of wrapping of unary ufunc
1 parent 1acb5aa commit 37dcf49

File tree

4 files changed

+171
-791
lines changed

4 files changed

+171
-791
lines changed

autogen/gen_ufuncs_2.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,29 @@ def test_{np_name}():
8585
main_text = header
8686
test_text = test_header
8787

88+
_impl_list = []
89+
_ufunc_list = []
90+
8891
for ufunc in dct['ufunc']:
8992
if ufunc in skip:
9093
continue
9194

9295
if ufunc.nin == 1:
93-
# print(get_signature(ufunc))
96+
# print(get_signature(ufunc))
9497

9598
torch_name = torch_names.get(ufunc)
9699
if torch_name is None:
97100
torch_name = ufunc.__name__
98101

102+
# print(ufunc.__name__, ' -- ', torch_name)
103+
104+
_impl_stanza = "{np_name} = deco_unary_ufunc(torch.{torch_name})"
105+
_impl_stanza = _impl_stanza.format(np_name=ufunc.__name__,
106+
torch_name=torch_name,)
107+
_impl_list.append(_impl_stanza)
108+
109+
continue
110+
99111
torch_stanza = stanzas.get(ufunc)
100112
if torch_stanza is None:
101113
torch_stanza = default_stanza.format(torch_name=torch_name)
@@ -109,6 +121,11 @@ def test_{np_name}():
109121

110122
_all_list.append(ufunc.__name__)
111123

124+
125+
print("\n".join(_impl_list))
126+
print("\n\n-----\n\n")
127+
128+
'''
112129
main_text += "\n\n__all__ = %s" % _all_list
113130
114131
@@ -117,7 +134,7 @@ def test_{np_name}():
117134
118135
with open("test_unary_ufuncs.py", "w") as f:
119136
f.write(test_text)
120-
137+
'''
121138

122139
###### BINARY UFUNCS ###################################
123140

@@ -163,6 +180,7 @@ def test_{np_name}():
163180
_impl_list = []
164181
_ufunc_list = []
165182

183+
166184
for ufunc in dct['ufunc']:
167185

168186
if ufunc in skip:
@@ -190,3 +208,6 @@ def test_{np_name}():
190208
print("\n\n")
191209
print("\n".join(_ufunc_list))
192210

211+
212+
213+

torch_np/_binary_ufuncs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
# files, doing it would currently create import cycles.
2525
#
2626

27-
27+
# TODO: deduplicate with _unary_ufuncs/deco_unary_ufunc_from_impl,
28+
# _ndarray/asarray_replacer, and _wrapper/concatenate et al
2829
def deco_ufunc_from_impl(impl_func):
2930
@functools.wraps(impl_func)
3031
def wrapped(x1, x2, *args, **kwds):
@@ -75,3 +76,4 @@ def wrapped(x1, x2, *args, **kwds):
7576
right_shift = deco_ufunc_from_impl(_ufunc_impl.right_shift)
7677
subtract = deco_ufunc_from_impl(_ufunc_impl.subtract)
7778
divide = deco_ufunc_from_impl(_ufunc_impl.divide)
79+

torch_np/_ufunc_impl.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,38 @@ def wrapped(x1, x2, /, out=None, *, where=True,
1919
# XXX: dtype=... parameter is silently ignored
2020

2121
arrays = (x1, x2)
22-
x1_tensor, x2_tensor = _helpers.cast_and_broadcast(arrays, out, casting)
22+
tensors = _helpers.cast_and_broadcast(arrays, out, casting)
2323

24-
result = torch_func(x1_tensor, x2_tensor)
24+
result = torch_func(*tensors)
2525

2626
return _helpers.result_or_out(result, out)
2727
return wrapped
2828

2929

30-
# the list is autogenerated, cf autogen/gen_ufunc_2.py
30+
def deco_unary_ufunc(torch_func):
31+
# TODO: deduplicate with `deco_binary_ufunc` above. Need to figure out the
32+
# effect of the `dtype` parameter, does it differ between unary and binary ufuncs.
33+
def wrapped(x1, /, out=None, *, where=True,
34+
casting='same_kind', order='K', dtype=None, subok=False, **kwds):
35+
_util.subok_not_ok(subok=subok)
36+
if order != 'K' or not where:
37+
raise NotImplementedError
38+
39+
# XXX: dtype=... parameter is silently ignored
40+
41+
arrays = (x1, )
42+
tensors = _helpers.cast_and_broadcast(arrays, out, casting)
43+
44+
result = torch_func(*tensors)
45+
46+
return _helpers.result_or_out(result, out)
47+
return wrapped
48+
49+
50+
51+
52+
# binary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py
53+
# And edited manually! np.equal <--> torch.eq, not torch.equal
3154
add = deco_binary_ufunc(torch.add)
3255
arctan2 = deco_binary_ufunc(torch.arctan2)
3356
bitwise_and = deco_binary_ufunc(torch.bitwise_and)
@@ -69,3 +92,59 @@ def wrapped(x1, x2, /, out=None, *, where=True,
6992
subtract = deco_binary_ufunc(torch.subtract)
7093
divide = deco_binary_ufunc(torch.divide)
7194

95+
96+
97+
# unary ufuncs: the list is autogenerated, cf autogen/gen_ufunc_2.py
98+
absolute = deco_unary_ufunc(torch.absolute)
99+
#absolute = deco_unary_ufunc(torch.absolute)
100+
arccos = deco_unary_ufunc(torch.arccos)
101+
arccosh = deco_unary_ufunc(torch.arccosh)
102+
arcsin = deco_unary_ufunc(torch.arcsin)
103+
arcsinh = deco_unary_ufunc(torch.arcsinh)
104+
arctan = deco_unary_ufunc(torch.arctan)
105+
arctanh = deco_unary_ufunc(torch.arctanh)
106+
ceil = deco_unary_ufunc(torch.ceil)
107+
conjugate = deco_unary_ufunc(torch.conj_physical)
108+
#conjugate = deco_unary_ufunc(torch.conj_physical)
109+
cos = deco_unary_ufunc(torch.cos)
110+
cosh = deco_unary_ufunc(torch.cosh)
111+
deg2rad = deco_unary_ufunc(torch.deg2rad)
112+
degrees = deco_unary_ufunc(torch.rad2deg)
113+
exp = deco_unary_ufunc(torch.exp)
114+
exp2 = deco_unary_ufunc(torch.exp2)
115+
expm1 = deco_unary_ufunc(torch.expm1)
116+
fabs = deco_unary_ufunc(torch.absolute)
117+
floor = deco_unary_ufunc(torch.floor)
118+
isfinite = deco_unary_ufunc(torch.isfinite)
119+
isinf = deco_unary_ufunc(torch.isinf)
120+
isnan = deco_unary_ufunc(torch.isnan)
121+
log = deco_unary_ufunc(torch.log)
122+
log10 = deco_unary_ufunc(torch.log10)
123+
log1p = deco_unary_ufunc(torch.log1p)
124+
log2 = deco_unary_ufunc(torch.log2)
125+
logical_not = deco_unary_ufunc(torch.logical_not)
126+
negative = deco_unary_ufunc(torch.negative)
127+
rad2deg = deco_unary_ufunc(torch.rad2deg)
128+
radians = deco_unary_ufunc(torch.deg2rad)
129+
reciprocal = deco_unary_ufunc(torch.reciprocal)
130+
rint = deco_unary_ufunc(torch.round)
131+
sign = deco_unary_ufunc(torch.sign)
132+
signbit = deco_unary_ufunc(torch.signbit)
133+
sin = deco_unary_ufunc(torch.sin)
134+
sinh = deco_unary_ufunc(torch.sinh)
135+
sqrt = deco_unary_ufunc(torch.sqrt)
136+
square = deco_unary_ufunc(torch.square)
137+
tan = deco_unary_ufunc(torch.tan)
138+
tanh = deco_unary_ufunc(torch.tanh)
139+
trunc = deco_unary_ufunc(torch.trunc)
140+
141+
# special cases: torch does not export these names
142+
def _cbrt(x):
143+
return torch.pow(x, 1/3)
144+
145+
def _positive(x):
146+
return +x
147+
148+
cbrt = deco_unary_ufunc(_cbrt)
149+
positive = deco_unary_ufunc(_positive)
150+

0 commit comments

Comments
 (0)