Skip to content

Commit 443214c

Browse files
committed
implement a working autodiff frontend
1 parent b5723af commit 443214c

File tree

21 files changed

+1269
-0
lines changed

21 files changed

+1269
-0
lines changed

Diff for: compiler/rustc_ast/src/ast.rs

+7
Original file line numberDiff line numberDiff line change
@@ -2733,6 +2733,13 @@ impl FnRetTy {
27332733
FnRetTy::Ty(ty) => ty.span,
27342734
}
27352735
}
2736+
2737+
pub fn has_ret(&self) -> bool {
2738+
match self {
2739+
FnRetTy::Default(_) => false,
2740+
FnRetTy::Ty(_) => true,
2741+
}
2742+
}
27362743
}
27372744

27382745
#[derive(Clone, Copy, PartialEq, Encodable, Decodable, Debug)]

Diff for: compiler/rustc_ast/src/expand/autodiff_attrs.rs

+270
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
use std::fmt::{self, Display, Formatter};
2+
use std::str::FromStr;
3+
4+
use crate::expand::typetree::TypeTree;
5+
use crate::expand::{Decodable, Encodable, HashStable_Generic};
6+
use crate::ptr::P;
7+
use crate::{Ty, TyKind};
8+
9+
#[allow(dead_code)]
10+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
11+
pub enum DiffMode {
12+
Inactive,
13+
Source,
14+
Forward,
15+
Reverse,
16+
ForwardFirst,
17+
ReverseFirst,
18+
}
19+
20+
pub fn is_rev(mode: DiffMode) -> bool {
21+
match mode {
22+
DiffMode::Reverse | DiffMode::ReverseFirst => true,
23+
_ => false,
24+
}
25+
}
26+
pub fn is_fwd(mode: DiffMode) -> bool {
27+
match mode {
28+
DiffMode::Forward | DiffMode::ForwardFirst => true,
29+
_ => false,
30+
}
31+
}
32+
33+
impl Display for DiffMode {
34+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
35+
match self {
36+
DiffMode::Inactive => write!(f, "Inactive"),
37+
DiffMode::Source => write!(f, "Source"),
38+
DiffMode::Forward => write!(f, "Forward"),
39+
DiffMode::Reverse => write!(f, "Reverse"),
40+
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
41+
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
42+
}
43+
}
44+
}
45+
46+
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
47+
if activity == DiffActivity::None {
48+
// Only valid if primal returns (), but we can't check that here.
49+
return true;
50+
}
51+
match mode {
52+
DiffMode::Inactive => false,
53+
DiffMode::Source => false,
54+
DiffMode::Forward | DiffMode::ForwardFirst => {
55+
activity == DiffActivity::Dual
56+
|| activity == DiffActivity::DualOnly
57+
|| activity == DiffActivity::Const
58+
}
59+
DiffMode::Reverse | DiffMode::ReverseFirst => {
60+
activity == DiffActivity::Const
61+
|| activity == DiffActivity::Active
62+
|| activity == DiffActivity::ActiveOnly
63+
}
64+
}
65+
}
66+
fn is_ptr_or_ref(ty: &Ty) -> bool {
67+
match ty.kind {
68+
TyKind::Ptr(_) | TyKind::Ref(_, _) => true,
69+
_ => false,
70+
}
71+
}
72+
// TODO We should make this more robust to also
73+
// accept aliases of f32 and f64
74+
//fn is_float(ty: &Ty) -> bool {
75+
// false
76+
//}
77+
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
78+
if is_ptr_or_ref(ty) {
79+
return activity == DiffActivity::Dual
80+
|| activity == DiffActivity::DualOnly
81+
|| activity == DiffActivity::Duplicated
82+
|| activity == DiffActivity::DuplicatedOnly
83+
|| activity == DiffActivity::Const;
84+
}
85+
true
86+
//if is_scalar_ty(&ty) {
87+
// return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly ||
88+
// activity == DiffActivity::Const;
89+
//}
90+
}
91+
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
92+
return match mode {
93+
DiffMode::Inactive => false,
94+
DiffMode::Source => false,
95+
DiffMode::Forward | DiffMode::ForwardFirst => {
96+
// These are the only valid cases
97+
activity == DiffActivity::Dual
98+
|| activity == DiffActivity::DualOnly
99+
|| activity == DiffActivity::Const
100+
}
101+
DiffMode::Reverse | DiffMode::ReverseFirst => {
102+
// These are the only valid cases
103+
activity == DiffActivity::Active
104+
|| activity == DiffActivity::ActiveOnly
105+
|| activity == DiffActivity::Const
106+
|| activity == DiffActivity::Duplicated
107+
|| activity == DiffActivity::DuplicatedOnly
108+
}
109+
};
110+
}
111+
pub fn invalid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> Option<usize> {
112+
for i in 0..activity_vec.len() {
113+
if !valid_input_activity(mode, activity_vec[i]) {
114+
return Some(i);
115+
}
116+
}
117+
None
118+
}
119+
120+
#[allow(dead_code)]
121+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
122+
pub enum DiffActivity {
123+
None,
124+
Const,
125+
Active,
126+
ActiveOnly,
127+
Dual,
128+
DualOnly,
129+
Duplicated,
130+
DuplicatedOnly,
131+
FakeActivitySize,
132+
}
133+
134+
impl Display for DiffActivity {
135+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
136+
match self {
137+
DiffActivity::None => write!(f, "None"),
138+
DiffActivity::Const => write!(f, "Const"),
139+
DiffActivity::Active => write!(f, "Active"),
140+
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
141+
DiffActivity::Dual => write!(f, "Dual"),
142+
DiffActivity::DualOnly => write!(f, "DualOnly"),
143+
DiffActivity::Duplicated => write!(f, "Duplicated"),
144+
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
145+
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
146+
}
147+
}
148+
}
149+
150+
impl FromStr for DiffMode {
151+
type Err = ();
152+
153+
fn from_str(s: &str) -> Result<DiffMode, ()> {
154+
match s {
155+
"Inactive" => Ok(DiffMode::Inactive),
156+
"Source" => Ok(DiffMode::Source),
157+
"Forward" => Ok(DiffMode::Forward),
158+
"Reverse" => Ok(DiffMode::Reverse),
159+
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
160+
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
161+
_ => Err(()),
162+
}
163+
}
164+
}
165+
impl FromStr for DiffActivity {
166+
type Err = ();
167+
168+
fn from_str(s: &str) -> Result<DiffActivity, ()> {
169+
match s {
170+
"None" => Ok(DiffActivity::None),
171+
"Active" => Ok(DiffActivity::Active),
172+
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
173+
"Const" => Ok(DiffActivity::Const),
174+
"Dual" => Ok(DiffActivity::Dual),
175+
"DualOnly" => Ok(DiffActivity::DualOnly),
176+
"Duplicated" => Ok(DiffActivity::Duplicated),
177+
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
178+
_ => Err(()),
179+
}
180+
}
181+
}
182+
183+
#[allow(dead_code)]
184+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
185+
pub struct AutoDiffAttrs {
186+
pub mode: DiffMode,
187+
pub ret_activity: DiffActivity,
188+
pub input_activity: Vec<DiffActivity>,
189+
}
190+
191+
impl AutoDiffAttrs {
192+
pub fn has_ret_activity(&self) -> bool {
193+
match self.ret_activity {
194+
DiffActivity::None => false,
195+
_ => true,
196+
}
197+
}
198+
pub fn has_active_only_ret(&self) -> bool {
199+
match self.ret_activity {
200+
DiffActivity::ActiveOnly => true,
201+
_ => false,
202+
}
203+
}
204+
}
205+
206+
impl AutoDiffAttrs {
207+
pub fn inactive() -> Self {
208+
AutoDiffAttrs {
209+
mode: DiffMode::Inactive,
210+
ret_activity: DiffActivity::None,
211+
input_activity: Vec::new(),
212+
}
213+
}
214+
pub fn source() -> Self {
215+
AutoDiffAttrs {
216+
mode: DiffMode::Source,
217+
ret_activity: DiffActivity::None,
218+
input_activity: Vec::new(),
219+
}
220+
}
221+
222+
pub fn is_active(&self) -> bool {
223+
match self.mode {
224+
DiffMode::Inactive => false,
225+
_ => true,
226+
}
227+
}
228+
229+
pub fn is_source(&self) -> bool {
230+
match self.mode {
231+
DiffMode::Source => true,
232+
_ => false,
233+
}
234+
}
235+
pub fn apply_autodiff(&self) -> bool {
236+
match self.mode {
237+
DiffMode::Inactive => false,
238+
DiffMode::Source => false,
239+
_ => true,
240+
}
241+
}
242+
243+
pub fn into_item(
244+
self,
245+
source: String,
246+
target: String,
247+
inputs: Vec<TypeTree>,
248+
output: TypeTree,
249+
) -> AutoDiffItem {
250+
AutoDiffItem { source, target, inputs, output, attrs: self }
251+
}
252+
}
253+
254+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
255+
pub struct AutoDiffItem {
256+
pub source: String,
257+
pub target: String,
258+
pub attrs: AutoDiffAttrs,
259+
pub inputs: Vec<TypeTree>,
260+
pub output: TypeTree,
261+
}
262+
263+
impl fmt::Display for AutoDiffItem {
264+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265+
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
266+
write!(f, " with attributes: {:?}", self.attrs)?;
267+
write!(f, " with inputs: {:?}", self.inputs)?;
268+
write!(f, " with output: {:?}", self.output)
269+
}
270+
}

