Skip to content

Commit abe2354

Browse files
Invariant trait (#87)
Adds an `Invariant` trait to `core::ub_checks`, and adds two invariants for `Alignment` and `Layout`. One call-out: Kani's invariant trait [enforces](https://github.com/model-checking/kani/blob/d2051b77437a0032120f0513e0e9c3c4766d8562/library/kani/src/invariant.rs#L63) that the type is sized, but I wasn't sure why that would be necessary, so I didn't add it here. By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 and MIT licenses. --------- Co-authored-by: Celina G. Val <[email protected]>
1 parent 800a8e7 commit abe2354

File tree

5 files changed

+307
-9
lines changed

5 files changed

+307
-9
lines changed

library/contracts/safety/src/kani.rs

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
use proc_macro::{TokenStream};
2-
use quote::{quote, format_ident};
3-
use syn::{ItemFn, parse_macro_input, Stmt};
1+
use proc_macro::TokenStream;
2+
use quote::{format_ident, quote};
3+
use syn::{parse_macro_input, ItemFn, Stmt};
44

55
pub(crate) fn requires(attr: TokenStream, item: TokenStream) -> TokenStream {
66
rewrite_attr(attr, item, "requires")
@@ -21,7 +21,8 @@ fn rewrite_stmt_attr(attr: TokenStream, stmt_stream: TokenStream, name: &str) ->
2121
quote!(
2222
#[kani_core::#attribute(#args)]
2323
#stmt
24-
).into()
24+
)
25+
.into()
2526
}
2627

2728
fn rewrite_attr(attr: TokenStream, item: TokenStream, name: &str) -> TokenStream {
@@ -31,5 +32,6 @@ fn rewrite_attr(attr: TokenStream, item: TokenStream, name: &str) -> TokenStream
3132
quote!(
3233
#[kani_core::#attribute(#args)]
3334
#fn_item
34-
).into()
35+
)
36+
.into()
3537
}

library/contracts/safety/src/lib.rs

+227
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
44
use proc_macro::TokenStream;
55
use proc_macro_error::proc_macro_error;
6+
use quote::{format_ident, quote, quote_spanned};
7+
use syn::{
8+
parse_macro_input, parse_quote, spanned::Spanned, Data, DataEnum, DeriveInput, Fields,
9+
GenericParam, Generics, Ident, Index, ItemStruct,
10+
};
611

712
#[cfg(kani_host)]
813
#[path = "kani.rs"]
@@ -12,6 +17,135 @@ mod tool;
1217
#[path = "runtime.rs"]
1318
mod tool;
1419

