@@ -199,27 +199,46 @@ mod llvm_enzyme {
199
199
return vec ! [ item] ;
200
200
}
201
201
let dcx = ecx. sess . dcx ( ) ;
202
- // first get the annotable item:
203
- let ( primal, sig, is_impl) : ( Ident , FnSig , bool ) = match & item {
202
+
203
+ // first get information about the annotable item:
204
+ let ( sig, vis, primal) = match & item {
204
205
Annotatable :: Item ( iitem) => {
205
- let ( ident , sig ) = match & iitem. kind {
206
- ItemKind :: Fn ( box ast:: Fn { ident , sig , .. } ) => ( ident , sig ) ,
206
+ let ( sig , ident ) = match & iitem. kind {
207
+ ItemKind :: Fn ( box ast:: Fn { sig , ident , .. } ) => ( sig , ident ) ,
207
208
_ => {
208
209
dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
209
210
return vec ! [ item] ;
210
211
}
211
212
} ;
212
- ( * ident , sig . clone ( ) , false )
213
+ ( sig . clone ( ) , iitem . vis . clone ( ) , ident . clone ( ) )
213
214
}
214
215
Annotatable :: AssocItem ( assoc_item, Impl { of_trait : false } ) => {
215
- let ( ident, sig) = match & assoc_item. kind {
216
- ast:: AssocItemKind :: Fn ( box ast:: Fn { ident, sig, .. } ) => ( ident, sig) ,
216
+ let ( sig, ident) = match & assoc_item. kind {
217
+ ast:: AssocItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => ( sig, ident) ,
218
+ _ => {
219
+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
220
+ return vec ! [ item] ;
221
+ }
222
+ } ;
223
+ ( sig. clone ( ) , assoc_item. vis . clone ( ) , ident. clone ( ) )
224
+ }
225
+ Annotatable :: Stmt ( stmt) => {
226
+ let ( sig, vis, ident) = match & stmt. kind {
227
+ ast:: StmtKind :: Item ( iitem) => match & iitem. kind {
228
+ ast:: ItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
229
+ ( sig. clone ( ) , iitem. vis . clone ( ) , ident. clone ( ) )
230
+ }
231
+ _ => {
232
+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
233
+ return vec ! [ item] ;
234
+ }
235
+ } ,
217
236
_ => {
218
237
dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
219
238
return vec ! [ item] ;
220
239
}
221
240
} ;
222
- ( * ident , sig . clone ( ) , true )
241
+ ( sig , vis , ident )
223
242
}
224
243
_ => {
225
244
dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
@@ -238,15 +257,6 @@ mod llvm_enzyme {
238
257
let has_ret = has_ret ( & sig. decl . output ) ;
239
258
let sig_span = ecx. with_call_site_ctxt ( sig. span ) ;
240
259
241
- let vis = match & item {
242
- Annotatable :: Item ( iitem) => iitem. vis . clone ( ) ,
243
- Annotatable :: AssocItem ( assoc_item, _) => assoc_item. vis . clone ( ) ,
244
- _ => {
245
- dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
246
- return vec ! [ item] ;
247
- }
248
- } ;
249
-
250
260
// create TokenStream from vec elemtents:
251
261
// meta_item doesn't have a .tokens field
252
262
let mut ts: Vec < TokenTree > = vec ! [ ] ;
@@ -379,6 +389,22 @@ mod llvm_enzyme {
379
389
}
380
390
Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
381
391
}
392
+ Annotatable :: Stmt ( ref mut stmt) => {
393
+ match stmt. kind {
394
+ ast:: StmtKind :: Item ( ref mut iitem) => {
395
+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & attr. kind ) ) {
396
+ iitem. attrs . push ( attr) ;
397
+ }
398
+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) )
399
+ {
400
+ iitem. attrs . push ( inline_never. clone ( ) ) ;
401
+ }
402
+ }
403
+ _ => unreachable ! ( "stmt kind checked previously" ) ,
404
+ } ;
405
+
406
+ Annotatable :: Stmt ( stmt. clone ( ) )
407
+ }
382
408
_ => {
383
409
unreachable ! ( "annotatable kind checked previously" )
384
410
}
@@ -389,22 +415,40 @@ mod llvm_enzyme {
389
415
delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
390
416
tokens : ts,
391
417
} ) ;
418
+
392
419
let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
393
- let d_annotatable = if is_impl {
394
- let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
395
- let d_fn = P ( ast:: AssocItem {
396
- attrs : thin_vec ! [ d_attr, inline_never] ,
397
- id : ast:: DUMMY_NODE_ID ,
398
- span,
399
- vis,
400
- kind : assoc_item,
401
- tokens : None ,
402
- } ) ;
403
- Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
404
- } else {
405
- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
406
- d_fn. vis = vis;
407
- Annotatable :: Item ( d_fn)
420
+ let d_annotatable = match & item {
421
+ Annotatable :: AssocItem ( _, _) => {
422
+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
423
+ let d_fn = P ( ast:: AssocItem {
424
+ attrs : thin_vec ! [ d_attr, inline_never] ,
425
+ id : ast:: DUMMY_NODE_ID ,
426
+ span,
427
+ vis,
428
+ kind : assoc_item,
429
+ tokens : None ,
430
+ } ) ;
431
+ Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
432
+ }
433
+ Annotatable :: Item ( _) => {
434
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
435
+ d_fn. vis = vis;
436
+
437
+ Annotatable :: Item ( d_fn)
438
+ }
439
+ Annotatable :: Stmt ( _) => {
440
+ let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
441
+ d_fn. vis = vis;
442
+
443
+ Annotatable :: Stmt ( P ( ast:: Stmt {
444
+ id : ast:: DUMMY_NODE_ID ,
445
+ kind : ast:: StmtKind :: Item ( d_fn) ,
446
+ span,
447
+ } ) )
448
+ }
449
+ _ => {
450
+ unreachable ! ( "item kind checked previously" )
451
+ }
408
452
} ;
409
453
410
454
return vec ! [ orig_annotatable, d_annotatable] ;
0 commit comments