Diff for: compiler/rustc_ast/src/expand/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
77
use crate::MetaItem;
88

99
pub mod allocator;
10+
pub mod autodiff_attrs;
11+
pub mod typetree;
1012

1113
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
1214
pub struct StrippedCfgItem<ModId = DefId> {

Diff for: compiler/rustc_ast/src/expand/typetree.rs

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
use std::fmt;
2+
3+
use crate::expand::{Decodable, Encodable, HashStable_Generic};
4+
5+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
6+
pub enum Kind {
7+
Anything,
8+
Integer,
9+
Pointer,
10+
Half,
11+
Float,
12+
Double,
13+
Unknown,
14+
}
15+
16+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
17+
pub struct TypeTree(pub Vec<Type>);
18+
19+
impl TypeTree {
20+
pub fn new() -> Self {
21+
Self(Vec::new())
22+
}
23+
pub fn all_ints() -> Self {
24+
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
25+
}
26+
pub fn int(size: usize) -> Self {
27+
let mut ints = Vec::with_capacity(size);
28+
for i in 0..size {
29+
ints.push(Type {
30+
offset: i as isize,
31+
size: 1,
32+
kind: Kind::Integer,
33+
child: TypeTree::new(),
34+
});
35+
}
36+
Self(ints)
37+
}
38+
}
39+
40+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
41+
pub struct FncTree {
42+
pub args: Vec<TypeTree>,
43+
pub ret: TypeTree,
44+
}
45+
46+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
47+
pub struct Type {
48+
pub offset: isize,
49+
pub size: usize,
50+
pub kind: Kind,
51+
pub child: TypeTree,
52+
}
53+
54+
impl Type {
55+
pub fn add_offset(self, add: isize) -> Self {
56+
let offset = match self.offset {
57+
-1 => add,
58+
x => add + x,
59+
};
60+
61+
Self { size: self.size, kind: self.kind, child: self.child, offset }
62+
}
63+
}
64+
65+
impl fmt::Display for Type {
66+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67+
<Self as fmt::Debug>::fmt(self, f)
68+
}
69+
}

Diff for: compiler/rustc_builtin_macros/Cargo.toml

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
33
version = "0.0.0"
44
edition = "2021"
55

6+
7+
[lints.rust]
8+
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }
9+
610
[lib]
711
doctest = false
812

Diff for: compiler/rustc_builtin_macros/messages.ftl

+8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
builtin_macros_alloc_error_must_be_fn = alloc_error_handler must be a function
22
builtin_macros_alloc_must_statics = allocators must be statics
33
4+
builtin_macros_autodiff_unknown_activity = did not recognize activity {$act}
5+
builtin_macros_autodiff = autodiff must be applied to function
6+
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
7+
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
8+
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
9+
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
10+
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
11+
412
builtin_macros_asm_clobber_abi = clobber_abi
513
builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs
614
builtin_macros_asm_clobber_outputs = generic outputs

0 commit comments

Comments
 (0)