Skip to content

Commit a8c18aa

Browse files
committed
Merge branch 'ndarray_dunders_ufuncs' into main
Reviewed at #17
2 parents f63c6bb + 0e4001a commit a8c18aa

14 files changed

+904
-1545
lines changed

autogen/gen_ufuncs_2.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
from dump_namespace import grab_namespace, get_signature
2+
3+
import numpy as np
4+
5+
namespace = np
6+
7+
dct = grab_namespace(namespace)
8+
9+
10+
# SKIP these (need special handling)
11+
skip = {np.frexp, np.modf, # non-standard unary ufunc signatures
12+
np.isnat,
13+
np.invert, # bitwise NOT operator
14+
np.spacing, # niche, does not have a direct equivalent
15+
}
16+
17+
# np functions where torch names differ
18+
torch_names = {np.radians : "deg2rad",
19+
np.degrees : "rad2deg",
20+
np.conjugate : "conj_physical",
21+
np.fabs : "absolute", # FIXME: np.fabs raises form complex
22+
np.rint : "round",
23+
np.left_shift: "bitwise_left_shift",
24+
np.right_shift: "bitwise_right_shift",
25+
}
26+
27+
28+
# np functions which do not have a torch equivalent
29+
default_stanza = "torch.{torch_name}(x, out=out)"
30+
31+
stanzas = {np.cbrt : "torch.pow(x, 1/3, out=out)",
32+
33+
# XXX what on earth is np.positive
34+
np.positive: "+x",
35+
36+
# these three do not have an out arg
37+
np.isinf: "torch.isinf(x)",
38+
np.isnan: "torch.isnan(x)",
39+
np.isfinite: "torch.isfinite(x)",
40+
}
41+
42+
43+
# for these np functions, pytorch analog does not have the out= arg
44+
needs_out = {np.isinf, np.isnan, np.isfinite, np.positive}
45+
add_out_stanza = """
46+
if out is not None:
47+
out[...] = result
48+
"""
49+
50+
51+
header = """\
52+
# this file is autogenerated via gen_ufuncs.py
53+
# do not edit manually!
54+
55+
import torch
56+
57+
import _util
58+
from _ndarray import asarray_replacer
59+
60+
"""
61+
62+
test_header = header + """\
63+
import numpy as np
64+
import torch
65+
66+
from _unary_ufuncs import *
67+
from testing import assert_allclose
68+
"""
69+
70+
71+
template = """ """
72+
73+
test_template = """
74+
75+
def test_{np_name}():
76+
assert_allclose(np.{np_name}(0.5),
77+
{np_name}(0.5), atol=1e-14, check_dtype=False)
78+
79+
"""
80+
81+
82+
###### UNARY UFUNCS ###################################
83+
84+
_all_list = []
85+
main_text = header
86+
test_text = test_header
87+
88+
_impl_list = []
89+
_ufunc_list = []
90+
91+
for ufunc in dct['ufunc']:
92+
if ufunc in skip:
93+
continue
94+
95+
if ufunc.nin == 1:
96+
# print(get_signature(ufunc))
97+
98+
torch_name = torch_names.get(ufunc)
99+
if torch_name is None:
100+
torch_name = ufunc.__name__
101+
102+
# print(ufunc.__name__, ' -- ', torch_name)
103+
104+
_impl_stanza = "{np_name} = deco_unary_ufunc(torch.{torch_name})"
105+
_impl_stanza = _impl_stanza.format(np_name=ufunc.__name__,
106+
torch_name=torch_name,)
107+
_impl_list.append(_impl_stanza)
108+
109+
continue
110+
111+
torch_stanza = stanzas.get(ufunc)
112+
if torch_stanza is None:
113+
torch_stanza = default_stanza.format(torch_name=torch_name)
114+
115+
out_stanza= add_out_stanza if ufunc in needs_out else ""
116+
117+
main_text += template.format(np_name=ufunc.__name__,
118+
torch_stanza=torch_stanza,
119+
out_stanza=out_stanza)
120+
test_text += test_template.format(np_name=ufunc.__name__)
121+
122+
_all_list.append(ufunc.__name__)
123+
124+
125+
print("\n".join(_impl_list))
126+
print("\n\n-----\n\n")
127+
128+
'''
129+
main_text += "\n\n__all__ = %s" % _all_list
130+
131+
132+
with open("_unary_ufuncs.py", "w") as f:
133+
f.write(main_text)
134+
135+
with open("test_unary_ufuncs.py", "w") as f:
136+
f.write(test_text)
137+
'''
138+
139+
###### BINARY UFUNCS ###################################
140+
141+
142+
143+
test_header = header + """\
144+
import numpy as np
145+
import torch
146+
147+
from _binary_ufuncs import *
148+
from testing import assert_allclose
149+
"""
150+
151+
152+
template = """
153+
154+
155+
"""
156+
157+
test_template = """
158+
159+
def test_{np_name}():
160+
assert_allclose(np.{np_name}(0.5, 0.6),
161+
{np_name}(0.5, 0.6), atol=1e-7, check_dtype=False)
162+
163+
"""
164+
165+
166+
167+
skip = {np.divmod, # two outputs
168+
}
169+
170+
171+
torch_names = {np.power: "pow",
172+
np.equal: "eq",
173+
}
174+
175+
176+
_all_list = []
177+
main_text = header
178+
test_text = test_header
179+
180+
_impl_list = []
181+
_ufunc_list = []
182+
183+
184+
for ufunc in dct['ufunc']:
185+
186+
if ufunc in skip:
187+
continue
188+
189+
if ufunc.nin == 2:
190+
## print(get_signature(ufunc))
191+
192+
torch_name = torch_names.get(ufunc)
193+
if torch_name is None:
194+
torch_name = ufunc.__name__
195+
196+
_impl_stanza = "{np_name} = deco_binary_ufunc(torch.{torch_name})"
197+
_impl_stanza = _impl_stanza.format(np_name=ufunc.__name__,
198+
torch_name=torch_name,)
199+
_impl_list.append(_impl_stanza)
200+
201+
_ufunc_stanza = "{np_name} = deco_ufunc_from_impl(_ufunc_impl.{np_name})"
202+
_ufunc_stanza = _ufunc_stanza.format(np_name=ufunc.__name__)
203+
_ufunc_list.append(_ufunc_stanza)
204+
205+
206+
print("\n".join(_impl_list))
207+
208+
print("\n\n")
209+
print("\n".join(_ufunc_list))
210+
211+
212+
213+

torch_np/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
from ._unary_ufuncs import *
77
from ._binary_ufuncs import *
88
from ._ndarray import can_cast, result_type, newaxis
9-
from ._util import AxisError
9+
from ._util import AxisError, UFuncTypeError
1010
from ._getlimits import iinfo, finfo
1111
from ._getlimits import errstate
1212

1313
inf = float('inf')
1414
nan = float('nan')
1515

16+
17+
#### HACK HACK HACK ####
18+
import torch
19+
torch.set_default_dtype(torch.float64)
20+
del torch

0 commit comments

Comments
 (0)