1
+ /// This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2
+ /// we create an `AutoDiffItem` which contains the source and target function names. The source
3
+ /// is the function to which the autodiff attribute is applied, and the target is the function
4
+ /// getting generated by us (with a name given by the user as the first autodiff arg).
1
5
use std:: fmt:: { self , Display , Formatter } ;
2
6
use std:: str:: FromStr ;
3
7
@@ -6,27 +10,91 @@ use crate::expand::{Decodable, Encodable, HashStable_Generic};
6
10
use crate :: ptr:: P ;
7
11
use crate :: { Ty , TyKind } ;
8
12
9
- #[ allow( dead_code) ]
10
13
#[ derive( Clone , Copy , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
11
14
pub enum DiffMode {
15
+ /// No autodiff is applied (usually used during error handling).
12
16
Inactive ,
17
+ /// The primal function which we will differentiate.
13
18
Source ,
19
+ /// The target function, to be created using forward mode AD.
14
20
Forward ,
21
+ /// The target function, to be created using reverse mode AD.
15
22
Reverse ,
23
+ /// The target function, to be created using forward mode AD.
24
+ /// This target function will also be used as a source for higher order derivatives,
25
+ /// so compute it before all Forward/Reverse targets and optimize it through llvm.
16
26
ForwardFirst ,
27
+ /// The target function, to be created using reverse mode AD.
28
+ /// This target function will also be used as a source for higher order derivatives,
29
+ /// so compute it before all Forward/Reverse targets and optimize it through llvm.
17
30
ReverseFirst ,
18
31
}
19
32
20
- pub fn is_rev ( mode : DiffMode ) -> bool {
21
- match mode {
22
- DiffMode :: Reverse | DiffMode :: ReverseFirst => true ,
23
- _ => false ,
24
- }
33
+ /// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
34
+ /// However, under forward mode we overwrite the previous shadow value, while for reverse mode
35
+ /// we add to the previous shadow value. To not surprise users, we picked different names.
36
+ /// Dual numbers is also a quite well known name for forward mode AD types.
37
+ #[ derive( Clone , Copy , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
38
+ pub enum DiffActivity {
39
+ /// Implicit or Explicit () return type, so a special case of Const.
40
+ None ,
41
+ /// Don't compute derivatives with respect to this input/output.
42
+ Const ,
43
+ /// Reverse Mode, Compute derivatives for this scalar input/output.
44
+ Active ,
45
+ /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
46
+ /// the original return value.
47
+ ActiveOnly ,
48
+ /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
49
+ /// with it.
50
+ Dual ,
51
+ /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
52
+ /// with it. Drop the code which updates the original input/output for maximum performance.
53
+ DualOnly ,
54
+ /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
55
+ Duplicated ,
56
+ /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
57
+ /// Drop the code which updates the original input for maximum performance.
58
+ DuplicatedOnly ,
59
+ /// All Integers must be Const, but these are used to mark the integer which represents the
60
+ /// length of a slice/vec. This is used for safety checks on slices.
61
+ FakeActivitySize ,
25
62
}
26
- pub fn is_fwd ( mode : DiffMode ) -> bool {
27
- match mode {
28
- DiffMode :: Forward | DiffMode :: ForwardFirst => true ,
29
- _ => false ,
63
+ /// We generate one of these structs for each `#[autodiff(...)]` attribute.
64
+ #[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
65
+ pub struct AutoDiffItem {
66
+ /// The name of the function getting differentiated
67
+ pub source : String ,
68
+ /// The name of the function being generated
69
+ pub target : String ,
70
+ pub attrs : AutoDiffAttrs ,
71
+ /// Despribe the memory layout of input types
72
+ pub inputs : Vec < TypeTree > ,
73
+ /// Despribe the memory layout of the output type
74
+ pub output : TypeTree ,
75
+ }
76
+ #[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
77
+ pub struct AutoDiffAttrs {
78
+ /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
79
+ /// e.g. in the [JAX
80
+ /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
81
+ pub mode : DiffMode ,
82
+ pub ret_activity : DiffActivity ,
83
+ pub input_activity : Vec < DiffActivity > ,
84
+ }
85
+
86
+ impl DiffMode {
87
+ pub fn is_rev ( & self ) -> bool {
88
+ match self {
89
+ DiffMode :: Reverse | DiffMode :: ReverseFirst => true ,
90
+ _ => false ,
91
+ }
92
+ }
93
+ pub fn is_fwd ( & self ) -> bool {
94
+ match self {
95
+ DiffMode :: Forward | DiffMode :: ForwardFirst => true ,
96
+ _ => false ,
97
+ }
30
98
}
31
99
}
32
100
@@ -63,30 +131,20 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
63
131
}
64
132
}
65
133
}
66
- fn is_ptr_or_ref ( ty : & Ty ) -> bool {
67
- match ty. kind {
68
- TyKind :: Ptr ( _) | TyKind :: Ref ( _, _) => true ,
69
- _ => false ,
70
- }
71
- }
72
- // TODO We should make this more robust to also
134
+
135
+ // FIXME(ZuseZ4) We should make this more robust to also
73
136
// accept aliases of f32 and f64
74
- //fn is_float(ty: &Ty) -> bool {
75
- // false
76
- //}
77
137
pub fn valid_ty_for_activity ( ty : & P < Ty > , activity : DiffActivity ) -> bool {
78
- if is_ptr_or_ref ( ty) {
79
- return activity == DiffActivity :: Dual
80
- || activity == DiffActivity :: DualOnly
81
- || activity == DiffActivity :: Duplicated
82
- || activity == DiffActivity :: DuplicatedOnly
83
- || activity == DiffActivity :: Const ;
138
+ match ty. kind {
139
+ TyKind :: Ptr ( _) | TyKind :: Ref ( ..) => {
140
+ return activity == DiffActivity :: Dual
141
+ || activity == DiffActivity :: DualOnly
142
+ || activity == DiffActivity :: Duplicated
143
+ || activity == DiffActivity :: DuplicatedOnly
144
+ || activity == DiffActivity :: Const ;
145
+ }
146
+ _ => false ,
84
147
}
85
- true
86
- //if is_scalar_ty(&ty) {
87
- // return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly ||
88
- // activity == DiffActivity::Const;
89
- //}
90
148
}
91
149
pub fn valid_input_activity ( mode : DiffMode , activity : DiffActivity ) -> bool {
92
150
return match mode {
@@ -117,20 +175,6 @@ pub fn invalid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -
117
175
None
118
176
}
119
177
120
- #[ allow( dead_code) ]
121
- #[ derive( Clone , Copy , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
122
- pub enum DiffActivity {
123
- None ,
124
- Const ,
125
- Active ,
126
- ActiveOnly ,
127
- Dual ,
128
- DualOnly ,
129
- Duplicated ,
130
- DuplicatedOnly ,
131
- FakeActivitySize ,
132
- }
133
-
134
178
impl Display for DiffActivity {
135
179
fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
136
180
match self {
@@ -180,30 +224,14 @@ impl FromStr for DiffActivity {
180
224
}
181
225
}
182
226
183
- #[ allow( dead_code) ]
184
- #[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
185
- pub struct AutoDiffAttrs {
186
- pub mode : DiffMode ,
187
- pub ret_activity : DiffActivity ,
188
- pub input_activity : Vec < DiffActivity > ,
189
- }
190
-
191
227
impl AutoDiffAttrs {
192
228
pub fn has_ret_activity ( & self ) -> bool {
193
- match self . ret_activity {
194
- DiffActivity :: None => false ,
195
- _ => true ,
196
- }
229
+ self . ret_activity != DiffActivity :: None
197
230
}
198
231
pub fn has_active_only_ret ( & self ) -> bool {
199
- match self . ret_activity {
200
- DiffActivity :: ActiveOnly => true ,
201
- _ => false ,
202
- }
232
+ self . ret_activity == DiffActivity :: ActiveOnly
203
233
}
204
- }
205
234
206
- impl AutoDiffAttrs {
207
235
pub fn inactive ( ) -> Self {
208
236
AutoDiffAttrs {
209
237
mode : DiffMode :: Inactive ,
@@ -251,15 +279,6 @@ impl AutoDiffAttrs {
251
279
}
252
280
}
253
281
254
- #[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
255
- pub struct AutoDiffItem {
256
- pub source : String ,
257
- pub target : String ,
258
- pub attrs : AutoDiffAttrs ,
259
- pub inputs : Vec < TypeTree > ,
260
- pub output : TypeTree ,
261
- }
262
-
263
282
impl fmt:: Display for AutoDiffItem {
264
283
fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
265
284
write ! ( f, "Differentiating {} -> {}" , self . source, self . target) ?;
0 commit comments