Skip to content

Commit 13bf79c

Browse files
committed
fix usage of autodiff macro with inner functions
- fix errors caused by the move of `ast::Item::ident` (see rust-lang#138740) - move the logic of getting `sig`, `vis`, and `ident` from two seperate `match` statements into one (less repetition especially with the nested `match`)
1 parent 175dcc7 commit 13bf79c

File tree

1 file changed

+76
-32
lines changed

1 file changed

+76
-32
lines changed

Diff for: compiler/rustc_builtin_macros/src/autodiff.rs

+76-32
Original file line numberDiff line numberDiff line change
@@ -199,27 +199,46 @@ mod llvm_enzyme {
199199
return vec![item];
200200
}
201201
let dcx = ecx.sess.dcx();
202-
// first get the annotable item:
203-
let (primal, sig, is_impl): (Ident, FnSig, bool) = match &item {
202+
203+
// first get information about the annotable item:
204+
let (sig, vis, primal) = match &item {
204205
Annotatable::Item(iitem) => {
205-
let (ident, sig) = match &iitem.kind {
206-
ItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
206+
let (sig, ident) = match &iitem.kind {
207+
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => (sig, ident),
207208
_ => {
208209
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
209210
return vec![item];
210211
}
211212
};
212-
(*ident, sig.clone(), false)
213+
(sig.clone(), iitem.vis.clone(), ident.clone())
213214
}
214215
Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
215-
let (ident, sig) = match &assoc_item.kind {
216-
ast::AssocItemKind::Fn(box ast::Fn { ident, sig, .. }) => (ident, sig),
216+
let (sig, ident) = match &assoc_item.kind {
217+
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => (sig, ident),
218+
_ => {
219+
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
220+
return vec![item];
221+
}
222+
};
223+
(sig.clone(), assoc_item.vis.clone(), ident.clone())
224+
}
225+
Annotatable::Stmt(stmt) => {
226+
let (sig, vis, ident) = match &stmt.kind {
227+
ast::StmtKind::Item(iitem) => match &iitem.kind {
228+
ast::ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
229+
(sig.clone(), iitem.vis.clone(), ident.clone())
230+
}
231+
_ => {
232+
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
233+
return vec![item];
234+
}
235+
},
217236
_ => {
218237
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
219238
return vec![item];
220239
}
221240
};
222-
(*ident, sig.clone(), true)
241+
(sig, vis, ident)
223242
}
224243
_ => {
225244
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
@@ -238,15 +257,6 @@ mod llvm_enzyme {
238257
let has_ret = has_ret(&sig.decl.output);
239258
let sig_span = ecx.with_call_site_ctxt(sig.span);
240259

241-
let vis = match &item {
242-
Annotatable::Item(iitem) => iitem.vis.clone(),
243-
Annotatable::AssocItem(assoc_item, _) => assoc_item.vis.clone(),
244-
_ => {
245-
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
246-
return vec![item];
247-
}
248-
};
249-
250260
// create TokenStream from vec elemtents:
251261
// meta_item doesn't have a .tokens field
252262
let mut ts: Vec<TokenTree> = vec![];
@@ -379,6 +389,22 @@ mod llvm_enzyme {
379389
}
380390
Annotatable::AssocItem(assoc_item.clone(), i)
381391
}
392+
Annotatable::Stmt(ref mut stmt) => {
393+
match stmt.kind {
394+
ast::StmtKind::Item(ref mut iitem) => {
395+
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
396+
iitem.attrs.push(attr);
397+
}
398+
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
399+
{
400+
iitem.attrs.push(inline_never.clone());
401+
}
402+
}
403+
_ => unreachable!("stmt kind checked previously"),
404+
};
405+
406+
Annotatable::Stmt(stmt.clone())
407+
}
382408
_ => {
383409
unreachable!("annotatable kind checked previously")
384410
}
@@ -389,22 +415,40 @@ mod llvm_enzyme {
389415
delim: rustc_ast::token::Delimiter::Parenthesis,
390416
tokens: ts,
391417
});
418+
392419
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
393-
let d_annotatable = if is_impl {
394-
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
395-
let d_fn = P(ast::AssocItem {
396-
attrs: thin_vec![d_attr, inline_never],
397-
id: ast::DUMMY_NODE_ID,
398-
span,
399-
vis,
400-
kind: assoc_item,
401-
tokens: None,
402-
});
403-
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
404-
} else {
405-
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
406-
d_fn.vis = vis;
407-
Annotatable::Item(d_fn)
420+
let d_annotatable = match &item {
421+
Annotatable::AssocItem(_, _) => {
422+
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
423+
let d_fn = P(ast::AssocItem {
424+
attrs: thin_vec![d_attr, inline_never],
425+
id: ast::DUMMY_NODE_ID,
426+
span,
427+
vis,
428+
kind: assoc_item,
429+
tokens: None,
430+
});
431+
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
432+
}
433+
Annotatable::Item(_) => {
434+
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
435+
d_fn.vis = vis;
436+
437+
Annotatable::Item(d_fn)
438+
}
439+
Annotatable::Stmt(_) => {
440+
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
441+
d_fn.vis = vis;
442+
443+
Annotatable::Stmt(P(ast::Stmt {
444+
id: ast::DUMMY_NODE_ID,
445+
kind: ast::StmtKind::Item(d_fn),
446+
span,
447+
}))
448+
}
449+
_ => {
450+
unreachable!("item kind checked previously")
451+
}
408452
};
409453

410454
return vec![orig_annotatable, d_annotatable];

0 commit comments

Comments
 (0)