forked from rust-lang/rust
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathautodiff_attrs.rs
269 lines (251 loc) · 7.95 KB
/
autodiff_attrs.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
use std::fmt::{self, Display, Formatter};
use std::str::FromStr;
use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::ptr::P;
use crate::{Ty, TyKind};
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffMode {
Inactive,
Source,
Forward,
Reverse,
ForwardFirst,
ReverseFirst,
}
pub fn is_rev(mode: DiffMode) -> bool {
match mode {
DiffMode::Reverse | DiffMode::ReverseFirst => true,
_ => false,
}
}
pub fn is_fwd(mode: DiffMode) -> bool {
match mode {
DiffMode::Forward | DiffMode::ForwardFirst => true,
_ => false,
}
}
impl Display for DiffMode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
DiffMode::Inactive => write!(f, "Inactive"),
DiffMode::Source => write!(f, "Source"),
DiffMode::Forward => write!(f, "Forward"),
DiffMode::Reverse => write!(f, "Reverse"),
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
}
}
}
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
if activity == DiffActivity::None {
// Only valid if primal returns (), but we can't check that here.
return true;
}
match mode {
DiffMode::Inactive => false,
DiffMode::Source => false,
DiffMode::Forward | DiffMode::ForwardFirst => {
activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Const
}
DiffMode::Reverse | DiffMode::ReverseFirst => {
activity == DiffActivity::Const
|| activity == DiffActivity::Active
|| activity == DiffActivity::ActiveOnly
}
}
}
fn is_ptr_or_ref(ty: &Ty) -> bool {
match ty.kind {
TyKind::Ptr(_) | TyKind::Ref(_, _) => true,
_ => false,
}
}
// TODO We should make this more robust to also
// accept aliases of f32 and f64
//fn is_float(ty: &Ty) -> bool {
// false
//}
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
if is_ptr_or_ref(ty) {
return activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Duplicated
|| activity == DiffActivity::DuplicatedOnly
|| activity == DiffActivity::Const;
}
true
//if is_scalar_ty(&ty) {
// return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly ||
// activity == DiffActivity::Const;
//}
}
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
return match mode {
DiffMode::Inactive => false,
DiffMode::Source => false,
DiffMode::Forward | DiffMode::ForwardFirst => {
// These are the only valid cases
activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Const
}
DiffMode::Reverse | DiffMode::ReverseFirst => {
// These are the only valid cases
activity == DiffActivity::Active
|| activity == DiffActivity::ActiveOnly
|| activity == DiffActivity::Const
|| activity == DiffActivity::Duplicated
|| activity == DiffActivity::DuplicatedOnly
}
};
}
pub fn invalid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> Option<usize> {
for i in 0..activity_vec.len() {
if !valid_input_activity(mode, activity_vec[i]) {
return Some(i);
}
}
None
}
#[allow(dead_code)]
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffActivity {
None,
Const,
Active,
ActiveOnly,
Dual,
DualOnly,
Duplicated,
DuplicatedOnly,
FakeActivitySize,
}
impl Display for DiffActivity {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
DiffActivity::None => write!(f, "None"),
DiffActivity::Const => write!(f, "Const"),
DiffActivity::Active => write!(f, "Active"),
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
DiffActivity::Dual => write!(f, "Dual"),
DiffActivity::DualOnly => write!(f, "DualOnly"),
DiffActivity::Duplicated => write!(f, "Duplicated"),
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
}
}
}
impl FromStr for DiffMode {
type Err = ();
fn from_str(s: &str) -> Result<DiffMode, ()> {
match s {
"Inactive" => Ok(DiffMode::Inactive),
"Source" => Ok(DiffMode::Source),
"Forward" => Ok(DiffMode::Forward),
"Reverse" => Ok(DiffMode::Reverse),
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
_ => Err(()),
}
}
}
impl FromStr for DiffActivity {
type Err = ();
fn from_str(s: &str) -> Result<DiffActivity, ()> {
match s {
"None" => Ok(DiffActivity::None),
"Active" => Ok(DiffActivity::Active),
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
"Const" => Ok(DiffActivity::Const),
"Dual" => Ok(DiffActivity::Dual),
"DualOnly" => Ok(DiffActivity::DualOnly),
"Duplicated" => Ok(DiffActivity::Duplicated),
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
_ => Err(()),
}
}
}
#[allow(dead_code)]
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
pub mode: DiffMode,
pub ret_activity: DiffActivity,
pub input_activity: Vec<DiffActivity>,
}
impl AutoDiffAttrs {
pub fn has_ret_activity(&self) -> bool {
match self.ret_activity {
DiffActivity::None => false,
_ => true,
}
}
pub fn has_active_only_ret(&self) -> bool {
match self.ret_activity {
DiffActivity::ActiveOnly => true,
_ => false,
}
}
}
impl AutoDiffAttrs {
pub fn inactive() -> Self {
AutoDiffAttrs {
mode: DiffMode::Inactive,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}
pub fn source() -> Self {
AutoDiffAttrs {
mode: DiffMode::Source,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}
pub fn is_active(&self) -> bool {
match self.mode {
DiffMode::Inactive => false,
_ => true,
}
}
pub fn is_source(&self) -> bool {
match self.mode {
DiffMode::Source => true,
_ => false,
}
}
pub fn apply_autodiff(&self) -> bool {
match self.mode {
DiffMode::Inactive => false,
DiffMode::Source => false,
_ => true,
}
}
pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
}
}
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
pub source: String,
pub target: String,
pub attrs: AutoDiffAttrs,
pub inputs: Vec<TypeTree>,
pub output: TypeTree,
}
impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
}
}