Note
- Source code: Available on GitHub.
- This is my journey to optimize a single-floating point (FP32) matrix multiplication (SGEMM) kernel on my Apple M2 laptop. Since I don’t have an NVIDIA card, I’m using Apple’s Metal API instead of CUDA. The core ideas are the same: making GEMM as fast as possible.
The theoretical FP32 peak of an 8-core M2 is ~2.84 TFLOPs (or 2840 GFLOPS). Can we even get close? Let’s find out.
TL;DR
How fast did we get? Here’s the summary of the journey.
| Kernel | Best performance (GFLOPS) | Percent of peak performance |
|---|---|---|
naive | ||
tile_16 | ||
tile_32 | ||
tile_threads | ||
tile_simdgroup |
The Optimization Journey
This project was structured as a series of optimizations, with each new kernel building on the lessons of the last.
1. The Naive Kernel (naive.metal)
This kernel is the most straightforward implementation: one thread computes one element of the final matrix. It’s simple, easy to verify, but hammers global device memory (DRAM) with no data reuse.
kernel void matmul_naive(device const float * A [[buffer(0)]], device const float * B [[buffer(1)]], device float * C [[buffer(2)]], device const MatmulParams& params [[buffer(3)]], uint2 block_pos [[ threadgroup_position_in_grid ]], uint2 thread_pos [[ thread_position_in_threadgroup ]]){ // Thread index const uint thread_x = thread_pos.x; // CUDA: threadIdx.x const uint thread_y = thread_pos.y; // CUDA: threadIdx.y
// Calculate global row and col const uint j = block_pos.x * params.BLOCK_SIZE_X + thread_x; // col const uint i = block_pos.y * params.BLOCK_SIZE_Y + thread_y; // row
const uint M = params.M; const uint N = params.N; const uint K = params.K;
if (i < M && j < N) { float sum = 0.f; for (uint p = 0; p < K; ++p) { // Read one row of A and one col of B from DRAM sum += A[i * K + p] * B[p * N + j]; } C[i * N + j] = sum; }}Performance: GFLOPS. Not great, but it’s a start.
2. The Tiling (tile_16.metal & tile_32.metal)
This is the first real optimization. Instead of hitting DRAM for every multiplication, we load a “tile” of matrix and a “tile” of matrix into the fast, on-chip threadgroup memory (equivalent to CUDA’s __shared__ memory).
Each thread in the threadgroup helps load a piece of the tiles, we synchronize (threadgroup_barrier) (equivalent to CUDA’s __syncthreads), and then all threads in the group compute their part of the output using data from that fast memory.
kernel void matmul_tile_16(device const float * A [[buffer(0)]], /* ... buffers ... */ uint2 thread_pos [[ thread_position_in_threadgroup ]]){ // create tile block constexpr uint TILE_SIZE = 16; threadgroup float tileA[TILE_SIZE][TILE_SIZE]; threadgroup float tileB[TILE_SIZE][TILE_SIZE];
// ... calculate row, col, block_x, block_y ...
float sum = 0.0f; for (uint t = 0; t < (params.K + TILE_SIZE - 1) / TILE_SIZE; ++t) { uint tiledColA = t * TILE_SIZE + thread_x; uint tiledRowB = t * TILE_SIZE + thread_y;
// Load tile A from global to threadgroup memory if (row < params.M && tiledColA < params.K) tileA[thread_y][thread_x] = A[row * params.K + tiledColA]; else tileA[thread_y][thread_x] = 0.0f;
// Load tile B from global to threadgroup memory if (tiledRowB < params.K && col < params.N) tileB[thread_y][thread_x] = B[tiledRowB * params.N + col]; else tileB[thread_y][thread_x] = 0.0f;
// Wait for all threads to finish loading threadgroup_barrier(mem_flags::mem_threadgroup);
// fast matmul on tile (from fast memory) #pragma clang loop unroll(full) for (uint k = 0; k < TILE_SIZE; ++k) { sum += tileA[thread_y][k] * tileB[k][thread_x]; }
// Wait for all threads to finish computing threadgroup_barrier(mem_flags::mem_threadgroup); }
if (row < params.M && col < params.N) { C[row * params.N + col] = sum; }}Performance: tile_16 ( GFLOPS) was a big jump! But… tile_32 ( GFLOPS) was slower. Why?
Note (Why is tile_16 faster than tile_32?)
The answer is Occupancy.
-
What is Occupancy? Occupancy is the ratio of active threadgroups to the maximum number of threadgroups that can run on a single GPU compute unit (CU) (or an SM in CUDA). High occupancy is critical for hiding memory latency. When one group of threads is stalled waiting for data from DRAM, the GPU scheduler can switch to another resident group and keep the compute units busy.
-
Resource Limits: A CU has a fixed amount of resources, including
threadgroupmemory.tile_16kernel (16x16 = 256 threads):
threadgroupmemory =(16*16 + 16*16) * 4 bytes = 2048 bytes.tile_32kernel (32x32 = 1024 threads):
threadgroupmemory =(32*32 + 32*32) * 4 bytes = 8192 bytes.
-
The Bottleneck: The M2 GPU’s CUs have a limited amount of
threadgroupmemory (e.g., 32 KB). Thetile_32kernel’s 8KB footprint is significant. If a single threadgroup consumes too large a chunk of the CU’s memory, the scheduler cannot fit as many concurrent threadgroups onto that CU.
With tile_32, we fit fewer groups per CU, leading to low occupancy. If those few groups stall on a memory read, the expensive ALU units sit idle. The tile_16 kernel, with its smaller 2KB footprint, allows many more threadgroups to be resident, effectively hiding memory latency.
3. More Work Per Thread (tile_threads.metal)
The next step was to reduce synchronization overhead and increase register reuse. In the previous kernel, each thread computed only one output value. Here, we use a smaller threadgroup (8x8) but make each thread compute a 4x4 block of the output tile.
These 16 accumulator values (C_reg[4][4]) are stored in the thread’s private registers, which are even faster than threadgroup memory.
kernel void matmul_tile_threads(device const float * A [[buffer(0)]], /* ... buffers ... */ uint2 thread_pos [[thread_position_in_threadgroup]]){ // ... TILE_M=32, TILE_N=32, TILE_K=16 ... // ... TG_M=8, TG_N=8 ...
// Work per thread constexpr uint WPT_M = TILE_M / TG_M; // 4 rows per thread constexpr uint WPT_N = TILE_N / TG_N; // 4 cols per thread
const uint thread_m = thread_pos.y; // 0..7 const uint thread_n = thread_pos.x; // 0..7
// ...
// 16 accumulator values stored in private registers float C_reg[WPT_M][WPT_N] = {{0.0f}};
for (uint t = 0; t < params.K; t += TILE_K) { // ... complicated loading logic to fill tileA/tileB ...
threadgroup_barrier(mem_flags::mem_threadgroup);
// Compute on registers #pragma clang loop unroll(full) for (uint k = 0; k < TILE_K; ++k) { #pragma clang loop unroll(full) for (uint m = 0; m < WPT_M; ++m) { float a_val = tileA[thread_m * WPT_M + m][k]; #pragma clang loop unroll(full) for (uint n = 0; n < WPT_N; ++n) { C_reg[m][n] += a_val * tileB[k][thread_n * WPT_N + n]; } } } threadgroup_barrier(mem_flags::mem_threadgroup); }
// Write 16 results from registers to global memory for (uint m = 0; m < WPT_M; ++m) { for (uint n = 0; n < WPT_N; ++n) { // ... calculate c_row, c_col ... if (c_row < params.M && c_col < params.N) { C[c_row * params.N + c_col] = C_reg[m][n]; } } }}Performance: GFLOPS. Another solid jump! We’re doing more compute per memory load and reducing our reliance on threadgroup_barrier inside the inner loop.
4. Hardware Acceleration (tile_simdgroup.metal)
This is the game-changer. Modern GPUs have specialized hardware for matrix math (equivalent to Tensor Cores in NVIDIA’s GPU). Metal exposes this through simdgroup_matrix intrinsics.
Instead of writing for loops, we tell the hardware: “load an 8x8 matrix tile”, “load another 8x8 tile”, and “multiply-accumulate them”. The compiler maps this to the ultra-fast hardware units. The code becomes much simpler and much faster.
#include <metal_simdgroup_matrix>
kernel void matmul_tile_simdgroup( device const float* A [[buffer(0)]], device const float* B [[buffer(1)]], device float* C [[buffer(2)]], device const MatmulParams& params [[buffer(3)]], uint2 block_pos [[threadgroup_position_in_grid]], uint simd_id [[simdgroup_index_in_threadgroup]]) { const uint TILE_DIM = 8;
// ... calculate c_row, c_col for this SIMD-group ...
if (c_row >= params.M || c_col >= params.N) { return; }
// Create an 8x8 accumulator matrix in SIMD-group registers simdgroup_float8x8 acc = make_filled_simdgroup_matrix<float, 8, 8>(0.0f);
for (uint k = 0; k < params.K; k += TILE_DIM) { device const float* a_ptr = A + c_row * params.K + k; device const float* b_ptr = B + k * params.N + c_col;
simdgroup_float8x8 a_tile; simdgroup_float8x8 b_tile;
// Load 8x8 tiles from global memory simdgroup_load(a_tile, a_ptr, params.K); simdgroup_load(b_tile, b_ptr, params.N);
// THE MAGIC: D = A * B + C simdgroup_multiply_accumulate(acc, a_tile, b_tile, acc); }
// Store the 8x8 result tile to global memory simdgroup_store(acc, C + c_row * params.N + c_col, params.N);}Performance: GFLOPS. This is the fastest kernel yet, and the code is the cleanest. This shows that the best optimization is often to use the hardware as it was designed.
This was a fantastic journey into the guts of Apple Silicon. While 17% of peak may not sound high, it’s a 2.3x speedup over the naive baseline and taught me an incredible amount about occupancy, memory hierarchies, and hardware-specific intrinsics.