Skip to content

Commit 3833f26

Browse files
committed
Run black on all files sans numpy_tests/
1 parent 731128e commit 3833f26

34 files changed

+5012
-2521
lines changed

autogen/_binary_ufuncs.py

Lines changed: 560 additions & 161 deletions
Large diffs are not rendered by default.

autogen/_unary_ufuncs.py

Lines changed: 707 additions & 267 deletions
Large diffs are not rendered by default.

autogen/array_api_dump.py

Lines changed: 378 additions & 112 deletions
Large diffs are not rendered by default.

autogen/dump_namespace.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ def get_signature(obj):
1919
return obj.__name__ + str(inspect.signature(obj))
2020
except Exception:
2121
# builtins don't have it, try the first line of the docstring
22-
d = obj.__doc__.split('\n')
22+
d = obj.__doc__.split("\n")
2323
if d[0]:
2424
return d[0].strip()
2525
else:
2626
# empty line, maybe the is lines 2-3 (np.vectorize)
27-
return '\n'.join((d[1], d[2])).strip()
27+
return "\n".join((d[1], d[2])).strip()
2828

2929

3030
def dump_signatures(keys, namespace=None, replace=None):
@@ -49,49 +49,47 @@ def dump_signatures(keys, namespace=None, replace=None):
4949

5050
def dump_difference(namespace):
5151
import torch_np
52+
5253
dct_wrapper = grab_namespace(torch_np)
53-
wrapper_funcs =set([obj.__name__ for obj in dct_wrapper['function']])
54+
wrapper_funcs = set([obj.__name__ for obj in dct_wrapper["function"]])
5455

5556
dct_api = grab_namespace(namespace)
56-
namespace_funcs = set(obj.__name__ for obj in dct_api['function'])
57+
namespace_funcs = set(obj.__name__ for obj in dct_api["function"])
5758

5859
missing_names = namespace_funcs.difference(wrapper_funcs)
5960

6061
for name in sorted(missing_names):
61-
print('- [ ]', name)
62+
print("- [ ]", name)
6263

6364
breakpoint()
6465

6566
extras = wrapper_funcs.difference(namespace_funcs)
66-
print('\n\n')
67+
print("\n\n")
6768
for name in sorted(extras):
68-
print('- [ ]', name)
69+
print("- [ ]", name)
6970

7071

7172
if __name__ == "__main__":
7273

73-
# dct = grab_namespace(np)
74-
# print(dct.keys())
75-
76-
# for obj in dct['function']:
77-
# print( get_signature(obj) )
74+
# dct = grab_namespace(np)
75+
# print(dct.keys())
7876

77+
# for obj in dct['function']:
78+
# print( get_signature(obj) )
7979

80-
# dump array_api, full_signatures
81-
# from numpy import array_api
80+
# dump array_api, full_signatures
81+
# from numpy import array_api
8282

83-
# keys = ["builtin_function_or_method", "function"]
84-
# replace = {"<no value>": "NoValue"}
83+
# keys = ["builtin_function_or_method", "function"]
84+
# replace = {"<no value>": "NoValue"}
8585

86-
# print(dump_signatures(keys, namespace=array_api, replace=replace))
86+
# print(dump_signatures(keys, namespace=array_api, replace=replace))
8787

88-
# dump the difference
88+
# dump the difference
8989
from numpy import array_api
9090

9191
dump_difference(array_api)
9292

93-
# keys = ["builtin_function_or_method", "function"]
94-
# replace = {"<no value>": "NoValue"}
95-
# print(dump_signatures(keys, namespace=array_api, replace=replace))
96-
97-
93+
# keys = ["builtin_function_or_method", "function"]
94+
# replace = {"<no value>": "NoValue"}
95+
# print(dump_signatures(keys, namespace=array_api, replace=replace))

autogen/gen_dtypes.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@ class dtype:
1313
def __init__(self, name):
1414
self._name = name
1515

16-
dt_names = ['float16', 'float32', 'float64',
17-
'complex64', 'complex128',
18-
'uint8',
19-
'int8',
20-
'int16',
21-
'int32',
22-
'int64',
23-
'bool']
16+
17+
dt_names = [
18+
"float16",
19+
"float32",
20+
"float64",
21+
"complex64",
22+
"complex128",
23+
"uint8",
24+
"int8",
25+
"int16",
26+
"int32",
27+
"int64",
28+
"bool",
29+
]
2430

