@@ -3,13 +3,15 @@ use clippy_utils::diagnostics::{span_lint_and_sugg, span_lint_hir_and_then};
3
3
use clippy_utils:: source:: { snippet, snippet_with_applicability} ;
4
4
use clippy_utils:: sugg:: Sugg ;
5
5
use clippy_utils:: ty:: is_type_diagnostic_item;
6
- use clippy_utils:: { is_trait_method, path_to_local_id} ;
6
+ use clippy_utils:: { can_move_expr_to_closure , is_trait_method, path_to_local_id, CaptureKind } ;
7
7
use if_chain:: if_chain;
8
8
use rustc_errors:: Applicability ;
9
9
use rustc_hir:: intravisit:: { walk_block, walk_expr, NestedVisitorMap , Visitor } ;
10
- use rustc_hir:: { Block , Expr , ExprKind , HirId , PatKind , StmtKind } ;
10
+ use rustc_hir:: { Block , Expr , ExprKind , HirId , HirIdSet , Local , Mutability , Node , PatKind , Stmt , StmtKind } ;
11
11
use rustc_lint:: LateContext ;
12
12
use rustc_middle:: hir:: map:: Map ;
13
+ use rustc_middle:: ty:: subst:: GenericArgKind ;
14
+ use rustc_middle:: ty:: { TyKind , TyS } ;
13
15
use rustc_span:: sym;
14
16
use rustc_span:: { MultiSpan , Span } ;
15
17
@@ -83,7 +85,8 @@ fn check_needless_collect_indirect_usage<'tcx>(expr: &'tcx Expr<'_>, cx: &LateCo
83
85
is_type_diagnostic_item( cx, ty, sym:: VecDeque ) ||
84
86
is_type_diagnostic_item( cx, ty, sym:: BinaryHeap ) ||
85
87
is_type_diagnostic_item( cx, ty, sym:: LinkedList ) ;
86
- if let Some ( iter_calls) = detect_iter_and_into_iters( block, id) ;
88
+ let iter_ty = cx. typeck_results( ) . expr_ty( iter_source) ;
89
+ if let Some ( iter_calls) = detect_iter_and_into_iters( block, id, cx, get_captured_ids( cx, iter_ty) ) ;
87
90
if let [ iter_call] = & * iter_calls;
88
91
then {
89
92
let mut used_count_visitor = UsedCountVisitor {
@@ -167,34 +170,57 @@ enum IterFunctionKind {
167
170
Contains ( Span ) ,
168
171
}
169
172
170
- struct IterFunctionVisitor {
173
+ struct IterFunctionVisitor < ' b , ' a > {
174
+ illegal_mutable_capture_ids : HirIdSet ,
175
+ current_mutably_captured_ids : HirIdSet ,
176
+ cx : & ' a LateContext < ' b > ,
171
177
uses : Vec < IterFunction > ,
172
178
seen_other : bool ,
173
179
target : HirId ,
174
180
}
175
- impl < ' tcx > Visitor < ' tcx > for IterFunctionVisitor {
181
+ impl < ' tcx > Visitor < ' tcx > for IterFunctionVisitor < ' _ , ' tcx > {
182
+ fn visit_block ( & mut self , block : & ' txc Block < ' tcx > ) {
183
+ for elem in block. stmts . iter ( ) . filter_map ( get_expr_from_stmt) . chain ( block. expr ) {
184
+ self . current_mutably_captured_ids = HirIdSet :: default ( ) ;
185
+ self . visit_expr ( elem) ;
186
+ }
187
+ }
188
+
176
189
fn visit_expr ( & mut self , expr : & ' tcx Expr < ' tcx > ) {
177
190
// Check function calls on our collection
178
191
if let ExprKind :: MethodCall ( method_name, _, [ recv, args @ ..] , _) = & expr. kind {
192
+ if method_name. ident . name == sym ! ( collect) && is_trait_method ( self . cx , expr, sym:: Iterator ) {
193
+ self . current_mutably_captured_ids = get_captured_ids ( self . cx , self . cx . typeck_results ( ) . expr_ty ( recv) ) ;
194
+ self . visit_expr ( recv) ;
195
+ return ;
196
+ }
197
+
179
198
if path_to_local_id ( recv, self . target ) {
180
- match & * method_name. ident . name . as_str ( ) {
181
- "into_iter" => self . uses . push ( IterFunction {
182
- func : IterFunctionKind :: IntoIter ,
183
- span : expr. span ,
184
- } ) ,
185
- "len" => self . uses . push ( IterFunction {
186
- func : IterFunctionKind :: Len ,
187
- span : expr. span ,
188
- } ) ,
189
- "is_empty" => self . uses . push ( IterFunction {
190
- func : IterFunctionKind :: IsEmpty ,
191
- span : expr. span ,
192
- } ) ,
193
- "contains" => self . uses . push ( IterFunction {
194
- func : IterFunctionKind :: Contains ( args[ 0 ] . span ) ,
195
- span : expr. span ,
196
- } ) ,
197
- _ => self . seen_other = true ,
199
+ if self
200
+ . illegal_mutable_capture_ids
201
+ . intersection ( & self . current_mutably_captured_ids )
202
+ . next ( )
203
+ . is_none ( )
204
+ {
205
+ match & * method_name. ident . name . as_str ( ) {
206
+ "into_iter" => self . uses . push ( IterFunction {
207
+ func : IterFunctionKind :: IntoIter ,
208
+ span : expr. span ,
209
+ } ) ,
210
+ "len" => self . uses . push ( IterFunction {
211
+ func : IterFunctionKind :: Len ,
212
+ span : expr. span ,
213
+ } ) ,
214
+ "is_empty" => self . uses . push ( IterFunction {
215
+ func : IterFunctionKind :: IsEmpty ,
216
+ span : expr. span ,
217
+ } ) ,
218
+ "contains" => self . uses . push ( IterFunction {
219
+ func : IterFunctionKind :: Contains ( args[ 0 ] . span ) ,
220
+ span : expr. span ,
221
+ } ) ,
222
+ _ => self . seen_other = true ,
223
+ }
198
224
}
199
225
return ;
200
226
}
@@ -213,6 +239,14 @@ impl<'tcx> Visitor<'tcx> for IterFunctionVisitor {
213
239
}
214
240
}
215
241
242
+ fn get_expr_from_stmt < ' v > ( stmt : & ' v Stmt < ' v > ) -> Option < & ' v Expr < ' v > > {
243
+ match stmt. kind {
244
+ StmtKind :: Expr ( expr) | StmtKind :: Semi ( expr) => Some ( expr) ,
245
+ StmtKind :: Item ( ..) => None ,
246
+ StmtKind :: Local ( Local { init, .. } ) => * init,
247
+ }
248
+ }
249
+
216
250
struct UsedCountVisitor < ' a , ' tcx > {
217
251
cx : & ' a LateContext < ' tcx > ,
218
252
id : HirId ,
@@ -237,12 +271,55 @@ impl<'a, 'tcx> Visitor<'tcx> for UsedCountVisitor<'a, 'tcx> {
237
271
238
272
/// Detect the occurrences of calls to `iter` or `into_iter` for the
239
273
/// given identifier
240
- fn detect_iter_and_into_iters < ' tcx > ( block : & ' tcx Block < ' tcx > , id : HirId ) -> Option < Vec < IterFunction > > {
274
+ fn detect_iter_and_into_iters < ' tcx : ' a , ' a > (
275
+ block : & ' tcx Block < ' tcx > ,
276
+ id : HirId ,
277
+ cx : & ' a LateContext < ' tcx > ,
278
+ captured_ids : HirIdSet ,
279
+ ) -> Option < Vec < IterFunction > > {
241
280
let mut visitor = IterFunctionVisitor {
242
281
uses : Vec :: new ( ) ,
243
282
target : id,
244
283
seen_other : false ,
284
+ cx,
285
+ current_mutably_captured_ids : HirIdSet :: default ( ) ,
286
+ illegal_mutable_capture_ids : captured_ids,
245
287
} ;
246
288
visitor. visit_block ( block) ;
247
289
if visitor. seen_other { None } else { Some ( visitor. uses ) }
248
290
}
291
+
292
+ #[ allow( rustc:: usage_of_ty_tykind) ]
293
+ fn get_captured_ids ( cx : & LateContext < ' tcx > , ty : & ' _ TyS < ' _ > ) -> HirIdSet {
294
+ fn get_captured_ids_recursive ( cx : & LateContext < ' tcx > , ty : & ' _ TyS < ' _ > , set : & mut HirIdSet ) {
295
+ match ty. kind ( ) {
296
+ TyKind :: Adt ( _, generics) => {
297
+ for generic in * generics {
298
+ if let GenericArgKind :: Type ( ty) = generic. unpack ( ) {
299
+ get_captured_ids_recursive ( cx, ty, set) ;
300
+ }
301
+ }
302
+ } ,
303
+ TyKind :: Closure ( def_id, _) => {
304
+ let closure_hir_node = cx. tcx . hir ( ) . get_if_local ( * def_id) . unwrap ( ) ;
305
+ if let Node :: Expr ( closure_expr) = closure_hir_node {
306
+ can_move_expr_to_closure ( cx, closure_expr)
307
+ . unwrap ( )
308
+ . into_iter ( )
309
+ . for_each ( |( hir_id, capture_kind) | {
310
+ if matches ! ( capture_kind, CaptureKind :: Ref ( Mutability :: Mut ) ) {
311
+ set. insert ( hir_id) ;
312
+ }
313
+ } ) ;
314
+ }
315
+ } ,
316
+ _ => ( ) ,
317
+ }
318
+ }
319
+
320
+ let mut set = HirIdSet :: default ( ) ;
321
+
322
+ get_captured_ids_recursive ( cx, ty, & mut set) ;
323
+
324
+ set
325
+ }
0 commit comments