Skip to content

Commit b800211

Browse files
committed
Auto merge of rust-lang#3118 - eduardosm:intrinsics-x86-sse41, r=RalfJung
Implement `llvm.x86.sse41.*` intrinsics
2 parents 99d6cd4 + 949bb64 commit b800211

File tree

3 files changed

+581
-0
lines changed

3 files changed

+581
-0
lines changed

src/tools/miri/src/shims/x86/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ mod aesni;
1111
mod sse;
1212
mod sse2;
1313
mod sse3;
14+
mod sse41;
1415
mod ssse3;
1516

1617
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
@@ -101,6 +102,11 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
101102
this, link_name, abi, args, dest,
102103
);
103104
}
105+
name if name.starts_with("sse41.") => {
106+
return sse41::EvalContextExt::emulate_x86_sse41_intrinsic(
107+
this, link_name, abi, args, dest,
108+
);
109+
}
104110
name if name.starts_with("aesni.") => {
105111
return aesni::EvalContextExt::emulate_x86_aesni_intrinsic(
106112
this, link_name, abi, args, dest,

src/tools/miri/src/shims/x86/sse41.rs

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
use rustc_middle::mir;
2+
use rustc_span::Symbol;
3+
use rustc_target::abi::Size;
4+
use rustc_target::spec::abi::Abi;
5+
6+
use crate::*;
7+
use shims::foreign_items::EmulateForeignItemResult;
8+
9+
impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
10+
pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
11+
crate::MiriInterpCxExt<'mir, 'tcx>
12+
{
13+
fn emulate_x86_sse41_intrinsic(
14+
&mut self,
15+
link_name: Symbol,
16+
abi: Abi,
17+
args: &[OpTy<'tcx, Provenance>],
18+
dest: &PlaceTy<'tcx, Provenance>,
19+
) -> InterpResult<'tcx, EmulateForeignItemResult> {
20+
let this = self.eval_context_mut();
21+
// Prefix should have already been checked.
22+
let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.sse41.").unwrap();
23+
24+
match unprefixed_name {
25+
// Used to implement the _mm_insert_ps function.
26+
// Takes one element of `right` and inserts it into `left` and
27+
// optionally zero some elements. Source index is specified
28+
// in bits `6..=7` of `imm`, destination index is specified in
29+
// bits `4..=5` if `imm`, and `i`th bit specifies whether element
30+
// `i` is zeroed.
31+
"insertps" => {
32+
let [left, right, imm] =
33+
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
34+
35+
let (left, left_len) = this.operand_to_simd(left)?;
36+
let (right, right_len) = this.operand_to_simd(right)?;
37+
let (dest, dest_len) = this.place_to_simd(dest)?;
38+
39+
assert_eq!(dest_len, left_len);
40+
assert_eq!(dest_len, right_len);
41+
assert!(dest_len <= 4);
42+
43+
let imm = this.read_scalar(imm)?.to_u8()?;
44+
let src_index = u64::from((imm >> 6) & 0b11);
45+
let dst_index = u64::from((imm >> 4) & 0b11);
46+
47+
let src_value = this.read_immediate(&this.project_index(&right, src_index)?)?;
48+
49+
for i in 0..dest_len {
50+
let dest = this.project_index(&dest, i)?;
51+
52+
if imm & (1 << i) != 0 {
53+
// zeroed
54+
this.write_scalar(Scalar::from_u32(0), &dest)?;
55+
} else if i == dst_index {
56+
// copy from `right` at specified index
57+
this.write_immediate(*src_value, &dest)?;
58+
} else {
59+
// copy from `left`
60+
this.copy_op(
61+
&this.project_index(&left, i)?,
62+
&dest,
63+
/*allow_transmute*/ false,
64+
)?;
65+
}
66+
}
67+
}
68+
// Used to implement the _mm_packus_epi32 function.
69+
// Concatenates two 32-bit signed integer vectors and converts
70+
// the result to a 16-bit unsigned integer vector with saturation.
71+
"packusdw" => {
72+
let [left, right] =
73+
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
74+
75+
let (left, left_len) = this.operand_to_simd(left)?;
76+
let (right, right_len) = this.operand_to_simd(right)?;
77+
let (dest, dest_len) = this.place_to_simd(dest)?;
78+
79+
assert_eq!(left_len, right_len);
80+
assert_eq!(dest_len, left_len.checked_mul(2).unwrap());
81+
82+
for i in 0..left_len {
83+
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_i32()?;
84+
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_i32()?;
85+
let left_dest = this.project_index(&dest, i)?;
86+
let right_dest = this.project_index(&dest, i.checked_add(left_len).unwrap())?;
87+
88+
let left_res =
89+
u16::try_from(left).unwrap_or(if left < 0 { 0 } else { u16::MAX });
90+
let right_res =
91+
u16::try_from(right).unwrap_or(if right < 0 { 0 } else { u16::MAX });
92+
93+
this.write_scalar(Scalar::from_u16(left_res), &left_dest)?;
94+
this.write_scalar(Scalar::from_u16(right_res), &right_dest)?;
95+
}
96+
}
97+
// Used to implement the _mm_dp_ps and _mm_dp_pd functions.
98+
// Conditionally multiplies the packed floating-point elements in
99+
// `left` and `right` using the high 4 bits in `imm`, sums the four
100+
// products, and conditionally stores the sum in `dest` using the low
101+
// 4 bits of `imm`.
102+
"dpps" | "dppd" => {
103+
let [left, right, imm] =
104+
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
105+
106+
let (left, left_len) = this.operand_to_simd(left)?;
107+
let (right, right_len) = this.operand_to_simd(right)?;
108+
let (dest, dest_len) = this.place_to_simd(dest)?;
109+
110+
assert_eq!(left_len, right_len);
111+
assert!(dest_len <= 4);
112+
113+
let imm = this.read_scalar(imm)?.to_u8()?;
114+
115+
let element_layout = left.layout.field(this, 0);
116+
117+
// Calculate dot product
118+
// Elements are floating point numbers, but we can use `from_int`
119+
// because the representation of 0.0 is all zero bits.
120+
let mut sum = ImmTy::from_int(0u8, element_layout);
121+
for i in 0..left_len {
122+
if imm & (1 << i.checked_add(4).unwrap()) != 0 {
123+
let left = this.read_immediate(&this.project_index(&left, i)?)?;
124+
let right = this.read_immediate(&this.project_index(&right, i)?)?;
125+
126+
let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
127+
sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
128+
}
129+
}
130+
131+
// Write to destination (conditioned to imm)
132+
for i in 0..dest_len {
133+
let dest = this.project_index(&dest, i)?;
134+
135+
if imm & (1 << i) != 0 {
136+
this.write_immediate(*sum, &dest)?;
137+
} else {
138+
this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
139+
}
140+
}
141+
}
142+
// Used to implement the _mm_floor_ss, _mm_ceil_ss and _mm_round_ss
143+
// functions. Rounds the first element of `right` according to `rounding`
144+
// and copies the remaining elements from `left`.
145+
"round.ss" => {
146+
let [left, right, rounding] =
147+
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
148+
149+
round_first::<rustc_apfloat::ieee::Single>(this, left, right, rounding, dest)?;
150+
}
151+
// Used to implement the _mm_floor_sd, _mm_ceil_sd and _mm_round_sd
152+
// functions. Rounds the first element of `right` according to `rounding`
153+
// and copies the remaining elements from `left`.
154+
"round.sd" => {
155+
let [left, right, rounding] =
156+
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
157+
158+
round_first::<rustc_apfloat::ieee::Double>(this, left, right, rounding, dest)?;
159+
}
160+
// Used to implement the _mm_minpos_epu16 function.
161+
// Find the minimum unsinged 16-bit integer in `op` and
162+
// returns its value and position.
163+
"phminposuw" => {
164+
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
165+
166+
let (op, op_len) = this.operand_to_simd(op)?;
167+
let (dest, dest_len) = this.place_to_simd(dest)?;
168+
169+
// Find minimum
170+
let mut min_value = u16::MAX;
171+
let mut min_index = 0;
172+
for i in 0..op_len {
173+
let op = this.read_scalar(&this.project_index(&op, i)?)?.to_u16()?;
174+
if op < min_value {
175+
min_value = op;
176+
min_index = i;
177+
}
178+
}
179+
180+
// Write value and index
181+
this.write_scalar(Scalar::from_u16(min_value), &this.project_index(&dest, 0)?)?;
182+
this.write_scalar(
183+
Scalar::from_u16(min_index.try_into().unwrap()),
184+
&this.project_index(&dest, 1)?,
185+
)?;
186+
// Fill remaining with zeros
187+
for i in 2..dest_len {
188+
this.write_scalar(Scalar::from_u16(0), &this.project_index(&dest, i)?)?;
189+
}
190+
}
191+
// Used to implement the _mm_mpsadbw_epu8 function.
192+
// Compute the sum of absolute differences of quadruplets of unsigned
193+
// 8-bit integers in `left` and `right`, and store the 16-bit results
194+
// in `right`. Quadruplets are selected from `left` and `right` with
195+
// offsets specified in `imm`.
196+
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_mpsadbw_epu8
197+
"mpsadbw" => {
198+
let [left, right, imm] =
199+
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
200+
201+
let (left, left_len) = this.operand_to_simd(left)?;
202+
let (right, right_len) = this.operand_to_simd(right)?;
203+
let (dest, dest_len) = this.place_to_simd(dest)?;
204+
205+
assert_eq!(left_len, right_len);
206+
assert_eq!(left_len, dest_len.checked_mul(2).unwrap());
207+
208+
let imm = this.read_scalar(imm)?.to_u8()?;
209+
// Bit 2 of `imm` specifies the offset for indices of `left`.
210+
// The offset is 0 when the bit is 0 or 4 when the bit is 1.
211+
let left_offset = u64::from((imm >> 2) & 1).checked_mul(4).unwrap();
212+
// Bits 0..=1 of `imm` specify the offset for indices of
213+
// `right` in blocks of 4 elements.
214+
let right_offset = u64::from(imm & 0b11).checked_mul(4).unwrap();
215+
216+
for i in 0..dest_len {
217+
let left_offset = left_offset.checked_add(i).unwrap();
218+
let mut res: u16 = 0;
219+
for j in 0..4 {
220+
let left = this
221+
.read_scalar(
222+
&this.project_index(&left, left_offset.checked_add(j).unwrap())?,
223+
)?
224+
.to_u8()?;
225+
let right = this
226+
.read_scalar(
227+
&this
228+
.project_index(&right, right_offset.checked_add(j).unwrap())?,
229+
)?
230+
.to_u8()?;
231+
res = res.checked_add(left.abs_diff(right).into()).unwrap();
232+
}
233+
this.write_scalar(Scalar::from_u16(res), &this.project_index(&dest, i)?)?;
234+
}
235+
}
236+
// Used to implement the _mm_testz_si128, _mm_testc_si128
237+
// and _mm_testnzc_si128 functions.
238+
// Tests `op & mask == 0`, `op & mask == mask` or
239+
// `op & mask != 0 && op & mask != mask`
240+
"ptestz" | "ptestc" | "ptestnzc" => {
241+
let [op, mask] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
242+
243+
let (op, op_len) = this.operand_to_simd(op)?;
244+
let (mask, mask_len) = this.operand_to_simd(mask)?;
245+
246+
assert_eq!(op_len, mask_len);
247+
248+
let f = match unprefixed_name {
249+
"ptestz" => |op, mask| op & mask == 0,
250+
"ptestc" => |op, mask| op & mask == mask,
251+
"ptestnzc" => |op, mask| op & mask != 0 && op & mask != mask,
252+
_ => unreachable!(),
253+
};
254+
255+
let mut all_zero = true;
256+
for i in 0..op_len {
257+
let op = this.read_scalar(&this.project_index(&op, i)?)?.to_u64()?;
258+
let mask = this.read_scalar(&this.project_index(&mask, i)?)?.to_u64()?;
259+
all_zero &= f(op, mask);
260+
}
261+
262+
this.write_scalar(Scalar::from_i32(all_zero.into()), dest)?;
263+
}
264+
_ => return Ok(EmulateForeignItemResult::NotSupported),
265+
}
266+
Ok(EmulateForeignItemResult::NeedsJumping)
267+
}
268+
}
269+
270+
// Rounds the first element of `right` according to `rounding`
271+
// and copies the remaining elements from `left`.
272+
fn round_first<'tcx, F: rustc_apfloat::Float>(
273+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
274+
left: &OpTy<'tcx, Provenance>,
275+
right: &OpTy<'tcx, Provenance>,
276+
rounding: &OpTy<'tcx, Provenance>,
277+
dest: &PlaceTy<'tcx, Provenance>,
278+
) -> InterpResult<'tcx, ()> {
279+
let (left, left_len) = this.operand_to_simd(left)?;
280+
let (right, right_len) = this.operand_to_simd(right)?;
281+
let (dest, dest_len) = this.place_to_simd(dest)?;
282+
283+
assert_eq!(dest_len, left_len);
284+
assert_eq!(dest_len, right_len);
285+
286+
let rounding = match this.read_scalar(rounding)?.to_i32()? & !0x80 {
287+
0x00 => rustc_apfloat::Round::NearestTiesToEven,
288+
0x01 => rustc_apfloat::Round::TowardNegative,
289+
0x02 => rustc_apfloat::Round::TowardPositive,
290+
0x03 => rustc_apfloat::Round::TowardZero,
291+
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
292+
};
293+
294+
let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?;
295+
let res = op0.round_to_integral(rounding).value;
296+
this.write_scalar(
297+
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
298+
&this.project_index(&dest, 0)?,
299+
)?;
300+
301+
for i in 1..dest_len {
302+
this.copy_op(
303+
&this.project_index(&left, i)?,
304+
&this.project_index(&dest, i)?,
305+
/*allow_transmute*/ false,
306+
)?;
307+
}
308+
309+
Ok(())
310+
}

0 commit comments

Comments
 (0)