Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit ac749a1

Browse files
authored
add matrix_inversion example (rust-lang#131)
* add matrix_inversion example
1 parent 871d588 commit ac749a1

File tree

1 file changed

+316
-0
lines changed

1 file changed

+316
-0
lines changed
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
//! 4x4 matrix inverse
2+
// Code ported from the `packed_simd` crate
3+
// Run this code with `cargo test --example matrix_inversion`
4+
#![feature(array_chunks)]
5+
use core_simd::*;
6+
7+
// Gotta define our own 4x4 matrix since Rust doesn't ship multidim arrays yet :^)
8+
#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
9+
pub struct Matrix4x4([[f32; 4]; 4]);
10+
11+
#[allow(clippy::too_many_lines)]
12+
pub fn scalar_inv4x4(m: Matrix4x4) -> Option<Matrix4x4> {
13+
let m = m.0;
14+
15+
let mut inv = [
16+
[ // row 0:
17+
// 0,0:
18+
m[1][1] * m[2][2] * m[3][3] -
19+
m[1][1] * m[2][3] * m[3][2] -
20+
m[2][1] * m[1][2] * m[3][3] +
21+
m[2][1] * m[1][3] * m[3][2] +
22+
m[3][1] * m[1][2] * m[2][3] -
23+
m[3][1] * m[1][3] * m[2][2],
24+
// 0,1:
25+
-m[0][1] * m[2][2] * m[3][3] +
26+
m[0][1] * m[2][3] * m[3][2] +
27+
m[2][1] * m[0][2] * m[3][3] -
28+
m[2][1] * m[0][3] * m[3][2] -
29+
m[3][1] * m[0][2] * m[2][3] +
30+
m[3][1] * m[0][3] * m[2][2],
31+
// 0,2:
32+
m[0][1] * m[1][2] * m[3][3] -
33+
m[0][1] * m[1][3] * m[3][2] -
34+
m[1][1] * m[0][2] * m[3][3] +
35+
m[1][1] * m[0][3] * m[3][2] +
36+
m[3][1] * m[0][2] * m[1][3] -
37+
m[3][1] * m[0][3] * m[1][2],
38+
// 0,3:
39+
-m[0][1] * m[1][2] * m[2][3] +
40+
m[0][1] * m[1][3] * m[2][2] +
41+
m[1][1] * m[0][2] * m[2][3] -
42+
m[1][1] * m[0][3] * m[2][2] -
43+
m[2][1] * m[0][2] * m[1][3] +
44+
m[2][1] * m[0][3] * m[1][2],
45+
],
46+
[ // row 1
47+
// 1,0:
48+
-m[1][0] * m[2][2] * m[3][3] +
49+
m[1][0] * m[2][3] * m[3][2] +
50+
m[2][0] * m[1][2] * m[3][3] -
51+
m[2][0] * m[1][3] * m[3][2] -
52+
m[3][0] * m[1][2] * m[2][3] +
53+
m[3][0] * m[1][3] * m[2][2],
54+
// 1,1:
55+
m[0][0] * m[2][2] * m[3][3] -
56+
m[0][0] * m[2][3] * m[3][2] -
57+
m[2][0] * m[0][2] * m[3][3] +
58+
m[2][0] * m[0][3] * m[3][2] +
59+
m[3][0] * m[0][2] * m[2][3] -
60+
m[3][0] * m[0][3] * m[2][2],
61+
// 1,2:
62+
-m[0][0] * m[1][2] * m[3][3] +
63+
m[0][0] * m[1][3] * m[3][2] +
64+
m[1][0] * m[0][2] * m[3][3] -
65+
m[1][0] * m[0][3] * m[3][2] -
66+
m[3][0] * m[0][2] * m[1][3] +
67+
m[3][0] * m[0][3] * m[1][2],
68+
// 1,3:
69+
m[0][0] * m[1][2] * m[2][3] -
70+
m[0][0] * m[1][3] * m[2][2] -
71+
m[1][0] * m[0][2] * m[2][3] +
72+
m[1][0] * m[0][3] * m[2][2] +
73+
m[2][0] * m[0][2] * m[1][3] -
74+
m[2][0] * m[0][3] * m[1][2],
75+
],
76+
[ // row 2
77+
// 2,0:
78+
m[1][0] * m[2][1] * m[3][3] -
79+
m[1][0] * m[2][3] * m[3][1] -
80+
m[2][0] * m[1][1] * m[3][3] +
81+
m[2][0] * m[1][3] * m[3][1] +
82+
m[3][0] * m[1][1] * m[2][3] -
83+
m[3][0] * m[1][3] * m[2][1],
84+
// 2,1:
85+
-m[0][0] * m[2][1] * m[3][3] +
86+
m[0][0] * m[2][3] * m[3][1] +
87+
m[2][0] * m[0][1] * m[3][3] -
88+
m[2][0] * m[0][3] * m[3][1] -
89+
m[3][0] * m[0][1] * m[2][3] +
90+
m[3][0] * m[0][3] * m[2][1],
91+
// 2,2:
92+
m[0][0] * m[1][1] * m[3][3] -
93+
m[0][0] * m[1][3] * m[3][1] -
94+
m[1][0] * m[0][1] * m[3][3] +
95+
m[1][0] * m[0][3] * m[3][1] +
96+
m[3][0] * m[0][1] * m[1][3] -
97+
m[3][0] * m[0][3] * m[1][1],
98+
// 2,3:
99+
-m[0][0] * m[1][1] * m[2][3] +
100+
m[0][0] * m[1][3] * m[2][1] +
101+
m[1][0] * m[0][1] * m[2][3] -
102+
m[1][0] * m[0][3] * m[2][1] -
103+
m[2][0] * m[0][1] * m[1][3] +
104+
m[2][0] * m[0][3] * m[1][1],
105+
],
106+
[ // row 3
107+
// 3,0:
108+
-m[1][0] * m[2][1] * m[3][2] +
109+
m[1][0] * m[2][2] * m[3][1] +
110+
m[2][0] * m[1][1] * m[3][2] -
111+
m[2][0] * m[1][2] * m[3][1] -
112+
m[3][0] * m[1][1] * m[2][2] +
113+
m[3][0] * m[1][2] * m[2][1],
114+
// 3,1:
115+
m[0][0] * m[2][1] * m[3][2] -
116+
m[0][0] * m[2][2] * m[3][1] -
117+
m[2][0] * m[0][1] * m[3][2] +
118+
m[2][0] * m[0][2] * m[3][1] +
119+
m[3][0] * m[0][1] * m[2][2] -
120+
m[3][0] * m[0][2] * m[2][1],
121+
// 3,2:
122+
-m[0][0] * m[1][1] * m[3][2] +
123+
m[0][0] * m[1][2] * m[3][1] +
124+
m[1][0] * m[0][1] * m[3][2] -
125+
m[1][0] * m[0][2] * m[3][1] -
126+
m[3][0] * m[0][1] * m[1][2] +
127+
m[3][0] * m[0][2] * m[1][1],
128+
// 3,3:
129+
m[0][0] * m[1][1] * m[2][2] -
130+
m[0][0] * m[1][2] * m[2][1] -
131+
m[1][0] * m[0][1] * m[2][2] +
132+
m[1][0] * m[0][2] * m[2][1] +
133+
m[2][0] * m[0][1] * m[1][2] -
134+
m[2][0] * m[0][2] * m[1][1],
135+
],
136+
];
137+
138+
let det = m[0][0] * inv[0][0] + m[0][1] * inv[1][0] +
139+
m[0][2] * inv[2][0] + m[0][3] * inv[3][0];
140+
if det == 0. { return None; }
141+
142+
let det_inv = 1. / det;
143+
144+
for row in &mut inv {
145+
for elem in row.iter_mut() {
146+
*elem *= det_inv;
147+
}
148+
}
149+
150+
Some(Matrix4x4(inv))
151+
}
152+
153+
pub fn simd_inv4x4(m: Matrix4x4) -> Option<Matrix4x4> {
154+
let m = m.0;
155+
let m_0 = f32x4::from_array(m[0]);
156+
let m_1 = f32x4::from_array(m[1]);
157+
let m_2 = f32x4::from_array(m[2]);
158+
let m_3 = f32x4::from_array(m[3]);
159+
160+
// 2 argument shuffle, returns an f32x4
161+
// the first f32x4 is indexes 0..=3
162+
// the second f32x4 is indexed 4..=7
163+
let tmp1 = f32x4::shuffle::<{[0, 1, 4, 5]}>(m_0, m_1);
164+
let row1 = f32x4::shuffle::<{[0, 1, 4, 5]}>(m_2, m_3,);
165+
166+
let row0 = f32x4::shuffle::<{[0, 2, 4, 6]}>(tmp1, row1);
167+
let row1 = f32x4::shuffle::<{[1, 3, 5, 7]}>(row1, tmp1);
168+
169+
let tmp1 = f32x4::shuffle::<{[2, 3, 6, 7]}>(m_0, m_1);
170+
let row3 = f32x4::shuffle::<{[2, 3, 6, 7]}>(m_2, m_3);
171+
let row2 = f32x4::shuffle::<{[0, 2, 4, 6]}>(tmp1, row3);
172+
let row3 = f32x4::shuffle::<{[1, 3, 5, 7]}>(row3, tmp1);
173+
174+
let tmp1 = row2 * row3;
175+
// there's no syntax for a 1 arg shuffle yet,
176+
// so we just pass the same f32x4 twice
177+
let tmp1 = f32x4::shuffle::<{[1, 0, 3, 2]}>(tmp1, tmp1);
178+
179+
let minor0 = row1 * tmp1;
180+
let minor1 = row0 * tmp1;
181+
let tmp1 = f32x4::shuffle::<{[2, 3, 0, 1]}>(tmp1, tmp1);
182+
let minor0 = (row1 * tmp1) - minor0;
183+
let minor1 = (row0 * tmp1) - minor1;
184+
let minor1 = f32x4::shuffle::<{[2, 3, 0, 1]}>(minor1, minor1);
185+
186+
let tmp1 = row1 * row2;
187+
let tmp1 = f32x4::shuffle::<{[1, 0, 3, 2]}>(tmp1, tmp1);
188+
let minor0 = (row3 * tmp1) + minor0;
189+
let minor3 = row0 * tmp1;
190+
let tmp1 = f32x4::shuffle::<{[2, 3, 0, 1]}>(tmp1, tmp1);
191+
192+
let minor0 = minor0 - row3 * tmp1;
193+
let minor3 = row0 * tmp1 - minor3;
194+
let minor3 = f32x4::shuffle::<{[2, 3, 0, 1]}>(minor3, minor3);
195+
196+
let tmp1 = row3 * f32x4::shuffle::<{[2, 3, 0, 1]}>(row1, row1);
197+
let tmp1 = f32x4::shuffle::<{[1, 0, 3, 2]}>(tmp1, tmp1);
198+
let row2 = f32x4::shuffle::<{[2, 3, 0, 1]}>(row2, row2);
199+
let minor0 = row2 * tmp1 + minor0;
200+
let minor2 = row0 * tmp1;
201+
let tmp1 = f32x4::shuffle::<{[2, 3, 0, 1]}>(tmp1, tmp1);
202+
let minor0 = minor0 - row2 * tmp1;
203+
let minor2 = row0 * tmp1 - minor2;
204+
let minor2 = f32x4::shuffle::<{[2, 3, 0, 1]}>(minor2, minor2);
205+
206+
let tmp1 = row0 * row1;
207+
let tmp1 = f32x4::shuffle::<{[1, 0, 3, 2]}>(tmp1, tmp1);
208+
let minor2 = minor2 + row3 * tmp1;
209+
let minor3 = row2 * tmp1 - minor3;
210+
let tmp1 = f32x4::shuffle::<{[2, 3, 0, 1]}>(tmp1, tmp1);
211+
let minor2 = row3 * tmp1 - minor2;
212+
let minor3 = minor3 - row2 * tmp1;
213+
214+
let tmp1 = row0 * row3;
215+
let tmp1 = f32x4::shuffle::<{[1, 0, 3, 2]}>(tmp1, tmp1);
216+
let minor1 = minor1 - row2 * tmp1;
217+
let minor2 = row1 * tmp1 + minor2;
218+
let tmp1 = f32x4::shuffle::<{[2, 3, 0, 1]}>(tmp1, tmp1);
219+
let minor1 = row2 * tmp1 + minor1;
220+
let minor2 = minor2 - row1 * tmp1;
221+
222+
let tmp1 = row0 * row2;
223+
let tmp1 = f32x4::shuffle::<{[1, 0, 3, 2]}>(tmp1, tmp1);
224+
let minor1 = row3 * tmp1 + minor1;
225+
let minor3 = minor3 - row1 * tmp1;
226+
let tmp1 = f32x4::shuffle::<{[2, 3, 0, 1]}>(tmp1, tmp1);
227+
let minor1 = minor1 - row3 * tmp1;
228+
let minor3 = row1 * tmp1 + minor3;
229+
230+
let det = row0 * minor0;
231+
let det = f32x4::shuffle::<{[2, 3, 0, 1]}>(det, det) + det;
232+
let det = f32x4::shuffle::<{[1, 0, 3, 2]}>(det, det) + det;
233+
234+
if det.horizontal_sum() == 0. {
235+
return None;
236+
}
237+
// calculate the reciprocal
238+
let tmp1 = f32x4::splat(1.0) / det;
239+
let det = tmp1 + tmp1 - det * tmp1 * tmp1;
240+
241+
let res0 = minor0 * det;
242+
let res1 = minor1 * det;
243+
let res2 = minor2 * det;
244+
let res3 = minor3 * det;
245+
246+
let mut m = m;
247+
248+
m[0] = res0.to_array();
249+
m[1] = res1.to_array();
250+
m[2] = res2.to_array();
251+
m[3] = res3.to_array();
252+
253+
Some(Matrix4x4(m))
254+
}
255+
256+
257+
#[cfg(test)]
258+
#[rustfmt::skip]
259+
mod tests {
260+
use super::*;
261+
262+
#[test]
263+
fn test() {
264+
let tests: &[(Matrix4x4, Option<Matrix4x4>)] = &[
265+
// Identity:
266+
(Matrix4x4([
267+
[1., 0., 0., 0.],
268+
[0., 1., 0., 0.],
269+
[0., 0., 1., 0.],
270+
[0., 0., 0., 1.],
271+
]),
272+
Some(Matrix4x4([
273+
[1., 0., 0., 0.],
274+
[0., 1., 0., 0.],
275+
[0., 0., 1., 0.],
276+
[0., 0., 0., 1.],
277+
]))
278+
),
279+
// None:
280+
(Matrix4x4([
281+
[1., 2., 3., 4.],
282+
[12., 11., 10., 9.],
283+
[5., 6., 7., 8.],
284+
[16., 15., 14., 13.],
285+
]),
286+
None
287+
),
288+
// Other:
289+
(Matrix4x4([
290+
[1., 1., 1., 0.],
291+
[0., 3., 1., 2.],
292+
[2., 3., 1., 0.],
293+
[1., 0., 2., 1.],
294+
]),
295+
Some(Matrix4x4([
296+
[-3., -0.5, 1.5, 1.0],
297+
[ 1., 0.25, -0.25, -0.5],
298+
[ 3., 0.25, -1.25, -0.5],
299+
[-3., 0.0, 1.0, 1.0],
300+
]))
301+
),
302+
303+
304+
];
305+
306+
for &(input, output) in tests {
307+
assert_eq!(scalar_inv4x4(input), output);
308+
assert_eq!(simd_inv4x4(input), output);
309+
}
310+
}
311+
}
312+
313+
314+
fn main() {
315+
// Empty main to make cargo happy
316+
}

0 commit comments

Comments
 (0)