Skip to content

Commit f82e085

Browse files
committed
addressing reviewer feedback
1 parent ca94290 commit f82e085

File tree

2 files changed

+101
-83
lines changed

2 files changed

+101
-83
lines changed

Diff for: compiler/rustc_ast/src/expand/autodiff_attrs.rs

+91-72
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
/// This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2+
/// we create an `AutoDiffItem` which contains the source and target function names. The source
3+
/// is the function to which the autodiff attribute is applied, and the target is the function
4+
/// getting generated by us (with a name given by the user as the first autodiff arg).
15
use std::fmt::{self, Display, Formatter};
26
use std::str::FromStr;
37

@@ -6,27 +10,91 @@ use crate::expand::{Decodable, Encodable, HashStable_Generic};
610
use crate::ptr::P;
711
use crate::{Ty, TyKind};
812

9-
#[allow(dead_code)]
1013
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
1114
pub enum DiffMode {
15+
/// No autodiff is applied (usually used during error handling).
1216
Inactive,
17+
/// The primal function which we will differentiate.
1318
Source,
19+
/// The target function, to be created using forward mode AD.
1420
Forward,
21+
/// The target function, to be created using reverse mode AD.
1522
Reverse,
23+
/// The target function, to be created using forward mode AD.
24+
/// This target function will also be used as a source for higher order derivatives,
25+
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
1626
ForwardFirst,
27+
/// The target function, to be created using reverse mode AD.
28+
/// This target function will also be used as a source for higher order derivatives,
29+
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
1730
ReverseFirst,
1831
}
1932

20-
pub fn is_rev(mode: DiffMode) -> bool {
21-
match mode {
22-
DiffMode::Reverse | DiffMode::ReverseFirst => true,
23-
_ => false,
24-
}
33+
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
34+
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
35+
/// we add to the previous shadow value. To not surprise users, we picked different names.
36+
/// Dual numbers is also a quite well known name for forward mode AD types.
37+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
38+
pub enum DiffActivity {
39+
/// Implicit or Explicit () return type, so a special case of Const.
40+
None,
41+
/// Don't compute derivatives with respect to this input/output.
42+
Const,
43+
/// Reverse Mode, Compute derivatives for this scalar input/output.
44+
Active,
45+
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
46+
/// the original return value.
47+
ActiveOnly,
48+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
49+
/// with it.
50+
Dual,
51+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
52+
/// with it. Drop the code which updates the original input/output for maximum performance.
53+
DualOnly,
54+
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
55+
Duplicated,
56+
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
57+
/// Drop the code which updates the original input for maximum performance.
58+
DuplicatedOnly,
59+
/// All Integers must be Const, but these are used to mark the integer which represents the
60+
/// length of a slice/vec. This is used for safety checks on slices.
61+
FakeActivitySize,
2562
}
26-
pub fn is_fwd(mode: DiffMode) -> bool {
27-
match mode {
28-
DiffMode::Forward | DiffMode::ForwardFirst => true,
29-
_ => false,
63+
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
64+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
65+
pub struct AutoDiffItem {
66+
/// The name of the function getting differentiated
67+
pub source: String,
68+
/// The name of the function being generated
69+
pub target: String,
70+
pub attrs: AutoDiffAttrs,
71+
/// Despribe the memory layout of input types
72+
pub inputs: Vec<TypeTree>,
73+
/// Despribe the memory layout of the output type
74+
pub output: TypeTree,
75+
}
76+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
77+
pub struct AutoDiffAttrs {
78+
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
79+
/// e.g. in the [JAX
80+
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
81+
pub mode: DiffMode,
82+
pub ret_activity: DiffActivity,
83+
pub input_activity: Vec<DiffActivity>,
84+
}
85+
86+
impl DiffMode {
87+
pub fn is_rev(&self) -> bool {
88+
match self {
89+
DiffMode::Reverse | DiffMode::ReverseFirst => true,
90+
_ => false,
91+
}
92+
}
93+
pub fn is_fwd(&self) -> bool {
94+
match self {
95+
DiffMode::Forward | DiffMode::ForwardFirst => true,
96+
_ => false,
97+
}
3098
}
3199
}
32100

@@ -63,30 +131,20 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
63131
}
64132
}
65133
}
66-
fn is_ptr_or_ref(ty: &Ty) -> bool {
67-
match ty.kind {
68-
TyKind::Ptr(_) | TyKind::Ref(_, _) => true,
69-
_ => false,
70-
}
71-
}
72-
// TODO We should make this more robust to also
134+
135+
// FIXME(ZuseZ4) We should make this more robust to also
73136
// accept aliases of f32 and f64
74-
//fn is_float(ty: &Ty) -> bool {
75-
// false
76-
//}
77137
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
78-
if is_ptr_or_ref(ty) {
79-
return activity == DiffActivity::Dual
80-
|| activity == DiffActivity::DualOnly
81-
|| activity == DiffActivity::Duplicated
82-
|| activity == DiffActivity::DuplicatedOnly
83-
|| activity == DiffActivity::Const;
138+
match ty.kind {
139+
TyKind::Ptr(_) | TyKind::Ref(..) => {
140+
return activity == DiffActivity::Dual
141+
|| activity == DiffActivity::DualOnly
142+
|| activity == DiffActivity::Duplicated
143+
|| activity == DiffActivity::DuplicatedOnly
144+
|| activity == DiffActivity::Const;
145+
}
146+
_ => false,
84147
}
85-
true
86-
//if is_scalar_ty(&ty) {
87-
// return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly ||
88-
// activity == DiffActivity::Const;
89-
//}
90148
}
91149
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
92150
return match mode {
@@ -117,20 +175,6 @@ pub fn invalid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -
117175
None
118176
}
119177

