Skip to content

Commit e0c8ead

Browse files
committed
add autodiff batching middle-end
1 parent 087ffd7 commit e0c8ead

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

Diff for: compiler/rustc_codegen_ssa/src/codegen_attrs.rs

+28-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::str::FromStr;
22

33
use rustc_abi::ExternAbi;
44
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
5-
use rustc_ast::{MetaItem, MetaItemInner, attr};
5+
use rustc_ast::{LitKind, MetaItem, MetaItemInner, attr};
66
use rustc_attr_parsing::ReprAttr::ReprAlign;
77
use rustc_attr_parsing::{AttributeKind, InlineAttr, InstructionSetAttr, OptimizeAttr};
88
use rustc_data_structures::fx::FxHashMap;
@@ -805,8 +805,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
805805
return Some(AutoDiffAttrs::source());
806806
}
807807

808-
let [mode, input_activities @ .., ret_activity] = &list[..] else {
809-
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode and activities");
808+
let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else {
809+
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities");
810810
};
811811
let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode {
812812
p1.segments.first().unwrap().ident
@@ -823,6 +823,30 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
823823
}
824824
};
825825

826+
let width: u32 = match width_meta {
827+
MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => {
828+
let w = p1.segments.first().unwrap().ident;
829+
match w.as_str().parse() {
830+
Ok(val) => val,
831+
Err(_) => {
832+
span_bug!(w.span, "rustc_autodiff width should fit u32");
833+
}
834+
}
835+
}
836+
MetaItemInner::Lit(lit) => {
837+
if let LitKind::Int(val, _) = lit.kind {
838+
match val.get().try_into() {
839+
Ok(val) => val,
840+
Err(_) => {
841+
span_bug!(lit.span, "rustc_autodiff width should fit u32");
842+
}
843+
}
844+
} else {
845+
span_bug!(lit.span, "rustc_autodiff width should be an integer");
846+
}
847+
}
848+
};
849+
826850
// First read the ret symbol from the attribute
827851
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity {
828852
p1.segments.first().unwrap().ident
@@ -860,7 +884,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
860884
}
861885
}
862886

863-
Some(AutoDiffAttrs { mode, width: 1, ret_activity, input_activity: arg_activities })
887+
Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities })
864888
}
865889

866890
pub(crate) fn provide(providers: &mut Providers) {

0 commit comments

Comments
 (0)