2531
templ = """\
2632
{name} = dtype("{name}")
@@ -34,13 +40,12 @@ def __init__(self, name):
3440
print(src)
3541

3642

37-
3843
############### Output the casting dict ############3
3944

40-
_casting_modes = ['no', 'equiv', 'safe', 'same_kind', 'unsafe']
45+
_casting_modes = ["no", "equiv", "safe", "same_kind", "unsafe"]
4146

42-
# The structure is
43-
#_can_cast_dict["safe"]["dtyp1"]["dtyp2"]
47+
# The structure is
48+
# _can_cast_dict["safe"]["dtyp1"]["dtyp2"]
4449

4550

4651
def generate_can_cast(casting):
@@ -49,8 +54,7 @@ def generate_can_cast(casting):
4954
for dtyp1 in dt_names:
5055
dct_dtyp1 = {}
5156
for dtyp2 in dt_names:
52-
can_cast = np.can_cast(np.dtype(dtyp1), np.dtype(dtyp2),
53-
casting=casting)
57+
can_cast = np.can_cast(np.dtype(dtyp1), np.dtype(dtyp2), casting=casting)
5458
dct_dtyp1[dtyp2] = can_cast
5559
dct[dtyp1] = dct_dtyp1
5660
return dct
@@ -69,8 +73,8 @@ def generate_result_type():
6973

7074

7175
# pprint compact=True doesn't quite work :-)
72-
#import pprint
73-
#pprint.pprint(_can_cast_dict['no']['int32'], compact=True, width=100)
76+
# import pprint
77+
# pprint.pprint(_can_cast_dict['no']['int32'], compact=True, width=100)
7478

7579

7680
preamble = f"""
@@ -87,4 +91,3 @@ def generate_result_type():
8791
print("_can_cast_dict = ", _can_cast_dict)
8892
print("\n")
8993
print("_result_type_dict = ", generate_result_type())
90-

autogen/gen_mapping.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010

1111
_all.remove("ndarray")
1212
_all.remove("NoValue")
13-
#_all.remove("mapping")
13+
# _all.remove("mapping")
1414

15-
pieces = [" np.{np_name}: {wrapper_name}, ".format(np_name=name, wrapper_name=name)
16-
for name in sorted(_all)]
15+
pieces = [
16+
" np.{np_name}: {wrapper_name}, ".format(np_name=name, wrapper_name=name)
17+
for name in sorted(_all)
18+
]
1719

1820
# XXX: apply additional manual surgery here, if neeeded.
1921

@@ -22,4 +24,3 @@
2224
f.write("mapping = {\n")
2325
f.write("\n".join(pieces))
2426
f.write("\n}\n")
25-

autogen/gen_ufuncs.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,33 +8,35 @@
88

99

