Skip to content

Commit 6e9471d

Browse files
committed
Move parse_op_special_cases functionality to parse_special_cases
1 parent 0ed66ef commit 6e9471d

File tree

1 file changed

+10
-44
lines changed

1 file changed

+10
-44
lines changed

generate_stubs.py

+10-44
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,11 @@ def {annotated_sig}:{doc}
361361
with open(py_path, 'w') as f:
362362
f.write(code)
363363
elif filename == 'array_object.md':
364-
special_cases = parse_op_special_cases(text, verbose=args.verbose)
364+
special_cases = parse_special_cases(text, verbose=args.verbose)
365+
for name in IN_PLACE_OPERATOR_RE.findall(text):
366+
op = f"__{name}__"
367+
iop = f"__i{name}__"
368+
special_cases[iop] = special_cases[op]
365369
for func in special_cases:
366370
py_path = os.path.join('array_api_tests', 'special_cases', f'test_{func}.py')
367371
tests = []
@@ -972,47 +976,15 @@ def generate_special_case_test(func, typ, m, test_name_extra, sigs):
972976
raise RuntimeError(f"Unexpected type {typ}")
973977

974978
def parse_special_cases(spec_text, verbose=False) -> Dict[str, DefaultDict[str, List[regex.Match]]]:
975-
special_cases = {}
976-
in_block = False
977-
for line in spec_text.splitlines():
978-
m = FUNCTION_HEADER_RE.match(line)
979-
if m:
980-
name = m.group(1)
981-
special_cases[name] = defaultdict(list)
982-
continue
983-
if line == '#### Special Cases':
984-
in_block = True
985-
continue
986-
elif line.startswith('#'):
987-
in_block = False
988-
continue
989-
if in_block:
990-
if '- ' not in line:
991-
continue
992-
for typ, reg in SPECIAL_CASE_REGEXS.items():
993-
m = reg.match(line)
994-
if m:
995-
if verbose:
996-
print(f"Matched {typ} for {name}: {m.groups()}")
997-
special_cases[name][typ].append(m)
998-
break
999-
else:
1000-
raise ValueError(f"Unrecognized special case string for '{name}':\n{line}")
1001-
1002-
return special_cases
1003-
1004-
def parse_op_special_cases(spec_text, verbose=False) -> Dict[str, DefaultDict[str, List[regex.Match]]]:
1005979
special_cases = {}
1006980
in_block = False
1007981
name = None
1008982
for line in spec_text.splitlines():
1009-
m = METHOD_HEADER_RE.match(line)
1010-
if m:
1011-
name = m.group(1)
1012-
if name in OPS:
1013-
special_cases[name] = defaultdict(list)
1014-
else:
1015-
name = None
983+
func_m = FUNCTION_HEADER_RE.match(line)
984+
meth_m = METHOD_HEADER_RE.match(line)
985+
if func_m or meth_m:
986+
name = func_m.group(1) if func_m else meth_m.group(1)
987+
special_cases[name] = defaultdict(list)
1016988
continue
1017989
if line == '#### Special Cases':
1018990
in_block = True
@@ -1032,12 +1004,6 @@ def parse_op_special_cases(spec_text, verbose=False) -> Dict[str, DefaultDict[st
10321004
break
10331005
else:
10341006
raise ValueError(f"Unrecognized special case string for '{name}':\n{line}")
1035-
for line in spec_text.splitlines():
1036-
for name in IN_PLACE_OPERATOR_RE.findall(spec_text):
1037-
op = f"__{name}__"
1038-
iop = f"__i{name}__"
1039-
special_cases[iop] = special_cases[op]
1040-
continue
10411007

10421008
return special_cases
10431009

0 commit comments

Comments
 (0)