Skip to content

Commit 47c5c5a

Browse files
committed
Fix 1D tiling
The original code from the blog looks wrong. The code in the repo has these checks and they make tests pass.
1 parent 7dbd06a commit 47c5c5a

File tree

4 files changed

+58
-10
lines changed

4 files changed

+58
-10
lines changed

blog/2024-11-21-optimizing-matrix-mul/code/bin/blog/src/bin.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn main() {
3030
run_tests(matmul::naive::wgpu(), &sizes);
3131
run_tests(matmul::workgroup_256::wgpu(), &sizes);
3232
run_tests(matmul::workgroup_2d::wgpu(), &sizes);
33-
//run_tests(matmul::tiling_1d::wgpu(), &sizes);
33+
run_tests(matmul::tiling_1d::wgpu(), &sizes);
3434
run_tests(matmul::tiling_2d_simd::wgpu(), &sizes);
3535

3636
run_tests(matmul::isomorphic::wgpu(), &sizes);

blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/cpu.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,37 @@ mod tests {
171171
assert_eq!(result, expected);
172172
}
173173

174+
#[test]
175+
fn test_single_threaded_matmul_4x4() {
176+
let m = 4;
177+
let k = 4;
178+
let n = 4;
179+
180+
// Define matrix `a` (4x4) in row-major order
181+
let a = vec![
182+
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
183+
];
184+
185+
// Define matrix `b` (4x4) in row-major order
186+
let b = vec![
187+
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
188+
31.0, 32.0,
189+
];
190+
191+
// Expected result (4x4) after multiplying `a` and `b`
192+
let expected = vec![
193+
250.0, 260.0, 270.0, 280.0, 618.0, 644.0, 670.0, 696.0, 986.0, 1028.0, 1070.0, 1112.0,
194+
1354.0, 1412.0, 1470.0, 1528.0,
195+
];
196+
197+
let variant = crate::variants::Isomorphic;
198+
let matrix_multiplier = futures::executor::block_on(SingleThreadedMatMul::new(variant));
199+
200+
let result = matrix_multiplier.multiply(&a, &b, m, k, n);
201+
202+
assert_eq!(result, expected);
203+
}
204+
174205
#[test]
175206
fn test_multithreaded_matmul_2x1x1() {
176207
let m = 2;

blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_1d/src/lib.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,30 @@ pub fn matmul(
2727

2828
for i in 0..dimensions.k as usize {
2929
let a_elem = a[row * dimensions.k as usize + i];
30-
sum00 += a_elem * b[i * dimensions.n as usize + col];
31-
sum01 += a_elem * b[i * dimensions.n as usize + col + 1];
32-
sum02 += a_elem * b[i * dimensions.n as usize + col + 2];
33-
sum03 += a_elem * b[i * dimensions.n as usize + col + 3];
30+
if col < dimensions.n as usize {
31+
sum00 += a_elem * b[i * dimensions.n as usize + col];
32+
}
33+
if col + 1 < dimensions.n as usize {
34+
sum01 += a_elem * b[i * dimensions.n as usize + col + 1];
35+
}
36+
if col + 2 < dimensions.n as usize {
37+
sum02 += a_elem * b[i * dimensions.n as usize + col + 2];
38+
}
39+
if col + 3 < dimensions.n as usize {
40+
sum03 += a_elem * b[i * dimensions.n as usize + col + 3];
41+
}
3442
}
3543

36-
result[row * dimensions.n as usize + col] = sum00;
37-
result[row * dimensions.n as usize + col + 1] = sum01;
38-
result[row * dimensions.n as usize + col + 2] = sum02;
39-
result[row * dimensions.n as usize + col + 3] = sum03;
44+
if col < dimensions.n as usize {
45+
result[row * dimensions.n as usize + col] = sum00;
46+
}
47+
if col + 1 < dimensions.n as usize {
48+
result[row * dimensions.n as usize + col + 1] = sum01;
49+
}
50+
if col + 2 < dimensions.n as usize {
51+
result[row * dimensions.n as usize + col + 2] = sum02;
52+
}
53+
if col + 3 < dimensions.n as usize {
54+
result[row * dimensions.n as usize + col + 3] = sum03;
55+
}
4056
}

blog/2024-11-21-optimizing-matrix-mul/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ import { RustTiling1d } from './snippets/tiling_1d.tsx';
275275
<RustTiling1d />
276276

277277
The kernel looks roughly the same as before except we've unrolled the computation and
278-
are calculating `TILE_SIZE` results per thread.
278+
are calculating `TILE_SIZE` results per thread. We also need some error checking for
279+
when our matrices don't fit nicely.
279280

280281
We can take this a step further and calculate 2D results per thread! Instead of
281282
calculating 4 elements per single row, we can calculate 4 elements for 4 rows (e.g. a 2D

0 commit comments

Comments
 (0)