Skip to content

Commit c15294e

Browse files
authored
Add support to array-based SIMD (#2633)
Originally, repr(simd) supported only multi-field form. An array based version was later added and it's likely to become the only supported way (rust-lang/compiler-team#621). The array-based version is currently used in the standard library, and it is used to implement `portable-simd`. This change adds support to instantiating and using the array-based version.
1 parent 06f0b5c commit c15294e

File tree

6 files changed

+225
-54
lines changed

6 files changed

+225
-54
lines changed

kani-compiler/src/codegen_cprover_gotoc/codegen/place.rs

Lines changed: 85 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use tracing::{debug, trace, warn};
2323

2424
/// A projection in Kani can either be to a type (the normal case),
2525
/// or a variant in the case of a downcast.
26-
#[derive(Debug)]
26+
#[derive(Copy, Clone, Debug)]
2727
pub enum TypeOrVariant<'tcx> {
2828
Type(Ty<'tcx>),
2929
Variant(&'tcx VariantDef),
@@ -235,15 +235,21 @@ impl<'tcx> TypeOrVariant<'tcx> {
235235
}
236236

237237
impl<'tcx> GotocCtx<'tcx> {
238+
/// Codegen field access for types that allow direct field projection.
239+
///
240+
/// I.e.: Algebraic data types, closures, and generators.
241+
///
242+
/// Other composite types such as array only support index projection.
238243
fn codegen_field(
239244
&mut self,
240-
res: Expr,
241-
t: TypeOrVariant<'tcx>,
242-
f: &FieldIdx,
245+
parent_expr: Expr,
246+
parent_ty_or_var: TypeOrVariant<'tcx>,
247+
field: &FieldIdx,
248+
field_ty_or_var: TypeOrVariant<'tcx>,
243249
) -> Result<Expr, UnimplementedData> {
244-
match t {
245-
TypeOrVariant::Type(t) => {
246-
match t.kind() {
250+
match parent_ty_or_var {
251+
TypeOrVariant::Type(parent_ty) => {
252+
match parent_ty.kind() {
247253
ty::Alias(..)
248254
| ty::Bool
249255
| ty::Char
@@ -254,56 +260,98 @@ impl<'tcx> GotocCtx<'tcx> {
254260
| ty::Never
255261
| ty::FnDef(..)
256262
| ty::GeneratorWitness(..)
263+
| ty::GeneratorWitnessMIR(..)
257264
| ty::Foreign(..)
258265
| ty::Dynamic(..)
259266
| ty::Bound(..)
260267
| ty::Placeholder(..)
261268
| ty::Param(_)
262269
| ty::Infer(_)
263-
| ty::Error(_) => unreachable!("type {:?} does not have a field", t),
264-
ty::Tuple(_) => {
265-
Ok(res.member(&Self::tuple_fld_name(f.index()), &self.symbol_table))
266-
}
267-
ty::Adt(def, _) if def.repr().simd() => {
268-
// this is a SIMD vector - the index represents one
269-
// of the elements, so we want to index as an array
270-
// Example:
271-
// pub struct i64x2(i64, i64);
272-
// fn main() {
273-
// let v = i64x2(1, 2);
274-
// assert!(v.0 == 1); // refers to the first i64
275-
// assert!(v.1 == 2);
276-
// }
277-
let size_index = Expr::int_constant(f.index(), Type::size_t());
278-
Ok(res.index_array(size_index))
279-
}
270+
| ty::Error(_) => unreachable!("type {parent_ty:?} does not have a field"),
271+
ty::Tuple(_) => Ok(parent_expr
272+
.member(&Self::tuple_fld_name(field.index()), &self.symbol_table)),
273+
ty::Adt(def, _) if def.repr().simd() => Ok(self.codegen_simd_field(
274+
parent_expr,
275+
*field,
276+
field_ty_or_var.expect_type(),
277+
)),
280278
// if we fall here, then we are handling either a struct or a union
281279
ty::Adt(def, _) => {
282-
let field = &def.variants().raw[0].fields[*f];
283-
Ok(res.member(&field.name.to_string(), &self.symbol_table))
280+
let field = &def.variants().raw[0].fields[*field];
281+
Ok(parent_expr.member(&field.name.to_string(), &self.symbol_table))
282+
}
283+
ty::Closure(..) => {
284+
Ok(parent_expr.member(&field.index().to_string(), &self.symbol_table))
284285
}
285-
ty::Closure(..) => Ok(res.member(&f.index().to_string(), &self.symbol_table)),
286286
ty::Generator(..) => {
287-
let field_name = self.generator_field_name(f.as_usize());
288-
Ok(res
287+
let field_name = self.generator_field_name(field.as_usize());
288+
Ok(parent_expr
289289
.member("direct_fields", &self.symbol_table)
290290
.member(field_name, &self.symbol_table))
291291
}
292-
_ => unimplemented!(),
292+
ty::Str | ty::Array(_, _) | ty::Slice(_) | ty::RawPtr(_) | ty::Ref(_, _, _) => {
293+
unreachable!(
294+
"element of {parent_ty:?} is not accessed via field projection"
295+
)
296+
}
293297
}
294298
}
295299
// if we fall here, then we are handling an enum
296-
TypeOrVariant::Variant(v) => {
297-
let field = &v.fields[*f];
298-
Ok(res.member(&field.name.to_string(), &self.symbol_table))
300+
TypeOrVariant::Variant(parent_var) => {
301+
let field = &parent_var.fields[*field];
302+
Ok(parent_expr.member(&field.name.to_string(), &self.symbol_table))
299303
}
300304
TypeOrVariant::GeneratorVariant(_var_idx) => {
301-
let field_name = self.generator_field_name(f.index());
302-
Ok(res.member(field_name, &self.symbol_table))
305+
let field_name = self.generator_field_name(field.index());
306+
Ok(parent_expr.member(field_name, &self.symbol_table))
303307
}
304308
}
305309
}
306310

311+
/// This is a SIMD vector, which has 2 possible internal representations:
312+
/// 1- Multi-field representation (original and currently deprecated)
313+
/// In this case, a field is one lane (i.e.: one element)
314+
/// Example:
315+
/// ```ignore
316+
/// pub struct i64x2(i64, i64);
317+
/// fn main() {
318+
/// let v = i64x2(1, 2);
319+
/// assert!(v.0 == 1); // refers to the first i64
320+
/// assert!(v.1 == 2);
321+
/// }
322+
/// ```
323+
/// 2- Array-based representation
324+
/// In this case, the projection refers to the entire array.
325+
/// ```ignore
326+
/// pub struct i64x2([i64; 2]);
327+
/// fn main() {
328+
/// let v = i64x2([1, 2]);
329+
/// assert!(v.0 == [1, 2]); // refers to the entire array
330+
/// }
331+
/// ```
332+
/// * Note that projection inside SIMD structs may eventually become illegal.
333+
/// See <https://github.com/rust-lang/stdarch/pull/1422#discussion_r1176415609> thread.
334+
///
335+
/// Since the goto representation for both is the same, we use the expected type to decide
336+
/// what to return.
337+
fn codegen_simd_field(
338+
&mut self,
339+
parent_expr: Expr,
340+
field: FieldIdx,
341+
field_ty: Ty<'tcx>,
342+
) -> Expr {
343+
if matches!(field_ty.kind(), ty::Array { .. }) {
344+
// Array based
345+
assert_eq!(field.index(), 0);
346+
let field_typ = self.codegen_ty(field_ty);
347+
parent_expr.reinterpret_cast(field_typ)
348+
} else {
349+
// Return the given field.
350+
let index_expr = Expr::int_constant(field.index(), Type::size_t());
351+
parent_expr.index_array(index_expr)
352+
}
353+
}
354+
307355
/// If a local is a function definition, ignore the local variable name and
308356
/// generate a function call based on the def id.
309357
///
@@ -424,7 +472,8 @@ impl<'tcx> GotocCtx<'tcx> {
424472
}
425473
ProjectionElem::Field(f, t) => {
426474
let typ = TypeOrVariant::Type(t);
427-
let expr = self.codegen_field(before.goto_expr, before.mir_typ_or_variant, &f)?;
475+
let expr =
476+
self.codegen_field(before.goto_expr, before.mir_typ_or_variant, &f, typ)?;
428477
ProjectedPlace::try_new(
429478
expr,
430479
typ,

kani-compiler/src/codegen_cprover_gotoc/codegen/rvalue.rs

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -582,24 +582,28 @@ impl<'tcx> GotocCtx<'tcx> {
582582
AggregateKind::Adt(_, _, _, _, _) if res_ty.is_simd() => {
583583
let typ = self.codegen_ty(res_ty);
584584
let layout = self.layout_of(res_ty);
585-
let vector_element_type = typ.base_type().unwrap().clone();
586-
Expr::vector_expr(
587-
typ,
588-
layout
589-
.fields
590-
.index_by_increasing_offset()
591-
.map(|idx| {
592-
let cgo = self.codegen_operand(&operands[idx.into()]);
593-
// The input operand might actually be a one-element array, as seen
594-
// when running assess on firecracker.
595-
if *cgo.typ() == vector_element_type {
596-
cgo
597-
} else {
598-
cgo.transmute_to(vector_element_type.clone(), &self.symbol_table)
599-
}
600-
})
601-
.collect(),
602-
)
585+
trace!(shape=?layout.fields, "codegen_rvalue_aggregate");
586+
assert!(operands.len() > 0, "SIMD vector cannot be empty");
587+
if operands.len() == 1 {
588+
let data = self.codegen_operand(&operands[0u32.into()]);
589+
if data.typ().is_array() {
590+
// Array-based SIMD representation.
591+
data.transmute_to(typ, &self.symbol_table)
592+
} else {
593+
// Multi field-based representation with one field.
594+
Expr::vector_expr(typ, vec![data])
595+
}
596+
} else {
597+
// Multi field SIMD representation.
598+
Expr::vector_expr(
599+
typ,
600+
layout
601+
.fields
602+
.index_by_increasing_offset()
603+
.map(|idx| self.codegen_operand(&operands[idx.into()]))
604+
.collect(),
605+
)
606+
}
603607
}
604608
AggregateKind::Adt(_, variant_index, ..) if res_ty.is_enum() => {
605609
self.codegen_rvalue_enum_aggregate(variant_index, operands, res_ty, loc)

tests/kani/SIMD/array_simd_repr.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
//! Verify that Kani can properly handle SIMD declaration and field access using array syntax.
4+
5+
#![allow(non_camel_case_types)]
6+
#![feature(repr_simd)]
7+
8+
#[repr(simd)]
9+
#[derive(Clone, PartialEq, Eq, PartialOrd, kani::Arbitrary)]
10+
pub struct i64x2([i64; 2]);
11+
12+
#[kani::proof]
13+
fn check_diff() {
14+
let x = i64x2([1, 2]);
15+
let y = i64x2([3, 4]);
16+
assert!(x != y);
17+
}
18+
19+
#[kani::proof]
20+
fn check_ge() {
21+
let x: i64x2 = kani::any();
22+
kani::assume(x.0[0] > 0);
23+
kani::assume(x.0[1] > 0);
24+
assert!(x > i64x2([0, 0]));
25+
}
26+
27+
#[derive(Clone, Debug)]
28+
#[repr(simd)]
29+
struct CustomSimd<T, const LANES: usize>([T; LANES]);
30+
31+
#[kani::proof]
32+
fn simd_vec() {
33+
let simd = CustomSimd([0u8; 10]);
34+
let idx: usize = kani::any_where(|x: &usize| *x < 10);
35+
assert_eq!(simd.0[idx], 0);
36+
}

tests/kani/SIMD/multi_field_simd.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
//! Verify that Kani can properly handle SIMD declaration and field access using multi-field syntax.
4+
//! Note: Multi-field SIMD is actually being deprecated, but until it's removed, we might
5+
//! as well keep supporting it.
6+
//! See <https://github.com/rust-lang/compiler-team/issues/621> for more details.
7+
8+
#![allow(non_camel_case_types)]
9+
#![feature(repr_simd)]
10+
11+
#[repr(simd)]
12+
#[derive(PartialEq, Eq, PartialOrd, kani::Arbitrary)]
13+
pub struct i64x2(i64, i64);
14+
15+
#[kani::proof]
16+
fn check_diff() {
17+
let x = i64x2(1, 2);
18+
let y = i64x2(3, 4);
19+
assert!(x != y);
20+
}
21+
22+
#[kani::proof]
23+
fn check_ge() {
24+
let x: i64x2 = kani::any();
25+
kani::assume(x.0 > 0);
26+
kani::assume(x.1 > 0);
27+
assert!(x > i64x2(0, 0));
28+
}

tests/kani/SIMD/portable_simd.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
4+
//! Ensure we have basic support of portable SIMD.
5+
#![feature(portable_simd)]
6+
7+
use std::simd::u64x16;
8+
9+
#[kani::proof]
10+
fn check_sum_any() {
11+
let a = u64x16::splat(0);
12+
let b = u64x16::from_array(kani::any());
13+
// Cannot compare them directly: https://github.com/model-checking/kani/issues/2632
14+
assert_eq!((a + b).as_array(), b.as_array());
15+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
4+
//! Ensure we can handle SIMD defined in the standard library
5+
//! FIXME: <https://github.com/model-checking/kani/issues/2631>
6+
#![allow(non_camel_case_types)]
7+
#![feature(repr_simd, platform_intrinsics, portable_simd)]
8+
use std::simd::f32x4;
9+
10+
extern "platform-intrinsic" {
11+
fn simd_add<T>(x: T, y: T) -> T;
12+
fn simd_eq<T, U>(x: T, y: T) -> U;
13+
}
14+
15+
#[repr(simd)]
16+
#[derive(Clone, PartialEq, kani::Arbitrary)]
17+
pub struct f32x2(f32, f32);
18+
19+
impl f32x2 {
20+
fn as_array(&self) -> &[f32; 2] {
21+
unsafe { &*(self as *const f32x2 as *const [f32; 2]) }
22+
}
23+
}
24+
25+
#[kani::proof]
26+
fn check_sum() {
27+
let a = f32x2(0.0, 0.0);
28+
let b = kani::any::<f32x2>();
29+
let sum = unsafe { simd_add(a.clone(), b) };
30+
assert_eq!(sum.as_array(), a.as_array());
31+
}
32+
33+
#[kani::proof]
34+
fn check_sum_portable() {
35+
let a = f32x4::splat(0.0);
36+
let b = f32x4::from_array(kani::any());
37+
// Cannot compare them directly: https://github.com/model-checking/kani/issues/2632
38+
assert_eq!((a + b).as_array(), b.as_array());
39+
}

0 commit comments

Comments
 (0)