Skip to content

Commit 2c1926b

Browse files
Insert parentheses for multi-argument generators (#12422)
## Summary Closes #12420.
1 parent 4bcc96a commit 2c1926b

File tree

3 files changed

+71
-10
lines changed

3 files changed

+71
-10
lines changed

crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C419_1.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,18 @@
33
max([x.val for x in bar])
44
sum([x.val for x in bar], 0)
55

6-
# Ok
6+
# OK
77
sum(x.val for x in bar)
88
min(x.val for x in bar)
99
max(x.val for x in bar)
1010
sum(x.val for x in bar, 0)
11+
12+
# Multi-line
13+
sum(
14+
[
15+
delta
16+
for delta in timedelta_list
17+
if delta
18+
],
19+
dt.timedelta(),
20+
)

crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_comprehension_in_call.rs

+26-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use ruff_python_ast::{self as ast, Expr, Keyword};
22

3-
use ruff_diagnostics::Violation;
43
use ruff_diagnostics::{Diagnostic, FixAvailability};
4+
use ruff_diagnostics::{Edit, Fix, Violation};
55
use ruff_macros::{derive_message_formats, violation};
66
use ruff_python_ast::helpers::any_over_expr;
7-
use ruff_text_size::Ranged;
7+
use ruff_text_size::{Ranged, TextSize};
88

99
use crate::checkers::ast::Checker;
1010

@@ -112,9 +112,30 @@ pub(crate) fn unnecessary_comprehension_in_call(
112112
}
113113

114114
let mut diagnostic = Diagnostic::new(UnnecessaryComprehensionInCall, arg.range());
115-
diagnostic.try_set_fix(|| {
116-
fixes::fix_unnecessary_comprehension_in_call(expr, checker.locator(), checker.stylist())
117-
});
115+
116+
if args.len() == 1 {
117+
// If there's only one argument, remove the list or set brackets.
118+
diagnostic.try_set_fix(|| {
119+
fixes::fix_unnecessary_comprehension_in_call(expr, checker.locator(), checker.stylist())
120+
});
121+
} else {
122+
// If there are multiple arguments, replace the list or set brackets with parentheses.
123+
// If a function call has multiple arguments, one of which is a generator, then the
124+
// generator must be parenthesized.
125+
126+
// Replace `[` with `(`.
127+
let collection_start = Edit::replacement(
128+
"(".to_string(),
129+
arg.start(),
130+
arg.start() + TextSize::from(1),
131+
);
132+
133+
// Replace `]` with `)`.
134+
let collection_end =
135+
Edit::replacement(")".to_string(), arg.end() - TextSize::from(1), arg.end());
136+
137+
diagnostic.set_fix(Fix::unsafe_edits(collection_start, [collection_end]));
138+
}
118139
checker.diagnostics.push(diagnostic);
119140
}
120141

crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__preview__C419_C419_1.py.snap

+34-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ C419_1.py:3:5: C419 [*] Unnecessary list comprehension
5252
3 |+max(x.val for x in bar)
5353
4 4 | sum([x.val for x in bar], 0)
5454
5 5 |
55-
6 6 | # Ok
55+
6 6 | # OK
5656

5757
C419_1.py:4:5: C419 [*] Unnecessary list comprehension
5858
|
@@ -61,7 +61,7 @@ C419_1.py:4:5: C419 [*] Unnecessary list comprehension
6161
4 | sum([x.val for x in bar], 0)
6262
| ^^^^^^^^^^^^^^^^^^^^ C419
6363
5 |
64-
6 | # Ok
64+
6 | # OK
6565
|
6666
= help: Remove unnecessary list comprehension
6767

@@ -70,7 +70,37 @@ C419_1.py:4:5: C419 [*] Unnecessary list comprehension
7070
2 2 | min([x.val for x in bar])
7171
3 3 | max([x.val for x in bar])
7272
4 |-sum([x.val for x in bar], 0)
73-
4 |+sum(x.val for x in bar, 0)
73+
4 |+sum((x.val for x in bar), 0)
7474
5 5 |
75-
6 6 | # Ok
75+
6 6 | # OK
7676
7 7 | sum(x.val for x in bar)
77+
78+
C419_1.py:14:5: C419 [*] Unnecessary list comprehension
79+
|
80+
12 | # Multi-line
81+
13 | sum(
82+
14 | [
83+
| _____^
84+
15 | | delta
85+
16 | | for delta in timedelta_list
86+
17 | | if delta
87+
18 | | ],
88+
| |_____^ C419
89+
19 | dt.timedelta(),
90+
20 | )
91+
|
92+
= help: Remove unnecessary list comprehension
93+
94+
Unsafe fix
95+
11 11 |
96+
12 12 | # Multi-line
97+
13 13 | sum(
98+
14 |- [
99+
14 |+ (
100+
15 15 | delta
101+
16 16 | for delta in timedelta_list
102+
17 17 | if delta
103+
18 |- ],
104+
18 |+ ),
105+
19 19 | dt.timedelta(),
106+
20 20 | )

0 commit comments

Comments
 (0)