1
1
use ide_db:: defs:: { Definition , NameRefClass } ;
2
2
use syntax:: {
3
- ast:: { self , HasName } ,
3
+ ast:: { self , HasName , Name } ,
4
4
ted, AstNode , SyntaxNode ,
5
5
} ;
6
6
@@ -48,15 +48,15 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'
48
48
other => format ! ( "{{ {other} }}" ) ,
49
49
} ;
50
50
let extracting_arm_pat = extracting_arm. pat ( ) ?;
51
- let extracted_variable = find_extracted_variable ( ctx, & extracting_arm) ?;
51
+ let extracted_variable_positions = find_extracted_variable ( ctx, & extracting_arm) ?;
52
52
53
53
acc. add (
54
54
AssistId ( "convert_match_to_let_else" , AssistKind :: RefactorRewrite ) ,
55
55
"Convert match to let-else" ,
56
56
let_stmt. syntax ( ) . text_range ( ) ,
57
57
|builder| {
58
58
let extracting_arm_pat =
59
- rename_variable ( & extracting_arm_pat, extracted_variable , binding) ;
59
+ rename_variable ( & extracting_arm_pat, & extracted_variable_positions , binding) ;
60
60
builder. replace (
61
61
let_stmt. syntax ( ) . text_range ( ) ,
62
62
format ! ( "let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};" ) ,
@@ -95,14 +95,15 @@ fn find_arms(
95
95
}
96
96
97
97
// Given an extracting arm, find the extracted variable.
98
- fn find_extracted_variable ( ctx : & AssistContext < ' _ > , arm : & ast:: MatchArm ) -> Option < ast :: Name > {
98
+ fn find_extracted_variable ( ctx : & AssistContext < ' _ > , arm : & ast:: MatchArm ) -> Option < Vec < Name > > {
99
99
match arm. expr ( ) ? {
100
100
ast:: Expr :: PathExpr ( path) => {
101
101
let name_ref = path. syntax ( ) . descendants ( ) . find_map ( ast:: NameRef :: cast) ?;
102
102
match NameRefClass :: classify ( & ctx. sema , & name_ref) ? {
103
103
NameRefClass :: Definition ( Definition :: Local ( local) ) => {
104
- let source = local. primary_source ( ctx. db ( ) ) . into_ident_pat ( ) ?;
105
- Some ( source. name ( ) ?)
104
+ let source =
105
+ local. sources ( ctx. db ( ) ) . into_iter ( ) . map ( |x| x. into_ident_pat ( ) ?. name ( ) ) ;
106
+ source. collect ( )
106
107
}
107
108
_ => None ,
108
109
}
@@ -115,27 +116,34 @@ fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Opti
115
116
}
116
117
117
118
// Rename `extracted` with `binding` in `pat`.
118
- fn rename_variable ( pat : & ast:: Pat , extracted : ast :: Name , binding : ast:: Pat ) -> SyntaxNode {
119
+ fn rename_variable ( pat : & ast:: Pat , extracted : & [ Name ] , binding : ast:: Pat ) -> SyntaxNode {
119
120
let syntax = pat. syntax ( ) . clone_for_update ( ) ;
120
- let extracted_syntax = syntax. covering_element ( extracted. syntax ( ) . text_range ( ) ) ;
121
-
122
- // If `extracted` variable is a record field, we should rename it to `binding`,
123
- // otherwise we just need to replace `extracted` with `binding`.
124
-
125
- if let Some ( record_pat_field) = extracted_syntax. ancestors ( ) . find_map ( ast:: RecordPatField :: cast)
126
- {
127
- if let Some ( name_ref) = record_pat_field. field_name ( ) {
128
- ted:: replace (
129
- record_pat_field. syntax ( ) ,
130
- ast:: make:: record_pat_field ( ast:: make:: name_ref ( & name_ref. text ( ) ) , binding)
121
+ let extracted = extracted
122
+ . iter ( )
123
+ . map ( |e| syntax. covering_element ( e. syntax ( ) . text_range ( ) ) )
124
+ . collect :: < Vec < _ > > ( ) ;
125
+ for extracted_syntax in extracted {
126
+ // If `extracted` variable is a record field, we should rename it to `binding`,
127
+ // otherwise we just need to replace `extracted` with `binding`.
128
+
129
+ if let Some ( record_pat_field) =
130
+ extracted_syntax. ancestors ( ) . find_map ( ast:: RecordPatField :: cast)
131
+ {
132
+ if let Some ( name_ref) = record_pat_field. field_name ( ) {
133
+ ted:: replace (
134
+ record_pat_field. syntax ( ) ,
135
+ ast:: make:: record_pat_field (
136
+ ast:: make:: name_ref ( & name_ref. text ( ) ) ,
137
+ binding. clone ( ) ,
138
+ )
131
139
. syntax ( )
132
140
. clone_for_update ( ) ,
133
- ) ;
141
+ ) ;
142
+ }
143
+ } else {
144
+ ted:: replace ( extracted_syntax, binding. clone ( ) . syntax ( ) . clone_for_update ( ) ) ;
134
145
}
135
- } else {
136
- ted:: replace ( extracted_syntax, binding. syntax ( ) . clone_for_update ( ) ) ;
137
146
}
138
-
139
147
syntax
140
148
}
141
149
@@ -162,6 +170,39 @@ fn foo(opt: Option<()>) {
162
170
) ;
163
171
}
164
172
173
+ #[ test]
174
+ fn or_pattern_multiple_binding ( ) {
175
+ check_assist (
176
+ convert_match_to_let_else,
177
+ r#"
178
+ //- minicore: option
179
+ enum Foo {
180
+ A(u32),
181
+ B(u32),
182
+ C(String),
183
+ }
184
+
185
+ fn foo(opt: Option<Foo>) -> Result<u32, ()> {
186
+ let va$0lue = match opt {
187
+ Some(Foo::A(it) | Foo::B(it)) => it,
188
+ _ => return Err(()),
189
+ };
190
+ }
191
+ "# ,
192
+ r#"
193
+ enum Foo {
194
+ A(u32),
195
+ B(u32),
196
+ C(String),
197
+ }
198
+
199
+ fn foo(opt: Option<Foo>) -> Result<u32, ()> {
200
+ let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
201
+ }
202
+ "# ,
203
+ ) ;
204
+ }
205
+
165
206
#[ test]
166
207
fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr ( ) {
167
208
cov_mark:: check_count!( extracting_arm_is_not_an_identity_expr, 2 ) ;
0 commit comments