From a76698bd636e5124e51e65b02b9e20d47a865def Mon Sep 17 00:00:00 2001 From: Christian Legnitto <christian@legnitto.com> Date: Wed, 20 Nov 2024 22:13:13 -0400 Subject: [PATCH] Remove unnecessary usize casts --- .../code/crates/gpu/tiling_2d/src/lib.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_2d/src/lib.rs b/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_2d/src/lib.rs index 4f8d4fc..9d1105b 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_2d/src/lib.rs +++ b/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_2d/src/lib.rs @@ -14,8 +14,8 @@ pub fn matmul( #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] b: &[f32], #[spirv(storage_buffer, descriptor_set = 0, binding = 3)] result: &mut [f32], ) { - let row = (global_id.y * TILE_M as u32) as usize; - let col = (global_id.x * TILE_N as u32) as usize; + let row = (global_id.y * TILE_M) as usize; + let col = (global_id.x * TILE_N) as usize; // Initialize sums array to zeros // Note: This is uglier than it needs to be to work around @@ -33,7 +33,7 @@ pub fn matmul( for j in 0..TILE_N as usize { let b_element = if col + j < dimensions.n as usize { - b[k * dimensions.n as usize + (col + j as usize)] + b[k * dimensions.n as usize + (col + j)] } else { 0.0 }; @@ -46,8 +46,8 @@ pub fn matmul( // Write results for i in 0..TILE_M as usize { for j in 0..TILE_N as usize { - let output_row = row + i as usize; - let output_col = col + j as usize; + let output_row = row + i; + let output_col = col + j; if output_row < dimensions.m as usize && output_col < dimensions.n as usize { result[output_row * dimensions.n as usize + output_col] = sums[i][j];