3
3
4
4
use proc_macro:: TokenStream ;
5
5
use proc_macro_error:: proc_macro_error;
6
+ use quote:: { format_ident, quote, quote_spanned} ;
7
+ use syn:: {
8
+ parse_macro_input, parse_quote, spanned:: Spanned , Data , DataEnum , DeriveInput , Fields ,
9
+ GenericParam , Generics , Ident , Index , ItemStruct ,
10
+ } ;
6
11
7
12
#[ cfg( kani_host) ]
8
13
#[ path = "kani.rs" ]
@@ -12,6 +17,135 @@ mod tool;
12
17
#[ path = "runtime.rs" ]
13
18
mod tool;
14
19
20
+ /// Expands the `#[invariant(...)]` attribute macro.
21
+ /// The macro expands to an implementation of the `is_safe` method for the `Invariant` trait.
22
+ /// This attribute is only supported for structs.
23
+ ///
24
+ /// # Example
25
+ ///
26
+ /// ```ignore
27
+ /// #[invariant(self.width == self.height)]
28
+ /// struct Square {
29
+ /// width: u32,
30
+ /// height: u32,
31
+ /// }
32
+ /// ```
33
+ ///
34
+ /// expands to:
35
+ /// ```ignore
36
+ /// impl core::ub_checks::Invariant for Square {
37
+ /// fn is_safe(&self) -> bool {
38
+ /// self.width == self.height
39
+ /// }
40
+ /// }
41
+ /// ```
42
+ /// For more information on the Invariant trait, see its documentation in core::ub_checks.
43
+ #[ proc_macro_error]
44
+ #[ proc_macro_attribute]
45
+ pub fn invariant ( attr : TokenStream , item : TokenStream ) -> TokenStream {
46
+ let safe_body = proc_macro2:: TokenStream :: from ( attr) ;
47
+ let item = parse_macro_input ! ( item as ItemStruct ) ;
48
+ let item_name = & item. ident ;
49
+ let ( impl_generics, ty_generics, where_clause) = item. generics . split_for_impl ( ) ;
50
+
51
+ let expanded = quote ! {
52
+ #item
53
+ #[ unstable( feature="invariant" , issue="none" ) ]
54
+ impl #impl_generics core:: ub_checks:: Invariant for #item_name #ty_generics #where_clause {
55
+ fn is_safe( & self ) -> bool {
56
+ #safe_body
57
+ }
58
+ }
59
+ } ;
60
+
61
+ proc_macro:: TokenStream :: from ( expanded)
62
+ }
63
+
64
+ /// Expands the derive macro for the Invariant trait.
65
+ /// The macro expands to an implementation of the `is_safe` method for the `Invariant` trait.
66
+ /// This macro is only supported for structs and enums.
67
+ ///
68
+ /// # Example
69
+ ///
70
+ /// ```ignore
71
+ /// #[derive(Invariant)]
72
+ /// struct Square {
73
+ /// width: u32,
74
+ /// height: u32,
75
+ /// }
76
+ /// ```
77
+ ///
78
+ /// expands to:
79
+ /// ```ignore
80
+ /// impl core::ub_checks::Invariant for Square {
81
+ /// fn is_safe(&self) -> bool {
82
+ /// self.width.is_safe() && self.height.is_safe()
83
+ /// }
84
+ /// }
85
+ /// ```
86
+ /// For enums, the body of `is_safe` matches on the variant and calls `is_safe` on its fields,
87
+ /// # Example
88
+ ///
89
+ /// ```ignore
90
+ /// #[derive(Invariant)]
91
+ /// enum MyEnum {
92
+ /// OptionOne(u32, u32),
93
+ /// OptionTwo(Square),
94
+ /// OptionThree
95
+ /// }
96
+ /// ```
97
+ ///
98
+ /// expands to:
99
+ /// ```ignore
100
+ /// impl core::ub_checks::Invariant for MyEnum {
101
+ /// fn is_safe(&self) -> bool {
102
+ /// match self {
103
+ /// MyEnum::OptionOne(field1, field2) => field1.is_safe() && field2.is_safe(),
104
+ /// MyEnum::OptionTwo(field1) => field1.is_safe(),
105
+ /// MyEnum::OptionThree => true,
106
+ /// }
107
+ /// }
108
+ /// }
109
+ /// ```
110
+ /// For more information on the Invariant trait, see its documentation in core::ub_checks.
111
+ #[ proc_macro_error]
112
+ #[ proc_macro_derive( Invariant ) ]
113
+ pub fn derive_invariant ( item : TokenStream ) -> TokenStream {
114
+ let derive_item = parse_macro_input ! ( item as DeriveInput ) ;
115
+ let item_name = & derive_item. ident ;
116
+ let safe_body = match derive_item. data {
117
+ Data :: Struct ( struct_data) => {
118
+ safe_body ( & struct_data. fields )
119
+ } ,
120
+ Data :: Enum ( enum_data) => {
121
+ let variant_checks = variant_checks ( enum_data, item_name) ;
122
+
123
+ quote ! {
124
+ match self {
125
+ #( #variant_checks) , *
126
+ }
127
+ }
128
+ } ,
129
+ Data :: Union ( ..) => unimplemented ! ( "Attempted to derive Invariant on a union; Invariant can only be derived for structs and enums." ) ,
130
+ } ;
131
+
132
+ // Add a bound `T: Invariant` to every type parameter T.
133
+ let generics = add_trait_bound_invariant ( derive_item. generics ) ;
134
+ // Generate an expression to sum up the heap size of each field.
135
+ let ( impl_generics, ty_generics, where_clause) = generics. split_for_impl ( ) ;
136
+
137
+ let expanded = quote ! {
138
+ // The generated implementation.
139
+ #[ unstable( feature="invariant" , issue="none" ) ]
140
+ impl #impl_generics core:: ub_checks:: Invariant for #item_name #ty_generics #where_clause {
141
+ fn is_safe( & self ) -> bool {
142
+ #safe_body
143
+ }
144
+ }
145
+ } ;
146
+ proc_macro:: TokenStream :: from ( expanded)
147
+ }
148
+
15
149
#[ proc_macro_error]
16
150
#[ proc_macro_attribute]
17
151
pub fn requires ( attr : TokenStream , item : TokenStream ) -> TokenStream {
@@ -29,3 +163,96 @@ pub fn ensures(attr: TokenStream, item: TokenStream) -> TokenStream {
29
163
pub fn loop_invariant ( attr : TokenStream , stmt_stream : TokenStream ) -> TokenStream {
30
164
tool:: loop_invariant ( attr, stmt_stream)
31
165
}
166
+
167
+ /// Add a bound `T: Invariant` to every type parameter T.
168
+ fn add_trait_bound_invariant ( mut generics : Generics ) -> Generics {
169
+ generics. params . iter_mut ( ) . for_each ( |param| {
170
+ if let GenericParam :: Type ( type_param) = param {
171
+ type_param
172
+ . bounds
173
+ . push ( parse_quote ! ( core:: ub_checks:: Invariant ) ) ;
174
+ }
175
+ } ) ;
176
+ generics
177
+ }
178
+
179
+ /// Generate safety checks for each variant of an enum
180
+ fn variant_checks ( enum_data : DataEnum , item_name : & Ident ) -> Vec < proc_macro2:: TokenStream > {
181
+ enum_data
182
+ . variants
183
+ . iter ( )
184
+ . map ( |variant| {
185
+ let variant_name = & variant. ident ;
186
+ match & variant. fields {
187
+ Fields :: Unnamed ( fields) => {
188
+ let field_names: Vec < _ > = fields
189
+ . unnamed
190
+ . iter ( )
191
+ . enumerate ( )
192
+ . map ( |( i, _) | format_ident ! ( "field{}" , i + 1 ) )
193
+ . collect ( ) ;
194
+
195
+ let field_checks: Vec < _ > = field_names
196
+ . iter ( )
197
+ . map ( |field_name| {
198
+ quote ! { #field_name. is_safe( ) }
199
+ } )
200
+ . collect ( ) ;
201
+
202
+ quote ! {
203
+ #item_name:: #variant_name( #( #field_names) , * ) => #( #field_checks) &&*
204
+ }
205
+ }
206
+ Fields :: Unit => {
207
+ quote ! {
208
+ #item_name:: #variant_name => true
209
+ }
210
+ }
211
+ Fields :: Named ( _) => unreachable ! ( "Enums do not have named fields" ) ,
212
+ }
213
+ } )
214
+ . collect ( )
215
+ }
216
+
217
+ /// Generate the body for the `is_safe` method.
218
+ /// For each field of the type, enforce that it is safe.
219
+ fn safe_body ( fields : & Fields ) -> proc_macro2:: TokenStream {
220
+ match fields {
221
+ Fields :: Named ( ref fields) => {
222
+ let field_safe_calls: Vec < proc_macro2:: TokenStream > = fields
223
+ . named
224
+ . iter ( )
225
+ . map ( |field| {
226
+ let name = & field. ident ;
227
+ quote_spanned ! { field. span( ) =>
228
+ self . #name. is_safe( )
229
+ }
230
+ } )
231
+ . collect ( ) ;
232
+ if !field_safe_calls. is_empty ( ) {
233
+ quote ! { #( #field_safe_calls ) &&* }
234
+ } else {
235
+ quote ! { true }
236
+ }
237
+ }
238
+ Fields :: Unnamed ( ref fields) => {
239
+ let field_safe_calls: Vec < proc_macro2:: TokenStream > = fields
240
+ . unnamed
241
+ . iter ( )
242
+ . enumerate ( )
243
+ . map ( |( idx, field) | {
244
+ let field_idx = Index :: from ( idx) ;
245
+ quote_spanned ! { field. span( ) =>
246
+ self . #field_idx. is_safe( )
247
+ }
248
+ } )
249
+ . collect ( ) ;
250
+ if !field_safe_calls. is_empty ( ) {
251
+ quote ! { #( #field_safe_calls ) &&* }
252
+ } else {
253
+ quote ! { true }
254
+ }
255
+ }
256
+ Fields :: Unit => quote ! { true } ,
257
+ }
258
+ }
0 commit comments