Skip to content

Commit 728876e

Browse files
committed
Implement SHA256 SIMD intrinsics on x86
It'd be useful to be able to verify code implementing SHA256 using SIMD since such code is a bit more complicated and at some points requires use of pointers. Until now `miri` didn't support x86 SHA256 intrinsics. This commit implements them.
1 parent b3736d6 commit 728876e

File tree

3 files changed

+497
-0
lines changed

3 files changed

+497
-0
lines changed

src/tools/miri/src/shims/x86/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ mod aesni;
1515
mod avx;
1616
mod avx2;
1717
mod bmi;
18+
mod sha;
1819
mod sse;
1920
mod sse2;
2021
mod sse3;
@@ -105,6 +106,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
105106
this, link_name, abi, args, dest,
106107
);
107108
}
109+
name if name.starts_with("sha") => {
110+
return sha::EvalContextExt::emulate_x86_sha_intrinsic(
111+
this, link_name, abi, args, dest,
112+
);
113+
}
108114
name if name.starts_with("sse.") => {
109115
return sse::EvalContextExt::emulate_x86_sse_intrinsic(
110116
this, link_name, abi, args, dest,

src/tools/miri/src/shims/x86/sha.rs

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
//! Implements sha256 SIMD instructions of x86 targets
2+
//!
3+
//! The functions that actually compute SHA256 were copied from [RustCrypto's sha256 module].
4+
//!
5+
//! [RustCrypto's sha256 module]: https://github.com/RustCrypto/hashes/blob/6be8466247e936c415d8aafb848697f39894a386/sha2/src/sha256/soft.rs
6+
7+
use rustc_span::Symbol;
8+
use rustc_target::spec::abi::Abi;
9+
10+
use crate::*;
11+
12+
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
13+
pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
14+
fn emulate_x86_sha_intrinsic(
15+
&mut self,
16+
link_name: Symbol,
17+
abi: Abi,
18+
args: &[OpTy<'tcx>],
19+
dest: &MPlaceTy<'tcx>,
20+
) -> InterpResult<'tcx, EmulateItemResult> {
21+
let this = self.eval_context_mut();
22+
this.expect_target_feature_for_intrinsic(link_name, "sha")?;
23+
// Prefix should have already been checked.
24+
let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.sha").unwrap();
25+
26+
fn read<'c>(this: &mut MiriInterpCx<'c>, reg: &MPlaceTy<'c>) -> InterpResult<'c, [u32; 4]> {
27+
let mut res = [0; 4];
28+
// We reverse the order because x86 is little endian but the copied implementation uses
29+
// big endian.
30+
for (i, dst) in res.iter_mut().rev().enumerate() {
31+
let projected = &this.project_index(reg, i.try_into().unwrap())?;
32+
*dst = this.read_scalar(projected)?.to_u32()?
33+
}
34+
Ok(res)
35+
}
36+
37+
fn write<'c>(
38+
this: &mut MiriInterpCx<'c>,
39+
dest: &MPlaceTy<'c>,
40+
val: [u32; 4],
41+
) -> InterpResult<'c, ()> {
42+
// We reverse the order because x86 is little endian but the copied implementation uses
43+
// big endian.
44+
for (i, part) in val.into_iter().rev().enumerate() {
45+
let projected = &this.project_index(dest, i.try_into().unwrap())?;
46+
this.write_scalar(Scalar::from_u32(part), projected)?;
47+
}
48+
Ok(())
49+
}
50+
51+
match unprefixed_name {
52+
// Used to implement the _mm_sha256rnds2_epu32 function.
53+
"256rnds2" => {
54+
let [a, b, k] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
55+
56+
let (a_reg, a_len) = this.operand_to_simd(a)?;
57+
let (b_reg, b_len) = this.operand_to_simd(b)?;
58+
let (k_reg, k_len) = this.operand_to_simd(k)?;
59+
let (dest, dest_len) = this.mplace_to_simd(dest)?;
60+
61+
assert_eq!(a_len, 4);
62+
assert_eq!(b_len, 4);
63+
assert_eq!(k_len, 4);
64+
assert_eq!(dest_len, 4);
65+
66+
let a = read(this, &a_reg)?;
67+
let b = read(this, &b_reg)?;
68+
let k = read(this, &k_reg)?;
69+
70+
let result = sha256_digest_round_x2(a, b, k);
71+
write(this, &dest, result)?;
72+
}
73+
// Used to implement the _mm_sha256msg1_epu32 function.
74+
"256msg1" => {
75+
let [a, b] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
76+
77+
let (a_reg, a_len) = this.operand_to_simd(a)?;
78+
let (b_reg, b_len) = this.operand_to_simd(b)?;
79+
let (dest, dest_len) = this.mplace_to_simd(dest)?;
80+
81+
assert_eq!(a_len, 4);
82+
assert_eq!(b_len, 4);
83+
assert_eq!(dest_len, 4);
84+
85+
let a = read(this, &a_reg)?;
86+
let b = read(this, &b_reg)?;
87+
88+
let result = sha256msg1(a, b);
89+
write(this, &dest, result)?;
90+
}
91+
// Used to implement the _mm_sha256msg2_epu32 function.
92+
"256msg2" => {
93+
let [a, b] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
94+
95+
let (a_reg, a_len) = this.operand_to_simd(a)?;
96+
let (b_reg, b_len) = this.operand_to_simd(b)?;
97+
let (dest, dest_len) = this.mplace_to_simd(dest)?;
98+
99+
assert_eq!(a_len, 4);
100+
assert_eq!(b_len, 4);
101+
assert_eq!(dest_len, 4);
102+
103+
let a = read(this, &a_reg)?;
104+
let b = read(this, &b_reg)?;
105+
106+
let result = sha256msg2(a, b);
107+
write(this, &dest, result)?;
108+
}
109+
_ => return Ok(EmulateItemResult::NotSupported),
110+
}
111+
Ok(EmulateItemResult::NeedsReturn)
112+
}
113+
}
114+
115+
#[inline(always)]
116+
fn shr(v: [u32; 4], o: u32) -> [u32; 4] {
117+
[v[0] >> o, v[1] >> o, v[2] >> o, v[3] >> o]
118+
}
119+
120+
#[inline(always)]
121+
fn shl(v: [u32; 4], o: u32) -> [u32; 4] {
122+
[v[0] << o, v[1] << o, v[2] << o, v[3] << o]
123+
}
124+
125+
#[inline(always)]
126+
fn or(a: [u32; 4], b: [u32; 4]) -> [u32; 4] {
127+
[a[0] | b[0], a[1] | b[1], a[2] | b[2], a[3] | b[3]]
128+
}
129+
130+
#[inline(always)]
131+
fn xor(a: [u32; 4], b: [u32; 4]) -> [u32; 4] {
132+
[a[0] ^ b[0], a[1] ^ b[1], a[2] ^ b[2], a[3] ^ b[3]]
133+
}
134+
135+
#[inline(always)]
136+
fn add(a: [u32; 4], b: [u32; 4]) -> [u32; 4] {
137+
[
138+
a[0].wrapping_add(b[0]),
139+
a[1].wrapping_add(b[1]),
140+
a[2].wrapping_add(b[2]),
141+
a[3].wrapping_add(b[3]),
142+
]
143+
}
144+
145+
fn sha256load(v2: [u32; 4], v3: [u32; 4]) -> [u32; 4] {
146+
[v3[3], v2[0], v2[1], v2[2]]
147+
}
148+
149+
fn sha256_digest_round_x2(cdgh: [u32; 4], abef: [u32; 4], wk: [u32; 4]) -> [u32; 4] {
150+
macro_rules! big_sigma0 {
151+
($a:expr) => {
152+
($a.rotate_right(2) ^ $a.rotate_right(13) ^ $a.rotate_right(22))
153+
};
154+
}
155+
macro_rules! big_sigma1 {
156+
($a:expr) => {
157+
($a.rotate_right(6) ^ $a.rotate_right(11) ^ $a.rotate_right(25))
158+
};
159+
}
160+
macro_rules! bool3ary_202 {
161+
($a:expr, $b:expr, $c:expr) => {
162+
$c ^ ($a & ($b ^ $c))
163+
};
164+
} // Choose, MD5F, SHA1C
165+
macro_rules! bool3ary_232 {
166+
($a:expr, $b:expr, $c:expr) => {
167+
($a & $b) ^ ($a & $c) ^ ($b & $c)
168+
};
169+
} // Majority, SHA1M
170+
171+
let [_, _, wk1, wk0] = wk;
172+
let [a0, b0, e0, f0] = abef;
173+
let [c0, d0, g0, h0] = cdgh;
174+
175+
// a round
176+
let x0 =
177+
big_sigma1!(e0).wrapping_add(bool3ary_202!(e0, f0, g0)).wrapping_add(wk0).wrapping_add(h0);
178+
let y0 = big_sigma0!(a0).wrapping_add(bool3ary_232!(a0, b0, c0));
179+
let (a1, b1, c1, d1, e1, f1, g1, h1) =
180+
(x0.wrapping_add(y0), a0, b0, c0, x0.wrapping_add(d0), e0, f0, g0);
181+
182+
// a round
183+
let x1 =
184+
big_sigma1!(e1).wrapping_add(bool3ary_202!(e1, f1, g1)).wrapping_add(wk1).wrapping_add(h1);
185+
let y1 = big_sigma0!(a1).wrapping_add(bool3ary_232!(a1, b1, c1));
186+
let (a2, b2, _, _, e2, f2, _, _) =
187+
(x1.wrapping_add(y1), a1, b1, c1, x1.wrapping_add(d1), e1, f1, g1);
188+
189+
[a2, b2, e2, f2]
190+
}
191+
192+
fn sha256msg1(v0: [u32; 4], v1: [u32; 4]) -> [u32; 4] {
193+
// sigma 0 on vectors
194+
#[inline]
195+
fn sigma0x4(x: [u32; 4]) -> [u32; 4] {
196+
let t1 = or(shr(x, 7), shl(x, 25));
197+
let t2 = or(shr(x, 18), shl(x, 14));
198+
let t3 = shr(x, 3);
199+
xor(xor(t1, t2), t3)
200+
}
201+
202+
add(v0, sigma0x4(sha256load(v0, v1)))
203+
}
204+
205+
fn sha256msg2(v4: [u32; 4], v3: [u32; 4]) -> [u32; 4] {
206+
macro_rules! sigma1 {
207+
($a:expr) => {
208+
$a.rotate_right(17) ^ $a.rotate_right(19) ^ ($a >> 10)
209+
};
210+
}
211+
212+
let [x3, x2, x1, x0] = v4;
213+
let [w15, w14, _, _] = v3;
214+
215+
let w16 = x0.wrapping_add(sigma1!(w14));
216+
let w17 = x1.wrapping_add(sigma1!(w15));
217+
let w18 = x2.wrapping_add(sigma1!(w16));
218+
let w19 = x3.wrapping_add(sigma1!(w17));
219+
220+
[w19, w18, w17, w16]
221+
}

0 commit comments

Comments
 (0)