Skip to content

Commit 3f91b13

Browse files
authored
Fix 1D tiling (#29)
The original code from the blog looks wrong. The code in the repo has these checks and they make tests pass.
1 parent 7dbd06a commit 3f91b13

File tree

4 files changed

+58
-10
lines changed

4 files changed

+58
-10
lines changed

blog/2024-11-21-optimizing-matrix-mul/code/bin/blog/src/bin.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn main() {
3030
run_tests(matmul::naive::wgpu(), &sizes);
3131
run_tests(matmul::workgroup_256::wgpu(), &sizes);
3232
run_tests(matmul::workgroup_2d::wgpu(), &sizes);
33-
//run_tests(matmul::tiling_1d::wgpu(), &sizes);
33+
run_tests(matmul::tiling_1d::wgpu(), &sizes);
3434
run_tests(matmul::tiling_2d_simd::wgpu(), &sizes);
3535

3636
run_tests(matmul::isomorphic::wgpu(), &sizes);

blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/cpu.rs

+31
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,37 @@ mod tests {
171171
assert_eq!(result, expected);
172172
}
173173

174+
#[test]
175+
fn test_single_threaded_matmul_4x4() {
176+
let m = 4;
177+
let k = 4;
178+
let n = 4;
179+
180+
// Define matrix `a` (4x4) in row-major order
181+
let a = vec![
182+
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
183+
];
184+
185+
// Define matrix `b` (4x4) in row-major order
186+
let b = vec![
187+
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
188+
31.0, 32.0,
189+
];
190+
191+
// Expected result (4x4) after multiplying `a` and `b`
192+
let expected = vec![
193+
250.0, 260.0, 270.0, 280.0, 618.0, 644.0, 670.0, 696.0, 986.0, 1028.0, 1070.0, 1112.0,
194+
1354.0, 1412.0, 1470.0, 1528.0,
195+
];
196+
197+
let variant = crate::variants::Isomorphic;
198+
let matrix_multiplier = futures::executor::block_on(SingleThreadedMatMul::new(variant));
199+
200+
let result = matrix_multiplier.multiply(&a, &b, m, k, n);
201+
202+
assert_eq!(result, expected);
203+
}
204+
174205
#[test]
175206
fn test_multithreaded_matmul_2x1x1() {
176207
let m = 2;

blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_1d/src/lib.rs

+24-8
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,30 @@ pub fn matmul(
2727

2828
for i in 0..dimensions.k as usize {
2929
let a_elem = a[row * dimensions.k as usize + i];
30-
sum00 += a_elem * b[i * dimensions.n as usize + col];
31-
sum01 += a_elem * b[i * dimensions.n as usize + col + 1];
32-
sum02 += a_elem * b[i * dimensions.n as usize + col + 2];
33-
sum03 += a_elem * b[i * dimensions.n as usize + col + 3];
30+
if col < dimensions.n as usize {
31+
sum00 += a_elem * b[i * dimensions.n as usize + col];
32+
}
33+
if col + 1 < dimensions.n as usize {
34+
sum01 += a_elem * b[i * dimensions.n as usize + col + 1];
35+
}
36+
if col + 2 < dimensions.n as usize {
37+
sum02 += a_elem * b[i * dimensions.n as usize + col + 2];
38+
}
39+
if col + 3 < dimensions.n as usize {
40+
sum03 += a_elem * b[i * dimensions.n as usize + col + 3];
41+
}
3442
}
3543

36-
result[row * dimensions.n as usize + col] = sum00;
37-
result[row * dimensions.n as usize + col + 1] = sum01;
38-
result[row * dimensions.n as usize + col + 2] = sum02;
39-
result[row * dimensions.n as usize + col + 3] = sum03;
44+
if col < dimensions.n as usize {
45+
result[row * dimensions.n as usize + col] = sum00;
46+
}
47+
if col + 1 < dimensions.n as usize {
48+
result[row * dimensions.n as usize + col + 1] = sum01;
49+
}
50+
if col + 2 < dimensions.n as usize {
51+
result[row * dimensions.n as usize + col + 2] = sum02;
52+
}
53+
if col + 3 < dimensions.n as usize {
54+
result[row * dimensions.n as usize + col + 3] = sum03;
55+
}
4056
}

blog/2024-11-21-optimizing-matrix-mul/index.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ import { RustTiling1d } from './snippets/tiling_1d.tsx';
275275
<RustTiling1d />
276276

277277
The kernel looks roughly the same as before except we've unrolled the computation and
278-
are calculating `TILE_SIZE` results per thread.
278+
are calculating `TILE_SIZE` results per thread. We also need some error checking for
279+
when our matrices don't fit nicely.
279280

280281
We can take this a step further and calculate 2D results per thread! Instead of
281282
calculating 4 elements per single row, we can calculate 4 elements for 4 rows (e.g. a 2D

0 commit comments

Comments
 (0)