20+
/// Expands the `#[invariant(...)]` attribute macro.
21+
/// The macro expands to an implementation of the `is_safe` method for the `Invariant` trait.
22+
/// This attribute is only supported for structs.
23+
///
24+
/// # Example
25+
///
26+
/// ```ignore
27+
/// #[invariant(self.width == self.height)]
28+
/// struct Square {
29+
/// width: u32,
30+
/// height: u32,
31+
/// }
32+
/// ```
33+
///
34+
/// expands to:
35+
/// ```ignore
36+
/// impl core::ub_checks::Invariant for Square {
37+
/// fn is_safe(&self) -> bool {
38+
/// self.width == self.height
39+
/// }
40+
/// }
41+
/// ```
42+
/// For more information on the Invariant trait, see its documentation in core::ub_checks.
43+
#[proc_macro_error]
44+
#[proc_macro_attribute]
45+
pub fn invariant(attr: TokenStream, item: TokenStream) -> TokenStream {
46+
let safe_body = proc_macro2::TokenStream::from(attr);
47+
let item = parse_macro_input!(item as ItemStruct);
48+
let item_name = &item.ident;
49+
let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
50+
51+
let expanded = quote! {
52+
#item
53+
#[unstable(feature="invariant", issue="none")]
54+
impl #impl_generics core::ub_checks::Invariant for #item_name #ty_generics #where_clause {
55+
fn is_safe(&self) -> bool {
56+
#safe_body
57+
}
58+
}
59+
};
60+
61+
proc_macro::TokenStream::from(expanded)
62+
}
63+
64+
/// Expands the derive macro for the Invariant trait.
65+
/// The macro expands to an implementation of the `is_safe` method for the `Invariant` trait.
66+
/// This macro is only supported for structs and enums.
67+
///
68+
/// # Example
69+
///
70+
/// ```ignore
71+
/// #[derive(Invariant)]
72+
/// struct Square {
73+
/// width: u32,
74+
/// height: u32,
75+
/// }
76+
/// ```
77+
///
78+
/// expands to:
79+
/// ```ignore
80+
/// impl core::ub_checks::Invariant for Square {
81+
/// fn is_safe(&self) -> bool {
82+
/// self.width.is_safe() && self.height.is_safe()
83+
/// }
84+
/// }
85+
/// ```
86+
/// For enums, the body of `is_safe` matches on the variant and calls `is_safe` on its fields,
87+
/// # Example
88+
///
89+
/// ```ignore
90+
/// #[derive(Invariant)]
91+
/// enum MyEnum {
92+
/// OptionOne(u32, u32),
93+
/// OptionTwo(Square),
94+
/// OptionThree
95+
/// }
96+
/// ```
97+
///
98+
/// expands to:
99+
/// ```ignore
100+
/// impl core::ub_checks::Invariant for MyEnum {
101+
/// fn is_safe(&self) -> bool {
102+
/// match self {
103+
/// MyEnum::OptionOne(field1, field2) => field1.is_safe() && field2.is_safe(),
104+
/// MyEnum::OptionTwo(field1) => field1.is_safe(),
105+
/// MyEnum::OptionThree => true,
106+
/// }
107+
/// }
108+
/// }
109+
/// ```
110+
/// For more information on the Invariant trait, see its documentation in core::ub_checks.
111+
#[proc_macro_error]
112+
#[proc_macro_derive(Invariant)]
113+
pub fn derive_invariant(item: TokenStream) -> TokenStream {
114+
let derive_item = parse_macro_input!(item as DeriveInput);
115+
let item_name = &derive_item.ident;
116+
let safe_body = match derive_item.data {
117+
Data::Struct(struct_data) => {
118+
safe_body(&struct_data.fields)
119+
},
120+
Data::Enum(enum_data) => {
121+
let variant_checks = variant_checks(enum_data, item_name);
122+
123+
quote! {
124+
match self {
125+
#(#variant_checks),*
126+
}
127+
}
128+
},
129+
Data::Union(..) => unimplemented!("Attempted to derive Invariant on a union; Invariant can only be derived for structs and enums."),
130+
};
131+
132+
// Add a bound `T: Invariant` to every type parameter T.
133+
let generics = add_trait_bound_invariant(derive_item.generics);
134+
// Generate an expression to sum up the heap size of each field.
135+
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
136+
137+
let expanded = quote! {
138+
// The generated implementation.
139+
#[unstable(feature="invariant", issue="none")]
140+
impl #impl_generics core::ub_checks::Invariant for #item_name #ty_generics #where_clause {
141+
fn is_safe(&self) -> bool {
142+
#safe_body
143+
}
144+
}
145+
};
146+
proc_macro::TokenStream::from(expanded)
147+
}
148+
15149
#[proc_macro_error]
16150
#[proc_macro_attribute]
17151
pub fn requires(attr: TokenStream, item: TokenStream) -> TokenStream {
@@ -29,3 +163,96 @@ pub fn ensures(attr: TokenStream, item: TokenStream) -> TokenStream {
29163
pub fn loop_invariant(attr: TokenStream, stmt_stream: TokenStream) -> TokenStream {
30164
tool::loop_invariant(attr, stmt_stream)
31165
}
166+
167+
/// Add a bound `T: Invariant` to every type parameter T.
168+
fn add_trait_bound_invariant(mut generics: Generics) -> Generics {
169+
generics.params.iter_mut().for_each(|param| {
170+
if let GenericParam::Type(type_param) = param {
171+
type_param
172+
.bounds
173+
.push(parse_quote!(core::ub_checks::Invariant));
174+
}
175+
});
176+
generics
177+
}
178+
179+
/// Generate safety checks for each variant of an enum
180+
fn variant_checks(enum_data: DataEnum, item_name: &Ident) -> Vec<proc_macro2::TokenStream> {
181+
enum_data
182+
.variants
183+
.iter()
184+
.map(|variant| {
185+
let variant_name = &variant.ident;
186+
match &variant.fields {
187+
Fields::Unnamed(fields) => {
188+
let field_names: Vec<_> = fields
189+
.unnamed
190+
.iter()
191+
.enumerate()
192+
.map(|(i, _)| format_ident!("field{}", i + 1))
193+
.collect();
194+
195+
let field_checks: Vec<_> = field_names
196+
.iter()
197+
.map(|field_name| {
198+
quote! { #field_name.is_safe() }
199+
})
200+
.collect();
201+
202+
quote! {
203+
#item_name::#variant_name(#(#field_names),*) => #(#field_checks)&&*
204+
}
205+
}
206+
Fields::Unit => {
207+
quote! {
208+
#item_name::#variant_name => true
209+
}
210+
}
211+
Fields::Named(_) => unreachable!("Enums do not have named fields"),
212+
}
213+
})
214+
.collect()
215+
}
216+
217+
/// Generate the body for the `is_safe` method.
218+
/// For each field of the type, enforce that it is safe.
219+
fn safe_body(fields: &Fields) -> proc_macro2::TokenStream {
220+
match fields {
221+
Fields::Named(ref fields) => {
222+
let field_safe_calls: Vec<proc_macro2::TokenStream> = fields
223+
.named
224+
.iter()
225+
.map(|field| {
226+
let name = &field.ident;
227+
quote_spanned! {field.span()=>
228+
self.#name.is_safe()
229+
}
230+
})
231+
.collect();
232+
if !field_safe_calls.is_empty() {
233+
quote! { #( #field_safe_calls )&&* }
234+
} else {
235+
quote! { true }
236+
}
237+
}
238+
Fields::Unnamed(ref fields) => {
239+
let field_safe_calls: Vec<proc_macro2::TokenStream> = fields
240+
.unnamed
241+
.iter()
242+
.enumerate()
243+
.map(|(idx, field)| {
244+
let field_idx = Index::from(idx);
245+
quote_spanned! {field.span()=>
246+
self.#field_idx.is_safe()
247+
}
248+
})
249+
.collect();
250+
if !field_safe_calls.is_empty() {
251+
quote! { #( #field_safe_calls )&&* }
252+
} else {
253+
quote! { true }
254+
}
255+
}
256+
Fields::Unit => quote! { true },
257+
}
258+
}

library/core/src/alloc/layout.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
// collections, resulting in having to optimize down excess IR multiple times.
55
// Your performance intuition is useless. Run perf.
66

7-
use safety::{ensures, requires};
7+
use safety::{ensures, Invariant, requires};
88
use crate::error::Error;
99
use crate::ptr::{Alignment, NonNull};
1010
use crate::{assert_unsafe_precondition, cmp, fmt, mem};
1111

1212
#[cfg(kani)]
1313
use crate::kani;
1414

15+
// Used only for contract verification.
16+
#[allow(unused_imports)]
17+
use crate::ub_checks::Invariant;
18+
1519
// While this function is used in one place and its implementation
1620
// could be inlined, the previous attempts to do so made rustc
1721
// slower:
@@ -39,6 +43,7 @@ const fn size_align<T>() -> (usize, usize) {
3943
#[stable(feature = "alloc_layout", since = "1.28.0")]
4044
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
4145
#[lang = "alloc_layout"]
46+
#[derive(Invariant)]
4247
pub struct Layout {
4348
// size of the requested block of memory, measured in bytes.
4449
size: usize,
@@ -132,6 +137,7 @@ impl Layout {
132137
#[inline]
133138
#[rustc_allow_const_fn_unstable(ptr_alignment_type)]
134139
#[requires(Layout::from_size_align(size, align).is_ok())]
140+
#[ensures(|result| result.is_safe())]
135141
#[ensures(|result| result.size() == size)]
136142
#[ensures(|result| result.align() == align)]
137143
pub const unsafe fn from_size_align_unchecked(size: usize, align: usize) -> Self {

library/core/src/ptr/alignment.rs

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
use safety::{ensures, requires};
1+
use safety::{ensures, invariant, requires};
22
use crate::num::NonZero;
33
use crate::ub_checks::assert_unsafe_precondition;
44
use crate::{cmp, fmt, hash, mem, num};
55

66
#[cfg(kani)]
77
use crate::kani;
88

9+
#[cfg(kani)]
10+
use crate::ub_checks::Invariant;
11+
912
/// A type storing a `usize` which is a power of two, and thus
1013
/// represents a possible alignment in the Rust abstract machine.
1114
///
@@ -14,6 +17,7 @@ use crate::kani;
1417
#[unstable(feature = "ptr_alignment_type", issue = "102070")]
1518
#[derive(Copy, Clone, PartialEq, Eq)]
1619
#[repr(transparent)]
20+
#[invariant(self.as_usize().is_power_of_two())]
1721
pub struct Alignment(AlignmentEnum);
1822

1923
// Alignment is `repr(usize)`, but via extra steps.
@@ -256,6 +260,7 @@ impl Default for Alignment {
256260
#[cfg(target_pointer_width = "16")]
257261
#[derive(Copy, Clone, PartialEq, Eq)]
258262
#[repr(u16)]
263+
#[cfg_attr(kani, derive(kani::Arbitrary))]
259264
enum AlignmentEnum {
260265
_Align1Shl0 = 1 << 0,
261266
_Align1Shl1 = 1 << 1,
@@ -278,6 +283,7 @@ enum AlignmentEnum {
278283
#[cfg(target_pointer_width = "32")]
279284
#[derive(Copy, Clone, PartialEq, Eq)]
280285
#[repr(u32)]
286+
#[cfg_attr(kani, derive(kani::Arbitrary))]
281287
enum AlignmentEnum {
282288
_Align1Shl0 = 1 << 0,
283289
_Align1Shl1 = 1 << 1,
@@ -316,6 +322,7 @@ enum AlignmentEnum {
316322
#[cfg(target_pointer_width = "64")]
317323
#[derive(Copy, Clone, PartialEq, Eq)]
318324
#[repr(u64)]
325+
#[cfg_attr(kani, derive(kani::Arbitrary))]
319326
enum AlignmentEnum {
320327
_Align1Shl0 = 1 << 0,
321328
_Align1Shl1 = 1 << 1,
@@ -390,8 +397,9 @@ mod verify {
390397

391398
impl kani::Arbitrary for Alignment {
392399
fn any() -> Self {
393-
let align = kani::any_where(|a: &usize| a.is_power_of_two());
394-
unsafe { mem::transmute::<usize, Alignment>(align) }
400+
let obj = Self { 0: kani::any() };
401+
kani::assume(obj.is_safe());
402+
obj
395403
}
396404
}
397405

0 commit comments

Comments
 (0)