Skip to content

Commit 2a45cec

Browse files
Fix crashes with comments in parentheses (#4453)
Co-authored-by: Jelle Zijlstra <[email protected]>
1 parent b4d6d86 commit 2a45cec

File tree

6 files changed

+185
-34
lines changed

6 files changed

+185
-34
lines changed

CHANGES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
<!-- Changes that affect Black's stable style -->
2121

22+
- Fix crashes involving comments in parenthesised return types or `X | Y` style unions.
23+
(#4453)
24+
2225
### Preview style
2326

2427
<!-- Changes that affect Black's preview style -->

src/black/linegen.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,47 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None
10791079
)
10801080

10811081

1082+
def _ensure_trailing_comma(
1083+
leaves: List[Leaf], original: Line, opening_bracket: Leaf
1084+
) -> bool:
1085+
if not leaves:
1086+
return False
1087+
# Ensure a trailing comma for imports
1088+
if original.is_import:
1089+
return True
1090+
# ...and standalone function arguments
1091+
if not original.is_def:
1092+
return False
1093+
if opening_bracket.value != "(":
1094+
return False
1095+
# Don't add commas if we already have any commas
1096+
if any(
1097+
leaf.type == token.COMMA
1098+
and (
1099+
Preview.typed_params_trailing_comma not in original.mode
1100+
or not is_part_of_annotation(leaf)
1101+
)
1102+
for leaf in leaves
1103+
):
1104+
return False
1105+
1106+
# Find a leaf with a parent (comments don't have parents)
1107+
leaf_with_parent = next((leaf for leaf in leaves if leaf.parent), None)
1108+
if leaf_with_parent is None:
1109+
return True
1110+
# Don't add commas inside parenthesized return annotations
1111+
if get_annotation_type(leaf_with_parent) == "return":
1112+
return False
1113+
# Don't add commas inside PEP 604 unions
1114+
if (
1115+
leaf_with_parent.parent
1116+
and leaf_with_parent.parent.next_sibling
1117+
and leaf_with_parent.parent.next_sibling.type == token.VBAR
1118+
):
1119+
return False
1120+
return True
1121+
1122+
10821123
def bracket_split_build_line(
10831124
leaves: List[Leaf],
10841125
original: Line,
@@ -1099,40 +1140,15 @@ def bracket_split_build_line(
10991140
if component is _BracketSplitComponent.body:
11001141
result.inside_brackets = True
11011142
result.depth += 1
1102-
if leaves:
1103-
no_commas = (
1104-
# Ensure a trailing comma for imports and standalone function arguments
1105-
original.is_def
1106-
# Don't add one after any comments or within type annotations
1107-
and opening_bracket.value == "("
1108-
# Don't add one if there's already one there
1109-
and not any(
1110-
leaf.type == token.COMMA
1111-
and (
1112-
Preview.typed_params_trailing_comma not in original.mode
1113-
or not is_part_of_annotation(leaf)
1114-
)
1115-
for leaf in leaves
1116-
)
1117-
# Don't add one inside parenthesized return annotations
1118-
and get_annotation_type(leaves[0]) != "return"
1119-
# Don't add one inside PEP 604 unions
1120-
and not (
1121-
leaves[0].parent
1122-
and leaves[0].parent.next_sibling
1123-
and leaves[0].parent.next_sibling.type == token.VBAR
1124-
)
1125-
)
1126-
1127-
if original.is_import or no_commas:
1128-
for i in range(len(leaves) - 1, -1, -1):
1129-
if leaves[i].type == STANDALONE_COMMENT:
1130-
continue
1143+
if _ensure_trailing_comma(leaves, original, opening_bracket):
1144+
for i in range(len(leaves) - 1, -1, -1):
1145+
if leaves[i].type == STANDALONE_COMMENT:
1146+
continue
11311147

1132-
if leaves[i].type != token.COMMA:
1133-
new_comma = Leaf(token.COMMA, ",")
1134-
leaves.insert(i + 1, new_comma)
1135-
break
1148+
if leaves[i].type != token.COMMA:
1149+
new_comma = Leaf(token.COMMA, ",")
1150+
leaves.insert(i + 1, new_comma)
1151+
break
11361152

11371153
leaves_to_track: Set[LeafID] = set()
11381154
if component is _BracketSplitComponent.head:

src/black/nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,7 @@ def get_annotation_type(leaf: Leaf) -> Literal["return", "param", None]:
10121012

10131013
def is_part_of_annotation(leaf: Leaf) -> bool:
10141014
"""Returns whether this leaf is part of a type annotation."""
1015+
assert leaf.parent is not None
10151016
return get_annotation_type(leaf) is not None
10161017

10171018

src/black/trans.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def do_match(self, line: Line) -> TMatchResult:
488488
break
489489
i += 1
490490

491-
if not is_part_of_annotation(leaf) and not contains_comment:
491+
if not contains_comment and not is_part_of_annotation(leaf):
492492
string_indices.append(idx)
493493

494494
# Advance to the next non-STRING leaf.

tests/data/cases/funcdef_return_type_trailing_comma.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def SimplePyFn(
142142
Buffer[UInt8, 2],
143143
Buffer[UInt8, 2],
144144
]: ...
145+
145146
# output
146147
# normal, short, function definition
147148
def foo(a, b) -> tuple[int, float]: ...

tests/data/cases/function_trailing_comma.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,64 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr
6060
argument1, (one, two,), argument4, argument5, argument6
6161
)
6262

63+
def foo() -> (
64+
# comment inside parenthesised return type
65+
int
66+
):
67+
...
68+
69+
def foo() -> (
70+
# comment inside parenthesised return type
71+
# more
72+
int
73+
# another
74+
):
75+
...
76+
77+
def foo() -> (
78+
# comment inside parenthesised new union return type
79+
int | str | bytes
80+
):
81+
...
82+
83+
def foo() -> (
84+
# comment inside plain tuple
85+
):
86+
pass
87+
88+
def foo(arg: (# comment with non-return annotation
89+
int
90+
# comment with non-return annotation
91+
)):
92+
pass
93+
94+
def foo(arg: (# comment with non-return annotation
95+
int | range | memoryview
96+
# comment with non-return annotation
97+
)):
98+
pass
99+
100+
def foo(arg: (# only before
101+
int
102+
)):
103+
pass
104+
105+
def foo(arg: (
106+
int
107+
# only after
108+
)):
109+
pass
110+
111+
variable: ( # annotation
112+
because
113+
# why not
114+
)
115+
116+
variable: (
117+
because
118+
# why not
119+
)
120+
63121
# output
64122

65123
def f(
@@ -176,3 +234,75 @@ def func() -> (
176234
argument5,
177235
argument6,
178236
)
237+
238+
239+
def foo() -> (
240+
# comment inside parenthesised return type
241+
int
242+
): ...
243+
244+
245+
def foo() -> (
246+
# comment inside parenthesised return type
247+
# more
248+
int
249+
# another
250+
): ...
251+
252+
253+
def foo() -> (
254+
# comment inside parenthesised new union return type
255+
int
256+
| str
257+
| bytes
258+
): ...
259+
260+
261+
def foo() -> (
262+
# comment inside plain tuple
263+
):
264+
pass
265+
266+
267+
def foo(
268+
arg: ( # comment with non-return annotation
269+
int
270+
# comment with non-return annotation
271+
),
272+
):
273+
pass
274+
275+
276+
def foo(
277+
arg: ( # comment with non-return annotation
278+
int
279+
| range
280+
| memoryview
281+
# comment with non-return annotation
282+
),
283+
):
284+
pass
285+
286+
287+
def foo(arg: int): # only before
288+
pass
289+
290+
291+
def foo(
292+
arg: (
293+
int
294+
# only after
295+
),
296+
):
297+
pass
298+
299+
300+
variable: ( # annotation
301+
because
302+
# why not
303+
)
304+
305+
variable: (
306+
because
307+
# why not
308+
)

0 commit comments

Comments
 (0)