Skip to content

Commit f203b42

Browse files
committed
Auto merge of rust-lang#3752 - Kixunil:simd-sha256, r=RalfJung
Implement SHA256 SIMD intrinsics on x86 Disclaimer: this is my first contribution to `miri`'s code. It's quite possible I'm missing something. This code works but may not be the cleanest/best possible. It'd be useful to be able to verify code implementing SHA256 using SIMD since such code is a bit more complicated and at some points requires use of pointers. Until now `miri` didn't support x86 SHA256 intrinsics. This commit implements them.
2 parents e6e294a + 728876e commit f203b42

File tree

3 files changed

+497
-0
lines changed

3 files changed

+497
-0
lines changed

Diff for: src/tools/miri/src/shims/x86/mod.rs

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ mod aesni;
1515
mod avx;
1616
mod avx2;
1717
mod bmi;
18+
mod sha;
1819
mod sse;
1920
mod sse2;
2021
mod sse3;
@@ -105,6 +106,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
105106
this, link_name, abi, args, dest,
106107
);
107108
}
109+
name if name.starts_with("sha") => {
110+
return sha::EvalContextExt::emulate_x86_sha_intrinsic(
111+
this, link_name, abi, args, dest,
112+
);
113+
}
108114
name if name.starts_with("sse.") => {
109115
return sse::EvalContextExt::emulate_x86_sse_intrinsic(
110116
this, link_name, abi, args, dest,

Diff for: src/tools/miri/src/shims/x86/sha.rs

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
//! Implements sha256 SIMD instructions of x86 targets
2+
//!
3+
//! The functions that actually compute SHA256 were copied from [RustCrypto's sha256 module].
4+
//!
5+
//! [RustCrypto's sha256 module]: https://github.com/RustCrypto/hashes/blob/6be8466247e936c415d8aafb848697f39894a386/sha2/src/sha256/soft.rs
6+
7+
use rustc_span::Symbol;
8+
use rustc_target::spec::abi::Abi;
9+
10+
use crate::*;
11+
12+
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
13+
pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
14+
fn emulate_x86_sha_intrinsic(
15+
&mut self,
16+
link_name: Symbol,
17+
abi: Abi,
18+
args: &[OpTy<'tcx>],
19+
dest: &MPlaceTy<'tcx>,
20+
) -> InterpResult<'tcx, EmulateItemResult> {
21+
let this = self.eval_context_mut();
22+
this.expect_target_feature_for_intrinsic(link_name, "sha")?;
23+
// Prefix should have already been checked.
24+
let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.sha").unwrap();
25+
26+
fn read<'c>(this: &mut MiriInterpCx<'c>, reg: &MPlaceTy<'c>) -> InterpResult<'c, [u32; 4]> {
27+
let mut res = [0; 4];
28+
// We reverse the order because x86 is little endian but the copied implementation uses
29+
// big endian.
30+
for (i, dst) in res.iter_mut().rev().enumerate() {
31+
let projected = &this.project_index(reg, i.try_into().unwrap())?;
32+
*dst = this.read_scalar(projected)?.to_u32()?
33+
}
34+
Ok(res)
35+
}
36+
37+
fn write<'c>(
38+
this: &mut MiriInterpCx<'c>,
39+
dest: &MPlaceTy<'c>,
40+
val: [u32; 4],
41+
) -> InterpResult<'c, ()> {
42+
// We reverse the order because x86 is little endian but the copied implementation uses
43+
// big endian.
44+
for (i, part) in val.into_iter().rev().enumerate() {
45+
let projected = &this.project_index(dest, i.try_into().unwrap())?;
46+
this.write_scalar(Scalar::from_u32(part), projected)?;
47+
}
48+
Ok(())
49+
}
50+
51+
match unprefixed_name {
52+
// Used to implement the _mm_sha256rnds2_epu32 function.
53+
"256rnds2" => {
54+
let [a, b, k] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
55+
56+
let (a_reg, a_len) = this.operand_to_simd(a)?;
57+
let (b_reg, b_len) = this.operand_to_simd(b)?;
58+
let (k_reg, k_len) = this.operand_to_simd(k)?;
59+
let (dest, dest_len) = this.mplace_to_simd(dest)?;
60+
61+
assert_eq!(a_len, 4);
62+
assert_eq!(b_len, 4);
63+
assert_eq!(k_len, 4);
64+
assert_eq!(dest_len, 4);
65+
66+
let a = read(this, &a_reg)?;
67+
let b = read(this, &b_reg)?;
68+
let k = read(this, &k_reg)?;
69+
70+
let result = sha256_digest_round_x2(a, b, k);
71+
write(this, &dest, result)?;
72+
}
73+
// Used to implement the _mm_sha256msg1_epu32 function.
74+
"256msg1" => {
75+
let [a, b] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
76+
77+
let (a_reg, a_len) = this.operand_to_simd(a)?;
78+
let (b_reg, b_len) = this.operand_to_simd(b)?;
79+
let (dest, dest_len) = this.mplace_to_simd(dest)?;
80+
81+
assert_eq!(a_len, 4);
82+
assert_eq!(b_len, 4);
83+
assert_eq!(dest_len, 4);
84+
85+
let a = read(this, &a_reg)?;
86+
let b = read(this, &b_reg)?;
87+
88+
let result = sha256msg1(a, b);
89+
write(this, &dest, result)?;
90+
}
91+
// Used to implement the _mm_sha256msg2_epu32 function.
92+
"256msg2" => {
93+
let [a, b] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
94+
95+
let (a_reg, a_len) = this.operand_to_simd(a)?;
96+
let (b_reg, b_len) = this.operand_to_simd(b)?;
97+
let (dest, dest_len) = this.mplace_to_simd(dest)?;
98+
99+
assert_eq!(a_len, 4);
100+
assert_eq!(b_len, 4);
101+
assert_eq!(dest_len, 4);
102+
103+
let a = read(this, &a_reg)?;
104+
let b = read(this, &b_reg)?;
105+
106+
let result = sha256msg2(a, b);
107+
write(this, &dest, result)?;
108+
}
109+
_ => return Ok(EmulateItemResult::NotSupported),
110+
}
111+
Ok(EmulateItemResult::NeedsReturn)
112+
}
113+
}
114+
115+
#[inline(always)]
116+
fn shr(v: [u32; 4], o: u32) -> [u32; 4] {
117+
[v[0] >> o, v[1] >> o, v[2] >> o, v[3] >> o]
118+
}
119+
120+
#[inline(always)]
121+
fn shl(v: [u32; 4], o: u32) -> [u32; 4] {
122+
[v[0] << o, v[1] << o, v[2] << o, v[3] << o]
123+
}
124+
125+
#[inline(always)]
126+
fn or(a: [u32; 4], b: [u32; 4]) -> [u32; 4] {
127+
[a[0] | b[0], a[1] | b[1], a[2] | b[2], a[3] | b[3]]
128+
}
129+
130+
#[inline(always)]
131+
fn xor(a: [u32; 4], b: [u32; 4]) -> [u32; 4] {
132+
[a[0] ^ b[0], a[1] ^ b[1], a[2] ^ b[2], a[3] ^ b[3]]
133+
}
134+
135+
#[inline(always)]
136+
fn add(a: [u32; 4], b: [u32; 4]) -> [u32; 4] {
137+
[
138+
a[0].wrapping_add(b[0]),
139+
a[1].wrapping_add(b[1]),
140+
a[2].wrapping_add(b[2]),
141+
a[3].wrapping_add(b[3]),
142+
]
143+
}
144+
145+
fn sha256load(v2: [u32; 4], v3: [u32; 4]) -> [u32; 4] {
146+
[v3[3], v2[0], v2[1], v2[2]]
147+
}
148+
149+
fn sha256_digest_round_x2(cdgh: [u32; 4], abef: [u32; 4], wk: [u32; 4]) -> [u32; 4] {
150+
macro_rules! big_sigma0 {
151+
($a:expr) => {
152+
($a.rotate_right(2) ^ $a.rotate_right(13) ^ $a.rotate_right(22))
153+
};
154+
}
155+
macro_rules! big_sigma1 {
156+
($a:expr) => {
157+
($a.rotate_right(6) ^ $a.rotate_right(11) ^ $a.rotate_right(25))
158+
};
159+
}
160+
macro_rules! bool3ary_202 {
161+
($a:expr, $b:expr, $c:expr) => {
162+
$c ^ ($a & ($b ^ $c))
163+
};
164+
} // Choose, MD5F, SHA1C
165+
macro_rules! bool3ary_232 {
166+
($a:expr, $b:expr, $c:expr) => {
167+
($a & $b) ^ ($a & $c) ^ ($b & $c)
168+
};
169+
} // Majority, SHA1M
170+
171+
let [_, _, wk1, wk0] = wk;
172+
let [a0, b0, e0, f0] = abef;
173+
let [c0, d0, g0, h0] = cdgh;
174+
175+
// a round
176+
let x0 =
177+
big_sigma1!(e0).wrapping_add(bool3ary_202!(e0, f0, g0)).wrapping_add(wk0).wrapping_add(h0);
178+
let y0 = big_sigma0!(a0).wrapping_add(bool3ary_232!(a0, b0, c0));
179+
let (a1, b1, c1, d1, e1, f1, g1, h1) =
180+
(x0.wrapping_add(y0), a0, b0, c0, x0.wrapping_add(d0), e0, f0, g0);
181+
182+
// a round
183+
let x1 =
184+
big_sigma1!(e1).wrapping_add(bool3ary_202!(e1, f1, g1)).wrapping_add(wk1).wrapping_add(h1);
185+
let y1 = big_sigma0!(a1).wrapping_add(bool3ary_232!(a1, b1, c1));
186+
let (a2, b2, _, _, e2, f2, _, _) =
187+
(x1.wrapping_add(y1), a1, b1, c1, x1.wrapping_add(d1), e1, f1, g1);
188+
189+
[a2, b2, e2, f2]
190+
}
191+
192+
fn sha256msg1(v0: [u32; 4], v1: [u32; 4]) -> [u32; 4] {
193+
// sigma 0 on vectors
194+
#[inline]
195+
fn sigma0x4(x: [u32; 4]) -> [u32; 4] {
196+
let t1 = or(shr(x, 7), shl(x, 25));
197+
let t2 = or(shr(x, 18), shl(x, 14));
198+
let t3 = shr(x, 3);
199+
xor(xor(t1, t2), t3)
200+
}
201+
202+
add(v0, sigma0x4(sha256load(v0, v1)))
203+
}
204+
205+
fn sha256msg2(v4: [u32; 4], v3: [u32; 4]) -> [u32; 4] {
206+
macro_rules! sigma1 {
207+
($a:expr) => {
208+
$a.rotate_right(17) ^ $a.rotate_right(19) ^ ($a >> 10)
209+
};
210+
}
211+
212+
let [x3, x2, x1, x0] = v4;
213+
let [w15, w14, _, _] = v3;
214+
215+
let w16 = x0.wrapping_add(sigma1!(w14));
216+
let w17 = x1.wrapping_add(sigma1!(w15));
217+
let w18 = x2.wrapping_add(sigma1!(w16));
218+
let w19 = x3.wrapping_add(sigma1!(w17));
219+
220+
[w19, w18, w17, w16]
221+
}

0 commit comments

Comments
 (0)