Skip to content

Commit 34d3490

Browse files
committed
feat: add assist for applying De Morgan's law to iterators
1 parent 0840038 commit 34d3490

File tree

3 files changed

+352
-2
lines changed

3 files changed

+352
-2
lines changed

crates/ide-assists/src/handlers/apply_demorgan.rs

Lines changed: 327 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
use std::collections::VecDeque;
22

3+
use ide_db::{
4+
assists::GroupLabel,
5+
famous_defs::FamousDefs,
6+
source_change::SourceChangeBuilder,
7+
syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
8+
};
39
use syntax::{
4-
ast::{self, AstNode, Expr::BinExpr},
10+
ast::{self, make, AstNode, Expr::BinExpr, HasArgList},
511
ted::{self, Position},
612
SyntaxKind,
713
};
@@ -89,7 +95,8 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
8995

9096
let dm_lhs = demorganed.lhs()?;
9197

92-
acc.add(
98+
acc.add_group(
99+
&GroupLabel("Apply De Morgan's law".to_string()),
93100
AssistId("apply_demorgan", AssistKind::RefactorRewrite),
94101
"Apply De Morgan's law",
95102
op_range,
@@ -143,6 +150,122 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
143150
)
144151
}
145152

153+
// Assist: apply_demorgan_iterator
154+
//
155+
// Apply https://en.wikipedia.org/wiki/De_Morgan%27s_laws[De Morgan's law] to
156+
// `Iterator::all` and `Iterator::any`.
157+
//
158+
// This transforms expressions of the form `!iter.any(|x| predicate(x))` into
159+
// `iter.all(|x| !predicate(x))` and vice versa. This also works the other way for
160+
// `Iterator::all` into `Iterator::any`.
161+
//
162+
// ```
163+
// # //- minicore: iterator
164+
// fn main() {
165+
// let arr = [1, 2, 3];
166+
// if !arr.into_iter().$0any(|num| num == 4) {
167+
// println!("foo");
168+
// }
169+
// }
170+
// ```
171+
// ->
172+
// ```
173+
// fn main() {
174+
// let arr = [1, 2, 3];
175+
// if arr.into_iter().all(|num| num != 4) {
176+
// println!("foo");
177+
// }
178+
// }
179+
// ```
180+
pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
181+
let method_call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
182+
let (name, arg_expr) = validate_method_call_expr(ctx, &method_call)?;
183+
184+
let ast::Expr::ClosureExpr(closure_expr) = arg_expr else { return None };
185+
let closure_body = closure_expr.body()?;
186+
187+
let op_range = method_call.syntax().text_range();
188+
let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str());
189+
acc.add_group(
190+
&GroupLabel("Apply De Morgan's law".to_string()),
191+
AssistId("apply_demorgan_iterator", AssistKind::RefactorRewrite),
192+
label,
193+
op_range,
194+
|edit| {
195+
// replace the method name
196+
let new_name = match name.text().as_str() {
197+
"all" => make::name_ref("any"),
198+
"any" => make::name_ref("all"),
199+
_ => unreachable!(),
200+
}
201+
.clone_for_update();
202+
edit.replace_ast(name, new_name);
203+
204+
// negate all tail expressions in the closure body
205+
let tail_cb = &mut |e: &_| tail_cb_impl(edit, e);
206+
walk_expr(&closure_body, &mut |expr| {
207+
if let ast::Expr::ReturnExpr(ret_expr) = expr {
208+
if let Some(ret_expr_arg) = &ret_expr.expr() {
209+
for_each_tail_expr(ret_expr_arg, tail_cb);
210+
}
211+
}
212+
});
213+
for_each_tail_expr(&closure_body, tail_cb);
214+
215+
// negate the whole method call
216+
if let Some(prefix_expr) = method_call
217+
.syntax()
218+
.parent()
219+
.and_then(ast::PrefixExpr::cast)
220+
.filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
221+
{
222+
edit.delete(prefix_expr.op_token().unwrap().text_range());
223+
} else {
224+
edit.insert(method_call.syntax().text_range().start(), "!");
225+
}
226+
},
227+
)
228+
}
229+
230+
/// Ensures that the method call is to `Iterator::all` or `Iterator::any`.
231+
fn validate_method_call_expr(
232+
ctx: &AssistContext<'_>,
233+
method_call: &ast::MethodCallExpr,
234+
) -> Option<(ast::NameRef, ast::Expr)> {
235+
let name_ref = method_call.name_ref()?;
236+
if name_ref.text() != "all" && name_ref.text() != "any" {
237+
return None;
238+
}
239+
let arg_expr = method_call.arg_list()?.args().next()?;
240+
241+
let sema = &ctx.sema;
242+
243+
let receiver = method_call.receiver()?;
244+
let it_type = sema.type_of_expr(&receiver)?.adjusted();
245+
let module = sema.scope(receiver.syntax())?.module();
246+
let krate = module.krate();
247+
248+
let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
249+
it_type.impls_trait(sema.db, iter_trait, &[]).then_some((name_ref, arg_expr))
250+
}
251+
252+
fn tail_cb_impl(edit: &mut SourceChangeBuilder, e: &ast::Expr) {
253+
match e {
254+
ast::Expr::BreakExpr(break_expr) => {
255+
if let Some(break_expr_arg) = break_expr.expr() {
256+
for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(edit, e))
257+
}
258+
}
259+
ast::Expr::ReturnExpr(_) => {
260+
// all return expressions have already been handled by the walk loop
261+
}
262+
e => {
263+
let inverted_body = invert_boolean_expression(e.clone());
264+
edit.replace(e.syntax().text_range(), inverted_body.syntax().text());
265+
}
266+
}
267+
}
268+
146269
#[cfg(test)]
147270
mod tests {
148271
use super::*;
@@ -255,4 +378,206 @@ fn f() { !(S <= S || S < S) }
255378
"fn() { let x = a && b && c; }",
256379
)
257380
}
381+
382+
#[test]
383+
fn demorgan_iterator_any_all_reverse() {
384+
check_assist(
385+
apply_demorgan_iterator,
386+
r#"
387+
//- minicore: iterator
388+
fn main() {
389+
let arr = [1, 2, 3];
390+
if arr.into_iter().all(|num| num $0!= 4) {
391+
println!("foo");
392+
}
393+
}
394+
"#,
395+
r#"
396+
fn main() {
397+
let arr = [1, 2, 3];
398+
if !arr.into_iter().any(|num| num == 4) {
399+
println!("foo");
400+
}
401+
}
402+
"#,
403+
);
404+
}
405+
406+
#[test]
407+
fn demorgan_iterator_all_any() {
408+
check_assist(
409+
apply_demorgan_iterator,
410+
r#"
411+
//- minicore: iterator
412+
fn main() {
413+
let arr = [1, 2, 3];
414+
if !arr.into_iter().$0all(|num| num > 3) {
415+
println!("foo");
416+
}
417+
}
418+
"#,
419+
r#"
420+
fn main() {
421+
let arr = [1, 2, 3];
422+
if arr.into_iter().any(|num| num <= 3) {
423+
println!("foo");
424+
}
425+
}
426+
"#,
427+
);
428+
}
429+
430+
#[test]
431+
fn demorgan_iterator_multiple_terms() {
432+
check_assist(
433+
apply_demorgan_iterator,
434+
r#"
435+
//- minicore: iterator
436+
fn main() {
437+
let arr = [1, 2, 3];
438+
if !arr.into_iter().$0any(|num| num > 3 && num == 23 && num <= 30) {
439+
println!("foo");
440+
}
441+
}
442+
"#,
443+
r#"
444+
fn main() {
445+
let arr = [1, 2, 3];
446+
if arr.into_iter().all(|num| !(num > 3 && num == 23 && num <= 30)) {
447+
println!("foo");
448+
}
449+
}
450+
"#,
451+
);
452+
}
453+
454+
#[test]
455+
fn demorgan_iterator_double_negation() {
456+
check_assist(
457+
apply_demorgan_iterator,
458+
r#"
459+
//- minicore: iterator
460+
fn main() {
461+
let arr = [1, 2, 3];
462+
if !arr.into_iter().$0all(|num| !(num > 3)) {
463+
println!("foo");
464+
}
465+
}
466+
"#,
467+
r#"
468+
fn main() {
469+
let arr = [1, 2, 3];
470+
if arr.into_iter().any(|num| num > 3) {
471+
println!("foo");
472+
}
473+
}
474+
"#,
475+
);
476+
}
477+
478+
#[test]
479+
fn demorgan_iterator_double_parens() {
480+
check_assist(
481+
apply_demorgan_iterator,
482+
r#"
483+
//- minicore: iterator
484+
fn main() {
485+
let arr = [1, 2, 3];
486+
if !arr.into_iter().$0any(|num| (num > 3 && (num == 1 || num == 2))) {
487+
println!("foo");
488+
}
489+
}
490+
"#,
491+
r#"
492+
fn main() {
493+
let arr = [1, 2, 3];
494+
if arr.into_iter().all(|num| !(num > 3 && (num == 1 || num == 2))) {
495+
println!("foo");
496+
}
497+
}
498+
"#,
499+
);
500+
}
501+
502+
#[test]
503+
fn demorgan_iterator_multiline() {
504+
check_assist(
505+
apply_demorgan_iterator,
506+
r#"
507+
//- minicore: iterator
508+
fn main() {
509+
let arr = [1, 2, 3];
510+
if arr
511+
.into_iter()
512+
.all$0(|num| !num.is_negative())
513+
{
514+
println!("foo");
515+
}
516+
}
517+
"#,
518+
r#"
519+
fn main() {
520+
let arr = [1, 2, 3];
521+
if !arr
522+
.into_iter()
523+
.any(|num| num.is_negative())
524+
{
525+
println!("foo");
526+
}
527+
}
528+
"#,
529+
);
530+
}
531+
532+
#[test]
533+
fn demorgan_iterator_block_closure() {
534+
check_assist(
535+
apply_demorgan_iterator,
536+
r#"
537+
//- minicore: iterator
538+
fn main() {
539+
let arr = [-1, 1, 2, 3];
540+
if arr.into_iter().all(|num: i32| {
541+
$0if num.is_positive() {
542+
num <= 3
543+
} else {
544+
num >= -1
545+
}
546+
}) {
547+
println!("foo");
548+
}
549+
}
550+
"#,
551+
r#"
552+
fn main() {
553+
let arr = [-1, 1, 2, 3];
554+
if !arr.into_iter().any(|num: i32| {
555+
if num.is_positive() {
556+
num > 3
557+
} else {
558+
num < -1
559+
}
560+
}) {
561+
println!("foo");
562+
}
563+
}
564+
"#,
565+
);
566+
}
567+
568+
#[test]
569+
fn demorgan_iterator_wrong_method() {
570+
check_assist_not_applicable(
571+
apply_demorgan_iterator,
572+
r#"
573+
//- minicore: iterator
574+
fn main() {
575+
let arr = [1, 2, 3];
576+
if !arr.into_iter().$0map(|num| num > 3) {
577+
println!("foo");
578+
}
579+
}
580+
"#,
581+
);
582+
}
258583
}

crates/ide-assists/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ mod handlers {
226226
add_return_type::add_return_type,
227227
add_turbo_fish::add_turbo_fish,
228228
apply_demorgan::apply_demorgan,
229+
apply_demorgan::apply_demorgan_iterator,
229230
auto_import::auto_import,
230231
bind_unused_param::bind_unused_param,
231232
bool_to_enum::bool_to_enum,

crates/ide-assists/src/tests/generated.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,30 @@ fn main() {
244244
)
245245
}
246246

247+
#[test]
248+
fn doctest_apply_demorgan_iterator() {
249+
check_doc_test(
250+
"apply_demorgan_iterator",
251+
r#####"
252+
//- minicore: iterator
253+
fn main() {
254+
let arr = [1, 2, 3];
255+
if !arr.into_iter().$0any(|num| num == 4) {
256+
println!("foo");
257+
}
258+
}
259+
"#####,
260+
r#####"
261+
fn main() {
262+
let arr = [1, 2, 3];
263+
if arr.into_iter().all(|num| num != 4) {
264+
println!("foo");
265+
}
266+
}
267+
"#####,
268+
)
269+
}
270+
247271
#[test]
248272
fn doctest_auto_import() {
249273
check_doc_test(

0 commit comments

Comments
 (0)