120-
#[allow(dead_code)]
121-
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
122-
pub enum DiffActivity {
123-
None,
124-
Const,
125-
Active,
126-
ActiveOnly,
127-
Dual,
128-
DualOnly,
129-
Duplicated,
130-
DuplicatedOnly,
131-
FakeActivitySize,
132-
}
133-
134178
impl Display for DiffActivity {
135179
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
136180
match self {
@@ -180,30 +224,14 @@ impl FromStr for DiffActivity {
180224
}
181225
}
182226

183-
#[allow(dead_code)]
184-
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
185-
pub struct AutoDiffAttrs {
186-
pub mode: DiffMode,
187-
pub ret_activity: DiffActivity,
188-
pub input_activity: Vec<DiffActivity>,
189-
}
190-
191227
impl AutoDiffAttrs {
192228
pub fn has_ret_activity(&self) -> bool {
193-
match self.ret_activity {
194-
DiffActivity::None => false,
195-
_ => true,
196-
}
229+
self.ret_activity != DiffActivity::None
197230
}
198231
pub fn has_active_only_ret(&self) -> bool {
199-
match self.ret_activity {
200-
DiffActivity::ActiveOnly => true,
201-
_ => false,
202-
}
232+
self.ret_activity == DiffActivity::ActiveOnly
203233
}
204-
}
205234

206-
impl AutoDiffAttrs {
207235
pub fn inactive() -> Self {
208236
AutoDiffAttrs {
209237
mode: DiffMode::Inactive,
@@ -251,15 +279,6 @@ impl AutoDiffAttrs {
251279
}
252280
}
253281

254-
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
255-
pub struct AutoDiffItem {
256-
pub source: String,
257-
pub target: String,
258-
pub attrs: AutoDiffAttrs,
259-
pub inputs: Vec<TypeTree>,
260-
pub output: TypeTree,
261-
}
262-
263282
impl fmt::Display for AutoDiffItem {
264283
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265284
write!(f, "Differentiating {} -> {}", self.source, self.target)?;

Diff for: compiler/rustc_builtin_macros/src/autodiff.rs

+10-11
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ use std::str::FromStr;
33
use std::string::String;
44

55
use rustc_ast::expand::autodiff_attrs::{
6-
is_fwd, is_rev, valid_input_activity, valid_ty_for_activity, AutoDiffAttrs, DiffActivity,
7-
DiffMode,
6+
valid_input_activity, valid_ty_for_activity, AutoDiffAttrs, DiffActivity, DiffMode,
87
};
98
use rustc_ast::ptr::P;
109
use rustc_ast::token::{Token, TokenKind};
@@ -382,7 +381,7 @@ fn gen_enzyme_body(
382381
// So that can be treated identical to not having one in the first place.
383382
let primal_ret = sig.decl.output.has_ret() && !x.has_active_only_ret();
384383

385-
if primal_ret && n_active == 0 && is_rev(x.mode) {
384+
if primal_ret && n_active == 0 && x.mode.is_rev() {
386385
// We only have the primal ret.
387386
body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone()));
388387
return body;
@@ -437,7 +436,7 @@ fn gen_enzyme_body(
437436
panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
438437
}
439438
};
440-
if is_fwd(x.mode) {
439+
if x.mode.is_fwd() {
441440
if x.ret_activity == DiffActivity::Dual {
442441
assert!(d_ret_ty.len() == 2);
443442
// both should be identical, by construction
@@ -451,7 +450,7 @@ fn gen_enzyme_body(
451450
exprs.push(default_call_expr);
452451
}
453452
} else {
454-
assert!(is_rev(x.mode));
453+
assert!(x.mode.is_rev());
455454

456455
if primal_ret {
457456
// We have extra handling above for the primal ret
@@ -562,7 +561,7 @@ fn gen_enzyme_decl(
562561
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
563562
ident.name
564563
} else {
565-
trace!("{:#?}", &shadow_arg.pat);
564+
dbg!(&shadow_arg.pat);
566565
panic!("not an ident?");
567566
};
568567
let name: String = format!("d{}", old_name);
@@ -582,7 +581,7 @@ fn gen_enzyme_decl(
582581
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
583582
ident.name
584583
} else {
585-
trace!("{:#?}", &shadow_arg.pat);
584+
dbg!(&shadow_arg.pat);
586585
panic!("not an ident?");
587586
};
588587
let name: String = format!("b{}", old_name);
@@ -601,7 +600,7 @@ fn gen_enzyme_decl(
601600
// Nothing to do here.
602601
}
603602
_ => {
604-
trace!{"{:#?}", &activity};
603+
dbg!(&activity);
605604
panic!("Not implemented");
606605
}
607606
}
@@ -614,12 +613,12 @@ fn gen_enzyme_decl(
614613

615614
let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
616615
if active_only_ret {
617-
assert!(is_rev(x.mode));
616+
assert!(x.mode.is_rev());
618617
}
619618

620619
// If we return a scalar in the primal and the scalar is active,
621620
// then add it as last arg to the inputs.
622-
if is_rev(x.mode) {
621+
if x.mode.is_rev() {
623622
match x.ret_activity {
624623
DiffActivity::Active | DiffActivity::ActiveOnly => {
625624
let ty = match d_decl.output {
@@ -651,7 +650,7 @@ fn gen_enzyme_decl(
651650
}
652651
d_decl.inputs = d_inputs.into();
653652

654-
if is_fwd(x.mode) {
653+
if x.mode.is_fwd() {
655654
if let DiffActivity::Dual = x.ret_activity {
656655
let ty = match d_decl.output {
657656
FnRetTy::Ty(ref ty) => ty.clone(),

0 commit comments

Comments
 (0)