1010
# SKIP these (need special handling)
11-
skip = {np.frexp, np.modf, # non-standard unary ufunc signatures
12-
np.isnat,
13-
np.invert, # bitwise NOT operator
14-
np.spacing, # niche, does not have a direct equivalent
11+
skip = {
12+
np.frexp,
13+
np.modf, # non-standard unary ufunc signatures
14+
np.isnat,
15+
np.invert, # bitwise NOT operator
16+
np.spacing, # niche, does not have a direct equivalent
1517
}
1618

1719
# np functions where torch names differ
18-
torch_names = {np.radians : "deg2rad",
19-
np.degrees : "rad2deg",
20-
np.conjugate : "conj_physical",
21-
np.fabs : "absolute", # FIXME: np.fabs raises form complex
22-
np.rint : "round"
20+
torch_names = {
21+
np.radians: "deg2rad",
22+
np.degrees: "rad2deg",
23+
np.conjugate: "conj_physical",
24+
np.fabs: "absolute", # FIXME: np.fabs raises form complex
25+
np.rint: "round",
2326
}
2427

2528

2629
# np functions which do not have a torch equivalent
2730
default_stanza = "torch.{torch_name}(x, out=out)"
2831

29-
stanzas = {np.cbrt : "torch.pow(x, 1/3, out=out)",
30-
31-
# XXX what on earth is np.positive
32-
np.positive: "+x",
33-
34-
# these three do not have an out arg
35-
np.isinf: "torch.isinf(x)",
36-
np.isnan: "torch.isnan(x)",
37-
np.isfinite: "torch.isfinite(x)",
32+
stanzas = {
33+
np.cbrt: "torch.pow(x, 1/3, out=out)",
34+
# XXX what on earth is np.positive
35+
np.positive: "+x",
36+
# these three do not have an out arg
37+
np.isinf: "torch.isinf(x)",
38+
np.isnan: "torch.isnan(x)",
39+
np.isfinite: "torch.isfinite(x)",
3840
}
3941

4042

@@ -57,13 +59,16 @@
5759
5860
"""
5961

60-
test_header = header + """\
62+
test_header = (
63+
header
64+
+ """\
6165
import numpy as np
6266
import torch
6367
6468
from _unary_ufuncs import *
6569
from testing import assert_allclose
6670
"""
71+
)
6772

6873

6974
template = """
@@ -100,7 +105,7 @@ def test_{np_name}():
100105
main_text = header
101106
test_text = test_header
102107

103-
for ufunc in dct['ufunc']:
108+
for ufunc in dct["ufunc"]:
104109
if ufunc in skip:
105110
continue
106111

@@ -115,11 +120,11 @@ def test_{np_name}():
115120
if torch_stanza is None:
116121
torch_stanza = default_stanza.format(torch_name=torch_name)
117122

118-
out_stanza= add_out_stanza if ufunc in needs_out else ""
123+
out_stanza = add_out_stanza if ufunc in needs_out else ""
119124

120-
main_text += template.format(np_name=ufunc.__name__,
121-
torch_stanza=torch_stanza,
122-
out_stanza=out_stanza)
125+
main_text += template.format(
126+
np_name=ufunc.__name__, torch_stanza=torch_stanza, out_stanza=out_stanza
127+
)
123128
test_text += test_template.format(np_name=ufunc.__name__)
124129

125130
_all_list.append(ufunc.__name__)
@@ -137,14 +142,16 @@ def test_{np_name}():
137142
###### BINARY UFUNCS ###################################
138143

139144

140-
141-
test_header = header + """\
145+
test_header = (
146+
header
147+
+ """\
142148
import numpy as np
143149
import torch
144150
145151
from _binary_ufuncs import *
146152
from testing import assert_allclose
147153
"""
154+
)
148155

149156

150157
template = """
@@ -174,43 +181,43 @@ def test_{np_name}():
174181
"""
175182

176183

177-
178-
skip = {np.divmod, # two outputs
184+
skip = {
185+
np.divmod, # two outputs
179186
}
180187

181188

182-
torch_names = {np.power: "pow",
183-
np.equal: "eq",
189+
torch_names = {
190+
np.power: "pow",
191+
np.equal: "eq",
184192
}
185193

186194

187195
_all_list = []
188196
main_text = header
189197
test_text = test_header
190198

191-
for ufunc in dct['ufunc']:
199+
for ufunc in dct["ufunc"]:
192200

193201
if ufunc in skip:
194202
continue
195203

196204
if ufunc.nin == 2:
197-
# print(get_signature(ufunc))
205+
# print(get_signature(ufunc))
198206

199207
torch_name = torch_names.get(ufunc)
200208
if torch_name is None:
201209
torch_name = ufunc.__name__
202210

203-
204-
main_text += template.format(np_name=ufunc.__name__,
205-
torch_name=torch_name,)
206-
# out_stanza=out_stanza)
211+
main_text += template.format(
212+
np_name=ufunc.__name__,
213+
torch_name=torch_name,
214+
)
215+
# out_stanza=out_stanza)
207216
test_text += test_template.format(np_name=ufunc.__name__)
208217

209218

210-
211219
with open("_binary_ufuncs.py", "w") as f:
212220
f.write(main_text)
213221

214222
with open("test_binary_ufuncs.py", "w") as f:
215223
f.write(test_text)
216-

0 commit comments

Comments
 (0)