Skip to content

Commit 624c071

Browse files
ZuseZ4bytesnake
andcommitted
Single commit implementing the enzyme/autodiff frontend
Co-authored-by: Lorenz Schmidt <[email protected]>
1 parent 52fd998 commit 624c071

File tree

17 files changed

+1384
-1
lines changed

17 files changed

+1384
-1
lines changed

Diff for: compiler/rustc_ast/src/expand/autodiff_attrs.rs

+283
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2+
//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
3+
//! is the function to which the autodiff attribute is applied, and the target is the function
4+
//! getting generated by us (with a name given by the user as the first autodiff arg).
5+
6+
use std::fmt::{self, Display, Formatter};
7+
use std::str::FromStr;
8+
9+
use crate::expand::typetree::TypeTree;
10+
use crate::expand::{Decodable, Encodable, HashStable_Generic};
11+
use crate::ptr::P;
12+
use crate::{Ty, TyKind};
13+
14+
/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
15+
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
16+
/// are a hack to support higher order derivatives. We need to compute first order derivatives
17+
/// before we compute second order derivatives, otherwise we would differentiate our placeholder
18+
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
19+
/// as it's already done in the C++ and Julia frontend of Enzyme.
20+
///
21+
/// (FIXME) remove *First variants.
22+
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
23+
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
24+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
25+
pub enum DiffMode {
26+
/// No autodiff is applied (used during error handling).
27+
Error,
28+
/// The primal function which we will differentiate.
29+
Source,
30+
/// The target function, to be created using forward mode AD.
31+
Forward,
32+
/// The target function, to be created using reverse mode AD.
33+
Reverse,
34+
/// The target function, to be created using forward mode AD.
35+
/// This target function will also be used as a source for higher order derivatives,
36+
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
37+
ForwardFirst,
38+
/// The target function, to be created using reverse mode AD.
39+
/// This target function will also be used as a source for higher order derivatives,
40+
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
41+
ReverseFirst,
42+
}
43+
44+
/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
45+
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
46+
/// we add to the previous shadow value. To not surprise users, we picked different names.
47+
/// Dual numbers is also a quite well known name for forward mode AD types.
48+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
49+
pub enum DiffActivity {
50+
/// Implicit or Explicit () return type, so a special case of Const.
51+
None,
52+
/// Don't compute derivatives with respect to this input/output.
53+
Const,
54+
/// Reverse Mode, Compute derivatives for this scalar input/output.
55+
Active,
56+
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
57+
/// the original return value.
58+
ActiveOnly,
59+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
60+
/// with it.
61+
Dual,
62+
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
63+
/// with it. Drop the code which updates the original input/output for maximum performance.
64+
DualOnly,
65+
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
66+
Duplicated,
67+
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
68+
/// Drop the code which updates the original input for maximum performance.
69+
DuplicatedOnly,
70+
/// All Integers must be Const, but these are used to mark the integer which represents the
71+
/// length of a slice/vec. This is used for safety checks on slices.
72+
FakeActivitySize,
73+
}
74+
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
75+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
76+
pub struct AutoDiffItem {
77+
/// The name of the function getting differentiated
78+
pub source: String,
79+
/// The name of the function being generated
80+
pub target: String,
81+
pub attrs: AutoDiffAttrs,
82+
/// Describe the memory layout of input types
83+
pub inputs: Vec<TypeTree>,
84+
/// Describe the memory layout of the output type
85+
pub output: TypeTree,
86+
}
87+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
88+
pub struct AutoDiffAttrs {
89+
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
90+
/// e.g. in the [JAX
91+
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
92+
pub mode: DiffMode,
93+
pub ret_activity: DiffActivity,
94+
pub input_activity: Vec<DiffActivity>,
95+
}
96+
97+
impl DiffMode {
98+
pub fn is_rev(&self) -> bool {
99+
matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst)
100+
}
101+
pub fn is_fwd(&self) -> bool {
102+
matches!(self, DiffMode::Forward | DiffMode::ForwardFirst)
103+
}
104+
}
105+
106+
impl Display for DiffMode {
107+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
108+
match self {
109+
DiffMode::Error => write!(f, "Error"),
110+
DiffMode::Source => write!(f, "Source"),
111+
DiffMode::Forward => write!(f, "Forward"),
112+
DiffMode::Reverse => write!(f, "Reverse"),
113+
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
114+
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
115+
}
116+
}
117+
}
118+
119+
/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
120+
/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
121+
/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
122+
/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
123+
/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
124+
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
125+
if activity == DiffActivity::None {
126+
// Only valid if primal returns (), but we can't check that here.
127+
return true;
128+
}
129+
match mode {
130+
DiffMode::Error => false,
131+
DiffMode::Source => false,
132+
DiffMode::Forward | DiffMode::ForwardFirst => {
133+
activity == DiffActivity::Dual
134+
|| activity == DiffActivity::DualOnly
135+
|| activity == DiffActivity::Const
136+
}
137+
DiffMode::Reverse | DiffMode::ReverseFirst => {
138+
activity == DiffActivity::Const
139+
|| activity == DiffActivity::Active
140+
|| activity == DiffActivity::ActiveOnly
141+
}
142+
}
143+
}
144+
145+
/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
146+
/// for the given argument, but we generally can't know the size of such a type.
147+
/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
148+
/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
149+
/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
150+
/// users here from marking scalars as Duplicated, due to type aliases.
151+
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
152+
use DiffActivity::*;
153+
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
154+
if matches!(activity, Const) {
155+
return true;
156+
}
157+
if matches!(activity, Dual | DualOnly) {
158+
return true;
159+
}
160+
// FIXME(ZuseZ4) We should make this more robust to also
161+
// handle type aliases. Once that is done, we can be more restrictive here.
162+
if matches!(activity, Active | ActiveOnly) {
163+
return true;
164+
}
165+
matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
166+
&& matches!(activity, Duplicated | DuplicatedOnly)
167+
}
168+
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
169+
use DiffActivity::*;
170+
return match mode {
171+
DiffMode::Error => false,
172+
DiffMode::Source => false,
173+
DiffMode::Forward | DiffMode::ForwardFirst => {
174+
matches!(activity, Dual | DualOnly | Const)
175+
}
176+
DiffMode::Reverse | DiffMode::ReverseFirst => {
177+
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
178+
}
179+
};
180+
}
181+
182+
impl Display for DiffActivity {
183+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
184+
match self {
185+
DiffActivity::None => write!(f, "None"),
186+
DiffActivity::Const => write!(f, "Const"),
187+
DiffActivity::Active => write!(f, "Active"),
188+
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
189+
DiffActivity::Dual => write!(f, "Dual"),
190+
DiffActivity::DualOnly => write!(f, "DualOnly"),
191+
DiffActivity::Duplicated => write!(f, "Duplicated"),
192+
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
193+
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
194+
}
195+
}
196+
}
197+
198+
impl FromStr for DiffMode {
199+
type Err = ();
200+
201+
fn from_str(s: &str) -> Result<DiffMode, ()> {
202+
match s {
203+
"Error" => Ok(DiffMode::Error),
204+
"Source" => Ok(DiffMode::Source),
205+
"Forward" => Ok(DiffMode::Forward),
206+
"Reverse" => Ok(DiffMode::Reverse),
207+
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
208+
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
209+
_ => Err(()),
210+
}
211+
}
212+
}
213+
impl FromStr for DiffActivity {
214+
type Err = ();
215+
216+
fn from_str(s: &str) -> Result<DiffActivity, ()> {
217+
match s {
218+
"None" => Ok(DiffActivity::None),
219+
"Active" => Ok(DiffActivity::Active),
220+
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
221+
"Const" => Ok(DiffActivity::Const),
222+
"Dual" => Ok(DiffActivity::Dual),
223+
"DualOnly" => Ok(DiffActivity::DualOnly),
224+
"Duplicated" => Ok(DiffActivity::Duplicated),
225+
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
226+
_ => Err(()),
227+
}
228+
}
229+
}
230+
231+
impl AutoDiffAttrs {
232+
pub fn has_ret_activity(&self) -> bool {
233+
self.ret_activity != DiffActivity::None
234+
}
235+
pub fn has_active_only_ret(&self) -> bool {
236+
self.ret_activity == DiffActivity::ActiveOnly
237+
}
238+
239+
pub fn error() -> Self {
240+
AutoDiffAttrs {
241+
mode: DiffMode::Error,
242+
ret_activity: DiffActivity::None,
243+
input_activity: Vec::new(),
244+
}
245+
}
246+
pub fn source() -> Self {
247+
AutoDiffAttrs {
248+
mode: DiffMode::Source,
249+
ret_activity: DiffActivity::None,
250+
input_activity: Vec::new(),
251+
}
252+
}
253+
254+
pub fn is_active(&self) -> bool {
255+
self.mode != DiffMode::Error
256+
}
257+
258+
pub fn is_source(&self) -> bool {
259+
self.mode == DiffMode::Source
260+
}
261+
pub fn apply_autodiff(&self) -> bool {
262+
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
263+
}
264+
265+
pub fn into_item(
266+
self,
267+
source: String,
268+
target: String,
269+
inputs: Vec<TypeTree>,
270+
output: TypeTree,
271+
) -> AutoDiffItem {
272+
AutoDiffItem { source, target, inputs, output, attrs: self }
273+
}
274+
}
275+
276+
impl fmt::Display for AutoDiffItem {
277+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278+
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
279+
write!(f, " with attributes: {:?}", self.attrs)?;
280+
write!(f, " with inputs: {:?}", self.inputs)?;
281+
write!(f, " with output: {:?}", self.output)
282+
}
283+
}

