@@ -17,7 +17,7 @@ mod llvm_enzyme {
17
17
use rustc_ast:: visit:: AssocCtxt :: * ;
18
18
use rustc_ast:: {
19
19
self as ast, AssocItemKind , BindingMode , ExprKind , FnRetTy , FnSig , Generics , ItemKind ,
20
- MetaItemInner , PatKind , QSelf , TyKind ,
20
+ MetaItemInner , PatKind , QSelf , TyKind , Visibility ,
21
21
} ;
22
22
use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
23
23
use rustc_span:: { Ident , Span , Symbol , kw, sym} ;
@@ -72,6 +72,16 @@ mod llvm_enzyme {
72
72
}
73
73
}
74
74
75
+ // Get information about the function the macro is applied to
76
+ fn extract_item_info ( iitem : & P < ast:: Item > ) -> Option < ( Visibility , FnSig , Ident ) > {
77
+ match & iitem. kind {
78
+ ItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
79
+ Some ( ( iitem. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) ) )
80
+ }
81
+ _ => None ,
82
+ }
83
+ }
84
+
75
85
pub ( crate ) fn from_ast (
76
86
ecx : & mut ExtCtxt < ' _ > ,
77
87
meta_item : & ThinVec < MetaItemInner > ,
@@ -199,32 +209,26 @@ mod llvm_enzyme {
199
209
return vec ! [ item] ;
200
210
}
201
211
let dcx = ecx. sess . dcx ( ) ;
202
- // first get the annotable item:
203
- let ( primal, sig, is_impl) : ( Ident , FnSig , bool ) = match & item {
204
- Annotatable :: Item ( iitem) => {
205
- let ( ident, sig) = match & iitem. kind {
206
- ItemKind :: Fn ( box ast:: Fn { ident, sig, .. } ) => ( ident, sig) ,
207
- _ => {
208
- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
209
- return vec ! [ item] ;
210
- }
211
- } ;
212
- ( * ident, sig. clone ( ) , false )
213
- }
212
+
213
+ // first get information about the annotable item:
214
+ let Some ( ( vis, sig, primal) ) = ( match & item {
215
+ Annotatable :: Item ( iitem) => extract_item_info ( iitem) ,
216
+ Annotatable :: Stmt ( stmt) => match & stmt. kind {
217
+ ast:: StmtKind :: Item ( iitem) => extract_item_info ( iitem) ,
218
+ _ => None ,
219
+ } ,
214
220
Annotatable :: AssocItem ( assoc_item, Impl { of_trait : false } ) => {
215
- let ( ident, sig) = match & assoc_item. kind {
216
- ast:: AssocItemKind :: Fn ( box ast:: Fn { ident, sig, .. } ) => ( ident, sig) ,
217
- _ => {
218
- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
219
- return vec ! [ item] ;
221
+ match & assoc_item. kind {
222
+ ast:: AssocItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
223
+ Some ( ( assoc_item. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) ) )
220
224
}
221
- } ;
222
- ( * ident, sig. clone ( ) , true )
223
- }
224
- _ => {
225
- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
226
- return vec ! [ item] ;
225
+ _ => None ,
226
+ }
227
227
}
228
+ _ => None ,
229
+ } ) else {
230
+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
231
+ return vec ! [ item] ;
228
232
} ;
229
233
230
234
let meta_item_vec: ThinVec < MetaItemInner > = match meta_item. kind {
@@ -238,15 +242,6 @@ mod llvm_enzyme {
238
242
let has_ret = has_ret ( & sig. decl . output ) ;
239
243
let sig_span = ecx. with_call_site_ctxt ( sig. span ) ;
240
244
241
- let vis = match & item {
242
- Annotatable :: Item ( iitem) => iitem. vis . clone ( ) ,
243
- Annotatable :: AssocItem ( assoc_item, _) => assoc_item. vis . clone ( ) ,
244
- _ => {
245
- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
246
- return vec ! [ item] ;
247
- }
248
- } ;
249
-
250
245
// create TokenStream from vec elemtents:
251
246
// meta_item doesn't have a .tokens field
252
247
let mut ts: Vec < TokenTree > = vec ! [ ] ;
@@ -379,6 +374,22 @@ mod llvm_enzyme {
379
374
}
380
375
Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
381
376
}
377
+ Annotatable :: Stmt ( ref mut stmt) => {
378
+ match stmt. kind {
379
+ ast:: StmtKind :: Item ( ref mut iitem) => {
380
+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & attr. kind ) ) {
381
+ iitem. attrs . push ( attr) ;
382
+ }
383
+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) )
384
+ {
385
+ iitem. attrs . push ( inline_never. clone ( ) ) ;
386
+ }
387
+ }
388
+ _ => unreachable ! ( "stmt kind checked previously" ) ,
389
+ } ;
390
+
391
+ Annotatable :: Stmt ( stmt. clone ( ) )
392
+ }
382
393
_ => {
383
394
unreachable ! ( "annotatable kind checked previously" )
384
395
}
@@ -389,22 +400,40 @@ mod llvm_enzyme {
389
400
delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
390
401
tokens : ts,
391
402
} ) ;
403
+
392
404
let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
393
- let d_annotatable = if is_impl {
394
- let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
395
- let d_fn = P ( ast:: AssocItem {
396
- attrs : thin_vec ! [ d_attr, inline_never] ,
397
- id : ast:: DUMMY_NODE_ID ,
398
- span,
399
- vis,
400
- kind : assoc_item,
401
- tokens : None ,
402
- } ) ;
403
- Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
404
- } else {
405
- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
406
- d_fn. vis = vis;
407
- Annotatable :: Item ( d_fn)
405
+ let d_annotatable = match & item {
406
+ Annotatable :: AssocItem ( _, _) => {
407
+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
408
+ let d_fn = P ( ast:: AssocItem {
409
+ attrs : thin_vec ! [ d_attr, inline_never] ,
410
+ id : ast:: DUMMY_NODE_ID ,
411
+ span,
412
+ vis,
413
+ kind : assoc_item,
414
+ tokens : None ,
415
+ } ) ;
416
+ Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
417
+ }
418
+ Annotatable :: Item ( _) => {
419
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
420
+ d_fn. vis = vis;
421
+
422
+ Annotatable :: Item ( d_fn)
423
+ }
424
+ Annotatable :: Stmt ( _) => {
425
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
426
+ d_fn. vis = vis;
427
+
428
+ Annotatable :: Stmt ( P ( ast:: Stmt {
429
+ id : ast:: DUMMY_NODE_ID ,
430
+ kind : ast:: StmtKind :: Item ( d_fn) ,
431
+ span,
432
+ } ) )
433
+ }
434
+ _ => {
435
+ unreachable ! ( "item kind checked previously" )
436
+ }
408
437
} ;
409
438
410
439
return vec ! [ orig_annotatable, d_annotatable] ;
0 commit comments