@@ -2,7 +2,7 @@ use std::str::FromStr;
2
2
3
3
use rustc_abi:: ExternAbi ;
4
4
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
5
- use rustc_ast:: { MetaItem , MetaItemInner , attr} ;
5
+ use rustc_ast:: { LitKind , MetaItem , MetaItemInner , attr} ;
6
6
use rustc_attr_parsing:: ReprAttr :: ReprAlign ;
7
7
use rustc_attr_parsing:: { AttributeKind , InlineAttr , InstructionSetAttr , OptimizeAttr } ;
8
8
use rustc_data_structures:: fx:: FxHashMap ;
@@ -805,8 +805,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
805
805
return Some ( AutoDiffAttrs :: source ( ) ) ;
806
806
}
807
807
808
- let [ mode, input_activities @ .., ret_activity] = & list[ ..] else {
809
- span_bug ! ( attr. span( ) , "rustc_autodiff attribute must contain mode and activities" ) ;
808
+ let [ mode, width_meta , input_activities @ .., ret_activity] = & list[ ..] else {
809
+ span_bug ! ( attr. span( ) , "rustc_autodiff attribute must contain mode, width and activities" ) ;
810
810
} ;
811
811
let mode = if let MetaItemInner :: MetaItem ( MetaItem { path : p1, .. } ) = mode {
812
812
p1. segments . first ( ) . unwrap ( ) . ident
@@ -823,6 +823,30 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
823
823
}
824
824
} ;
825
825
826
+ let width: u32 = match width_meta {
827
+ MetaItemInner :: MetaItem ( MetaItem { path : p1, .. } ) => {
828
+ let w = p1. segments . first ( ) . unwrap ( ) . ident ;
829
+ match w. as_str ( ) . parse ( ) {
830
+ Ok ( val) => val,
831
+ Err ( _) => {
832
+ span_bug ! ( w. span, "rustc_autodiff width should fit u32" ) ;
833
+ }
834
+ }
835
+ }
836
+ MetaItemInner :: Lit ( lit) => {
837
+ if let LitKind :: Int ( val, _) = lit. kind {
838
+ match val. get ( ) . try_into ( ) {
839
+ Ok ( val) => val,
840
+ Err ( _) => {
841
+ span_bug ! ( lit. span, "rustc_autodiff width should fit u32" ) ;
842
+ }
843
+ }
844
+ } else {
845
+ span_bug ! ( lit. span, "rustc_autodiff width should be an integer" ) ;
846
+ }
847
+ }
848
+ } ;
849
+
826
850
// First read the ret symbol from the attribute
827
851
let ret_symbol = if let MetaItemInner :: MetaItem ( MetaItem { path : p1, .. } ) = ret_activity {
828
852
p1. segments . first ( ) . unwrap ( ) . ident
@@ -860,7 +884,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
860
884
}
861
885
}
862
886
863
- Some ( AutoDiffAttrs { mode, width : 1 , ret_activity, input_activity : arg_activities } )
887
+ Some ( AutoDiffAttrs { mode, width, ret_activity, input_activity : arg_activities } )
864
888
}
865
889
866
890
pub ( crate ) fn provide ( providers : & mut Providers ) {
0 commit comments