Diff for: compiler/rustc_ast/src/expand/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
77
use crate::MetaItem;
88

99
pub mod allocator;
10+
pub mod autodiff_attrs;
11+
pub mod typetree;
1012

1113
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
1214
pub struct StrippedCfgItem<ModId = DefId> {

Diff for: compiler/rustc_ast/src/expand/typetree.rs

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//! This module contains the definition of the `TypeTree` and `Type` structs.
2+
//! They are thin Rust wrappers around the TypeTrees used by Enzyme as the LLVM based autodiff
3+
//! backend. The Enzyme TypeTrees currently have various limitations and should be rewritten, so the
4+
//! Rust frontend obviously has the same limitations. The main motivation of TypeTrees is to
5+
//! represent how a type looks like "in memory". Enzyme can deduce this based on usage patterns in
6+
//! the user code, but this is extremely slow and not even always sufficient. As such we lower some
7+
//! information from rustc to help Enzyme. For a full explanation of their design it is necessary to
8+
//! analyze the implementation in Enzyme core itself. As a rough summary, `-1` in Enzyme speech means
9+
//! everywhere. That is `{0:-1: Float}` means at index 0 you have a ptr, if you dereference it it
10+
//! will be floats everywhere. Thus `* f32`. If you have `{-1:int}` it means int's everywhere,
11+
//! e.g. [i32; N]. `{0:-1:-1 float}` then means one pointer at offset 0, if you dereference it there
12+
//! will be only pointers, if you dereference these new pointers they will point to array of floats.
13+
//! Generally, it allows byte-specific descriptions.
14+
//! FIXME: This description might be partly inaccurate and should be extended, along with
15+
//! adding documentation to the corresponding Enzyme core code.
16+
//! FIXME: Rewrite the TypeTree logic in Enzyme core to reduce the need for the rustc frontend to
17+
//! provide typetree information.
18+
//! FIXME: We should also re-evaluate where we create TypeTrees from Rust types, since MIR
19+
//! representations of some types might not be accurate. For example a vector of floats might be
20+
//! represented as a vector of u8s in MIR in some cases.
21+
22+
use std::fmt;
23+
24+
use crate::expand::{Decodable, Encodable, HashStable_Generic};
25+
26+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
27+
pub enum Kind {
28+
Anything,
29+
Integer,
30+
Pointer,
31+
Half,
32+
Float,
33+
Double,
34+
Unknown,
35+
}
36+
37+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
38+
pub struct TypeTree(pub Vec<Type>);
39+
40+
impl TypeTree {
41+
pub fn new() -> Self {
42+
Self(Vec::new())
43+
}
44+
pub fn all_ints() -> Self {
45+
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
46+
}
47+
pub fn int(size: usize) -> Self {
48+
let mut ints = Vec::with_capacity(size);
49+
for i in 0..size {
50+
ints.push(Type {
51+
offset: i as isize,
52+
size: 1,
53+
kind: Kind::Integer,
54+
child: TypeTree::new(),
55+
});
56+
}
57+
Self(ints)
58+
}
59+
}
60+
61+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
62+
pub struct FncTree {
63+
pub args: Vec<TypeTree>,
64+
pub ret: TypeTree,
65+
}
66+
67+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
68+
pub struct Type {
69+
pub offset: isize,
70+
pub size: usize,
71+
pub kind: Kind,
72+
pub child: TypeTree,
73+
}
74+
75+
impl Type {
76+
pub fn add_offset(self, add: isize) -> Self {
77+
let offset = match self.offset {
78+
-1 => add,
79+
x => add + x,
80+
};
81+
82+
Self { size: self.size, kind: self.kind, child: self.child, offset }
83+
}
84+
}
85+
86+
impl fmt::Display for Type {
87+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88+
<Self as fmt::Debug>::fmt(self, f)
89+
}
90+
}

Diff for: compiler/rustc_builtin_macros/Cargo.toml

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
33
version = "0.0.0"
44
edition = "2021"
55

6+
7+
[lints.rust]
8+
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }
9+
610
[lib]
711
doctest = false
812

0 commit comments

Comments
 (0)