Skip to content

Commit f79b4fd

Browse files
AdamNiedereralexcrichton
authored andcommitted
Add vroundps, vceilps, vfloorps, vsqrtps, vsqrtpd (rust-lang#53)
* Add vroundps, vceilps, vfloorps, vsqrtps, vsqrtpd * Uninhibit assert_instr on non-expanded intrinsics Also use the new simd_test macro * Use simd_test where possible * Add automated tests for vround* * Add target_feature guards to automated tests * Move automated tests below their functions
1 parent 9d3e1d2 commit f79b4fd

File tree

1 file changed

+120
-1
lines changed

1 file changed

+120
-1
lines changed

src/x86/avx.rs

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,72 @@ pub unsafe fn _mm256_floor_pd(a: f64x4) -> f64x4 {
110110
roundpd256(a, 0x01)
111111
}
112112

113+
/// Round packed single-precision (32-bit) floating point elements in `a`
114+
/// according to the flag `b`. The value of `b` may be as follows:
115+
/// 0x00: Round to the nearest whole number.
116+
/// 0x01: Round down, toward negative infinity.
117+
/// 0x02: Round up, toward positive infinity.
118+
/// 0x03: Truncate the values.
119+
/// For a few additional values options, check the LLVM docs:
120+
/// https://github.com/llvm-mirror/clang/blob/dcd8d797b20291f1a6b3e0ddda085aa2bbb382a8/lib/Headers/avxintrin.h#L382
121+
#[inline(always)]
122+
#[target_feature = "+avx"]
123+
// #[cfg_attr(test, assert_instr(vroundps))]
124+
// TODO: Replace with assert_expanded_instr https://github.com/rust-lang-nursery/stdsimd/issues/49
125+
pub fn _mm256_round_ps(a: f32x8, b: i32) -> f32x8 {
126+
macro_rules! call {
127+
($imm8:expr) => {
128+
unsafe { roundps256(a, $imm8) }
129+
}
130+
}
131+
constify_imm8!(b, call)
132+
}
133+
134+
// TODO: Remove once a macro is ipmlemented to automate these tests
135+
// https://github.com/rust-lang-nursery/stdsimd/issues/49
136+
#[cfg(test)]
137+
#[target_feature = "+avx"]
138+
#[cfg_attr(test, assert_instr(vroundps))]
139+
fn test_mm256_round_ps(a: f32x8) -> f32x8 {
140+
_mm256_round_ps(a, 0x00)
141+
}
142+
143+
/// Round packed single-precision (32-bit) floating point elements in `a` toward
144+
/// positive infinity.
145+
#[inline(always)]
146+
#[target_feature = "+avx"]
147+
#[cfg_attr(test, assert_instr(vroundps))]
148+
pub fn _mm256_ceil_ps(a: f32x8) -> f32x8 {
149+
unsafe { roundps256(a, 0x02) }
150+
}
151+
152+
/// Round packed single-precision (32-bit) floating point elements in `a` toward
153+
/// negative infinity.
154+
#[inline(always)]
155+
#[target_feature = "+avx"]
156+
#[cfg_attr(test, assert_instr(vroundps))]
157+
pub fn _mm256_floor_ps(a: f32x8) -> f32x8 {
158+
unsafe { roundps256(a, 0x01) }
159+
}
160+
161+
/// Return the square root of packed single-precision (32-bit) floating point
162+
/// elements in `a`.
163+
#[inline(always)]
164+
#[target_feature = "+avx"]
165+
#[cfg_attr(test, assert_instr(vsqrtps))]
166+
pub fn _mm256_sqrt_ps(a: f32x8) -> f32x8 {
167+
unsafe { sqrtps256(a) }
168+
}
169+
170+
/// Return the square root of packed double-precision (64-bit) floating point
171+
/// elements in `a`.
172+
#[inline(always)]
173+
#[target_feature = "+avx"]
174+
#[cfg_attr(test, assert_instr(vsqrtpd))]
175+
pub fn _mm256_sqrt_pd(a: f64x4) -> f64x4 {
176+
unsafe { sqrtpd256(a) }
177+
}
178+
113179
/// LLVM intrinsics used in the above functions
114180
#[allow(improper_ctypes)]
115181
extern "C" {
@@ -119,9 +185,15 @@ extern "C" {
119185
fn addsubps256(a: f32x8, b: f32x8) -> f32x8;
120186
#[link_name = "llvm.x86.avx.round.pd.256"]
121187
fn roundpd256(a: f64x4, b: i32) -> f64x4;
188+
#[link_name = "llvm.x86.avx.round.ps.256"]
189+
fn roundps256(a: f32x8, b: i32) -> f32x8;
190+
#[link_name = "llvm.x86.avx.sqrt.pd.256"]
191+
fn sqrtpd256(a: f64x4) -> f64x4;
192+
#[link_name = "llvm.x86.avx.sqrt.ps.256"]
193+
fn sqrtps256(a: f32x8) -> f32x8;
122194
}
123195

124-
#[cfg(test)]
196+
#[cfg(all(test, target_feature = "avx", any(target_arch = "x86", target_arch = "x86_64")))]
125197
mod tests {
126198
use stdsimd_test::simd_test;
127199

@@ -229,4 +301,51 @@ mod tests {
229301
let expected_up = f64x4::new(2.0, 3.0, 4.0, -1.0);
230302
assert_eq!(result_up, expected_up);
231303
}
304+
305+
#[simd_test = "avx"]
306+
fn _mm256_round_ps() {
307+
let a = f32x8::new(1.55, 2.2, 3.99, -1.2, 1.55, 2.2, 3.99, -1.2);
308+
let result_closest = avx::_mm256_round_ps(a, 0b00000000);
309+
let result_down = avx::_mm256_round_ps(a, 0b00000001);
310+
let result_up = avx::_mm256_round_ps(a, 0b00000010);
311+
let expected_closest = f32x8::new(2.0, 2.0, 4.0, -1.0, 2.0, 2.0, 4.0, -1.0);
312+
let expected_down = f32x8::new(1.0, 2.0, 3.0, -2.0, 1.0, 2.0, 3.0, -2.0);
313+
let expected_up = f32x8::new(2.0, 3.0, 4.0, -1.0, 2.0, 3.0, 4.0, -1.0);
314+
assert_eq!(result_closest, expected_closest);
315+
assert_eq!(result_down, expected_down);
316+
assert_eq!(result_up, expected_up);
317+
}
318+
319+
#[simd_test = "avx"]
320+
fn _mm256_floor_ps() {
321+
let a = f32x8::new(1.55, 2.2, 3.99, -1.2, 1.55, 2.2, 3.99, -1.2);
322+
let result_down = avx::_mm256_floor_ps(a);
323+
let expected_down = f32x8::new(1.0, 2.0, 3.0, -2.0, 1.0, 2.0, 3.0, -2.0);
324+
assert_eq!(result_down, expected_down);
325+
}
326+
327+
#[simd_test = "avx"]
328+
fn _mm256_ceil_ps() {
329+
let a = f32x8::new(1.55, 2.2, 3.99, -1.2, 1.55, 2.2, 3.99, -1.2);
330+
let result_up = avx::_mm256_ceil_ps(a);
331+
let expected_up = f32x8::new(2.0, 3.0, 4.0, -1.0, 2.0, 3.0, 4.0, -1.0);
332+
assert_eq!(result_up, expected_up);
333+
}
334+
335+
#[simd_test = "avx"]
336+
fn _mm256_sqrt_pd() {
337+
let a = f64x4::new(4.0, 9.0, 16.0, 25.0);
338+
let r = avx::_mm256_sqrt_pd(a, );
339+
let e = f64x4::new(2.0, 3.0, 4.0, 5.0);
340+
assert_eq!(r, e);
341+
}
342+
343+
#[simd_test = "avx"]
344+
fn _mm256_sqrt_ps() {
345+
let a = f32x8::new(4.0, 9.0, 16.0, 25.0, 4.0, 9.0, 16.0, 25.0);
346+
let r = avx::_mm256_sqrt_ps(a);
347+
let e = f32x8::new(2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0);
348+
assert_eq!(r, e);
349+
}
350+
232351
}

0 commit comments

Comments
 (0)