Posted on 2025-10-28

NanoGPT-inference - CUDA Graphs

Optimize GPU execution with CUDA graphs

In the previous post, we found a few thousand tokens/sec by using a fused kernel for rejection sampling. We can actually take this idea further by trying to "fuse" our entire forward pass into a single CUDA graph. This way we can launch a single graph and avoid the overhead of launching multiple kernels.

While this is nice in theory, CUDA graphs in Pytorch are not so flexible as just using the model's forward pass. The most noticable limitations are that we can't use dynamic shapes and we can't use control flow (i.e. if statements).

The dynamic shapes are quite unfortunate, as we want to be able to handle variable-length inputs and we we increase our input length with every token we generate. Aside from the input, our KV cache is also dynamic, as we add more tokens to it with every step.

Proper inference engines like vLLM solve this by using fixed, pre-allocated tensors. [1] We are not building a fully-featured inference engine, so for educational purposes we'll resort to allocating the entire KV cache at once and using a boolean tensor as the attention mask.

For the input itself, we're actually quite lucky. Since we only care about the last token, our shape during decode is always $(B, 1)$. During prefill, it's $(B, T)$ with $T$ being the input length. That is always the same for all requests in the batch in our crappy engine, so yay.

CUDA graphs in Pytorch

The way these CUDA graphs work is that we first have to create our graph as usual, meaning we have our model.forward() call and then record all the operations we want to execute in that graph. We can then replay this graph and enjoy a speedup.

One thing to note is that Pytorch has both torch.compile() and torch.cuda.CUDAGraph(). The compile variant is a bit more flexible and can do quite some magic (especially with torch.compile(mode="max-autotune")), but it's more opaque what is happening. The CUDAGraph variant is clear: we record the graph and then replay it. We have to manually keep track of the input shapes and replay the correct graph, but that's part of the fun.

While the .generate() function did not change a lot, we had to change quite a few things on the side of the model/engine architecture. The KV cache is moved out of the attention class and is now completely separated, so we allocate it once and pass it along during the forward pass. And similarly, we calculate an attention mask (either a triangular mask for prefill or a rectangular mask for decode) and pass that as input as well.

The generation itself is otherwise quite similar, we mostly need to track which cuda graphs we have based on the input shape (so $(B, 1)$ for decode and $(B, T)$ for prefill) and replay the correct one. That's literally a dictionary lookup and a call to graph.replay().

To capture the graph, we have a context manager with torch.cuda.graph(graph): ... that captures all the operations inside the block into the graph. On top of that, we also need to keep references to the different input tensors, as those memory adresses will be used for replaying the graph.

Throughput vs batch size with CUDA graphs.
Throughput vs batch size with CUDA graphs.

Linked publications