Skip to content

Commit 13ee073

Browse files
committed
MAINT: regenerate binary ufuncs
1 parent 2a1a1f2 commit 13ee073

File tree

6 files changed

+315
-623
lines changed

6 files changed

+315
-623
lines changed

autogen/gen_ufuncs_2.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from dump_namespace import grab_namespace, get_signature
2+
3+
import numpy as np
4+
5+
namespace = np
6+
7+
dct = grab_namespace(namespace)
8+
9+
10+
# SKIP these (need special handling)
11+
skip = {np.frexp, np.modf, # non-standard unary ufunc signatures
12+
np.isnat,
13+
np.invert, # bitwise NOT operator
14+
np.spacing, # niche, does not have a direct equivalent
15+
}
16+
17+
# np functions where torch names differ
18+
torch_names = {np.radians : "deg2rad",
19+
np.degrees : "rad2deg",
20+
np.conjugate : "conj_physical",
21+
np.fabs : "absolute", # FIXME: np.fabs raises form complex
22+
np.rint : "round",
23+
np.left_shift: "bitwise_left_shift",
24+
np.right_shift: "bitwise_right_shift",
25+
}
26+
27+
28+
# np functions which do not have a torch equivalent
29+
default_stanza = "torch.{torch_name}(x, out=out)"
30+
31+
stanzas = {np.cbrt : "torch.pow(x, 1/3, out=out)",
32+
33+
# XXX what on earth is np.positive
34+
np.positive: "+x",
35+
36+
# these three do not have an out arg
37+
np.isinf: "torch.isinf(x)",
38+
np.isnan: "torch.isnan(x)",
39+
np.isfinite: "torch.isfinite(x)",
40+
}
41+
42+
43+
# for these np functions, pytorch analog does not have the out= arg
44+
needs_out = {np.isinf, np.isnan, np.isfinite, np.positive}
45+
add_out_stanza = """
46+
if out is not None:
47+
out[...] = result
48+
"""
49+
50+
51+
header = """\
52+
# this file is autogenerated via gen_ufuncs.py
53+
# do not edit manually!
54+
55+
import torch
56+
57+
import _util
58+
from _ndarray import asarray_replacer
59+
60+
"""
61+
62+
test_header = header + """\
63+
import numpy as np
64+
import torch
65+
66+
from _unary_ufuncs import *
67+
from testing import assert_allclose
68+
"""
69+
70+
71+
template = """ """
72+
73+
test_template = """
74+
75+
def test_{np_name}():
76+
assert_allclose(np.{np_name}(0.5),
77+
{np_name}(0.5), atol=1e-14, check_dtype=False)
78+
79+
"""
80+
81+
82+
###### UNARY UFUNCS ###################################
83+
84+
_all_list = []
85+
main_text = header
86+
test_text = test_header
87+
88+
for ufunc in dct['ufunc']:
89+
if ufunc in skip:
90+
continue
91+
92+
if ufunc.nin == 1:
93+
# print(get_signature(ufunc))
94+
95+
torch_name = torch_names.get(ufunc)
96+
if torch_name is None:
97+
torch_name = ufunc.__name__
98+
99+
torch_stanza = stanzas.get(ufunc)
100+
if torch_stanza is None:
101+
torch_stanza = default_stanza.format(torch_name=torch_name)
102+
103+
out_stanza= add_out_stanza if ufunc in needs_out else ""
104+
105+
main_text += template.format(np_name=ufunc.__name__,
106+
torch_stanza=torch_stanza,
107+
out_stanza=out_stanza)
108+
test_text += test_template.format(np_name=ufunc.__name__)
109+
110+
_all_list.append(ufunc.__name__)
111+
112+
main_text += "\n\n__all__ = %s" % _all_list
113+
114+
115+
with open("_unary_ufuncs.py", "w") as f:
116+
f.write(main_text)
117+
118+
with open("test_unary_ufuncs.py", "w") as f:
119+
f.write(test_text)
120+
121+
122+
###### BINARY UFUNCS ###################################
123+
124+
125+
126+
test_header = header + """\
127+
import numpy as np
128+
import torch
129+
130+
from _binary_ufuncs import *
131+
from testing import assert_allclose
132+
"""
133+
134+
135+
template = """
136+
137+
138+
"""
139+
140+
test_template = """
141+
142+
def test_{np_name}():
143+
assert_allclose(np.{np_name}(0.5, 0.6),
144+
{np_name}(0.5, 0.6), atol=1e-7, check_dtype=False)
145+
146+
"""
147+
148+
149+
150+
skip = {np.divmod, # two outputs
151+
}
152+
153+
154+
torch_names = {np.power: "pow",
155+
np.equal: "eq",
156+
}
157+
158+
159+
_all_list = []
160+
main_text = header
161+
test_text = test_header
162+
163+
_impl_list = []
164+
_ufunc_list = []
165+
166+
for ufunc in dct['ufunc']:
167+
168+
if ufunc in skip:
169+
continue
170+
171+
if ufunc.nin == 2:
172+
## print(get_signature(ufunc))
173+
174+
torch_name = torch_names.get(ufunc)
175+
if torch_name is None:
176+
torch_name = ufunc.__name__
177+
178+
_impl_stanza = "{np_name} = deco_binary_ufunc(torch.{torch_name})"
179+
_impl_stanza = _impl_stanza.format(np_name=ufunc.__name__,
180+
torch_name=torch_name,)
181+
_impl_list.append(_impl_stanza)
182+
183+
_ufunc_stanza = "{np_name} = deco_ufunc_from_impl(_ufunc_impl.{np_name})"
184+
_ufunc_stanza = _ufunc_stanza.format(np_name=ufunc.__name__)
185+
_ufunc_list.append(_ufunc_stanza)
186+
187+
188+
print("\n".join(_impl_list))
189+
190+
print("\n\n")
191+
print("\n".join(_ufunc_list))
192+

0 commit comments

Comments
 (0)