Implement GPU Inference with WebGPU

Done. Lessons 1 and 2 checked off.

Lesson 3: Matrix Multiplication Kernel

This is the big one. Matmul is the heart of every neural network. When people say "GPUs are good for AI" — they mean GPUs are good at matmul.

Why matmul maps to GPUs: To multiply two matrices, every cell in the output is an independent dot product. A 128x10 output matrix = 1,280 dot products that don't depend on each other. That's 1,280 threads that can run in parallel. CPUs have 8-16 cores. GPUs have thousands.

Let's start small. Before we touch the MNIST-sized matrices, we'll multiply two tiny matrices and verify correctness.

Test case — multiply a 2x3 matrix by a 3x2 matrix:

A (2x3):          B (3x2):          Result (2x2):
[1, 2, 3]         [7,  8]           [58,  64]
[4, 5, 6]         [9, 10]           [139, 154]
                   [11, 12]

The result at [row, col] = dot product of row row from A and column col from B. For example: result[0][0] = 1*7 + 2*9 + 3*11 = 58.

The shader needs to know:

  1. Matrix dimensions (M, N, K) — A is MxK, B is KxN, result is MxN
  2. Which output element this thread computes — that's @builtin(global_invocation_id), a vec3 that tells each thread its unique index

The key idea: each GPU thread computes exactly one element of the output matrix. Thread (row, col) loops over the shared dimension K, accumulating the dot product.

Here's what you need to change from the add shader:

  1. Buffers: Replace the three scalar buffers with three matrix buffers (A, B, result) plus a uniform buffer for dimensions (M, N, K). Uniform buffers are for small read-only data that every thread reads — perfect for dimensions.

  2. Shader: The WGSL uniform buffer looks like:

    struct Dimensions { M: u32, N: u32, K: u32 }
    @group(0) @binding(0) var<uniform> dims : Dimensions;

    Storage buffers for A, B, result at bindings 1, 2, 3. All array<f32> — matrices stored flat, row-major (row 0 first, then row 1, etc). To index element [row][col] in a matrix with N columns: row * N + col.

  3. Workgroup size: Use @workgroup_size(1) for now (we can optimize later). But dispatch dispatchWorkgroups(N, M) — one thread per output element.

  4. The main function: Gets @builtin(global_invocation_id) id : vec3<u32>. Use id.x as column, id.y as row. Loop K times accumulating the dot product.

Start fresh — new shader string, new buffers. You can keep the WebGPU init code (adapter + device). Go for it.

how to create the matrix buffers? arent buffers just raw memory?

Fo shure, exactly right. Buffers are just raw bytes — the GPU has no concept of "matrix." It's just a flat chunk of memory.

So a 2x3 matrix like [1, 2, 3, 4, 5, 6] is stored as exactly that — 6 floats in a row. Row-major order means row 0 comes first, then row 1. The buffer is just new Float32Array([1, 2, 3, 4, 5, 6]).

The buffer creation is identical to what you did before, just bigger:

device.createBuffer({
    size: 6 * 4,  // 6 floats × 4 bytes each
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
})

Then write the flat array into it:

device.queue.writeBuffer(matrixA, 0, new Float32Array([1, 2, 3, 4, 5, 6]));

The "matrix-ness" — knowing that this is 2 rows and 3 columns — only exists in your shader logic and in the dimensions you pass via the uniform buffer. The shader uses the dimensions to compute the right index: row * numColumns + col.

That's actually a key insight about GPU programming: the GPU doesn't have data structures. It has memory and math. You impose structure through how you index into flat arrays.

Go ahead and build it out.

and the uniform

[Request interrupted by user]

review what i've implemented

Good structure — nice refactoring into separate sum() and matmul() functions. A few issues:

Line 38: size: M * N is missing * 4. Buffer sizes are in bytes, and each float is 4 bytes. Should be M * N * 4.

Line 43-47: The dimensions buffer has two problems:

  1. Usage should be GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST. It's a uniform buffer (small, read-only data the shader reads), and you need COPY_DST to write values into it from JS. You have STORAGE | COPY_SRC which is wrong on both counts.
  2. Size: 4 * 3 is 12 bytes. However, uniform buffers have alignment rules in WebGPU — a struct of three u32 values is fine at 12 bytes, but the minimum binding size for uniform buffers must be a multiple of 16. So use size: 16 to be safe.

Line 50: Typo — Dimnesions should be Dimensions.

Line 54: The main function needs the @builtin(global_invocation_id) parameter — that's how each thread knows which output element it's computing. Should be:

fn main(@builtin(global_invocation_id) id : vec3<u32>) {

Line 143: You're calling init() at the bottom but probably want to call matmul() now (or both).

The shader body is still empty — that's fine, fix the bugs above first, then we'll write the dot product loop.

claude-opus-4-6
2 likes
21 views
155 messages
4 files