Skip to content

Commit 7ba2e13

Browse files
committed
fix: add checks for overwriting incorrect ancestor
1 parent 2e13aed commit 7ba2e13

File tree

1 file changed

+165
-1
lines changed

1 file changed

+165
-1
lines changed

crates/ide-assists/src/handlers/bool_to_enum.rs

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,11 @@ fn replace_usages(
263263
fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> {
264264
let bin_expr = name_ref.syntax().ancestors().find_map(ast::BinExpr::cast)?;
265265

266+
if !bin_expr.lhs()?.syntax().descendants().contains(name_ref.syntax()) {
267+
cov_mark::hit!(dont_assign_incorrect_ref);
268+
return None;
269+
}
270+
266271
if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() {
267272
bin_expr.rhs()
268273
} else {
@@ -273,6 +278,11 @@ fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> {
273278
fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::Expr)> {
274279
let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?;
275280

281+
if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) {
282+
cov_mark::hit!(dont_overwrite_expression_inside_negation);
283+
return None;
284+
}
285+
276286
if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() {
277287
let inner_expr = prefix_expr.expr()?;
278288
Some((prefix_expr, inner_expr))
@@ -285,7 +295,12 @@ fn find_record_expr_usage(name_ref: &ast::NameRef) -> Option<(ast::RecordExprFie
285295
let record_field = name_ref.syntax().ancestors().find_map(ast::RecordExprField::cast)?;
286296
let initializer = record_field.expr()?;
287297

288-
Some((record_field, initializer))
298+
if record_field.field_name()?.syntax().descendants().contains(name_ref.syntax()) {
299+
Some((record_field, initializer))
300+
} else {
301+
cov_mark::hit!(dont_overwrite_wrong_record_field);
302+
None
303+
}
289304
}
290305

291306
/// Adds the definition of the new enum before the target node.
@@ -561,6 +576,37 @@ fn main() {
561576
)
562577
}
563578

579+
#[test]
580+
fn local_variable_nested_in_negation() {
581+
cov_mark::check!(dont_overwrite_expression_inside_negation);
582+
check_assist(
583+
bool_to_enum,
584+
r#"
585+
fn main() {
586+
if !"foo".chars().any(|c| {
587+
let $0foo = true;
588+
foo
589+
}) {
590+
println!("foo");
591+
}
592+
}
593+
"#,
594+
r#"
595+
fn main() {
596+
if !"foo".chars().any(|c| {
597+
#[derive(PartialEq, Eq)]
598+
enum Bool { True, False }
599+
600+
let foo = Bool::True;
601+
foo == Bool::True
602+
}) {
603+
println!("foo");
604+
}
605+
}
606+
"#,
607+
)
608+
}
609+
564610
#[test]
565611
fn local_variable_non_bool() {
566612
cov_mark::check!(not_applicable_non_bool_local);
@@ -638,6 +684,42 @@ fn main() {
638684
)
639685
}
640686

687+
#[test]
688+
fn field_negated() {
689+
check_assist(
690+
bool_to_enum,
691+
r#"
692+
struct Foo {
693+
$0bar: bool,
694+
}
695+
696+
fn main() {
697+
let foo = Foo { bar: false };
698+
699+
if !foo.bar {
700+
println!("foo");
701+
}
702+
}
703+
"#,
704+
r#"
705+
#[derive(PartialEq, Eq)]
706+
enum Bool { True, False }
707+
708+
struct Foo {
709+
bar: Bool,
710+
}
711+
712+
fn main() {
713+
let foo = Foo { bar: Bool::False };
714+
715+
if foo.bar == Bool::False {
716+
println!("foo");
717+
}
718+
}
719+
"#,
720+
)
721+
}
722+
641723
#[test]
642724
fn field_in_mod_properly_indented() {
643725
check_assist(
@@ -714,6 +796,88 @@ fn main() {
714796
)
715797
}
716798

799+
#[test]
800+
fn field_assigned_to_another() {
801+
cov_mark::check!(dont_assign_incorrect_ref);
802+
check_assist(
803+
bool_to_enum,
804+
r#"
805+
struct Foo {
806+
$0foo: bool,
807+
}
808+
809+
struct Bar {
810+
bar: bool,
811+
}
812+
813+
fn main() {
814+
let foo = Foo { foo: true };
815+
let mut bar = Bar { bar: true };
816+
817+
bar.bar = foo.foo;
818+
}
819+
"#,
820+
r#"
821+
#[derive(PartialEq, Eq)]
822+
enum Bool { True, False }
823+
824+
struct Foo {
825+
foo: Bool,
826+
}
827+
828+
struct Bar {
829+
bar: bool,
830+
}
831+
832+
fn main() {
833+
let foo = Foo { foo: Bool::True };
834+
let mut bar = Bar { bar: true };
835+
836+
bar.bar = foo.foo == Bool::True;
837+
}
838+
"#,
839+
)
840+
}
841+
842+
#[test]
843+
fn field_initialized_with_other() {
844+
cov_mark::check!(dont_overwrite_wrong_record_field);
845+
check_assist(
846+
bool_to_enum,
847+
r#"
848+
struct Foo {
849+
$0foo: bool,
850+
}
851+
852+
struct Bar {
853+
bar: bool,
854+
}
855+
856+
fn main() {
857+
let foo = Foo { foo: true };
858+
let bar = Bar { bar: foo.foo };
859+
}
860+
"#,
861+
r#"
862+
#[derive(PartialEq, Eq)]
863+
enum Bool { True, False }
864+
865+
struct Foo {
866+
foo: Bool,
867+
}
868+
869+
struct Bar {
870+
bar: bool,
871+
}
872+
873+
fn main() {
874+
let foo = Foo { foo: Bool::True };
875+
let bar = Bar { bar: foo.foo == Bool::True };
876+
}
877+
"#,
878+
)
879+
}
880+
717881
#[test]
718882
fn field_non_bool() {
719883
cov_mark::check!(not_applicable_non_bool_field);

0 commit comments

Comments
 (0)