In the previous post, we improved our inference engine by caching the KV's for the previous tokens. This allowed us to process many more requests in parallel, but we're still bottlenecked.
A methodological way of finding bottlenecks is to use a profiler.
This way we can see exactly where the time is spent and optimize it. The
benchmark.py script from our codebase has some options for
profiling using pytorch's profiler. If you run it with the names of the
inference engines that you want to profile, you'll get a json that you
can view on perfetto. For one
.generate() call, I got the following profile:
We clearly see the first step takes quite long, that's the prefill phase. The second to 100th step are the decode phase, which are much faster per step. Let's zoom in on one decode step:
It's a bit messy, buy you see the different layers with the attention
and MLP computations. At the end we also see the top_k
filtering, sorting and top_p sampling. Let's zoom in on that part:
This part is in total 1.77ms, which is honestly not negligible. The problem is also that we're sorting, which gets more expensive for larger batch sizes and bigger vocabularies. And both trends we see in LLMs. So we need to find a better way to sample.
FlashInfer implements a rejection sampling kernel that is fused with the top_k filtering and top_p sampling. The nice thing about this is that we don't need to sort anymore, which is a big win. I refer to their blogpost on top-p rejection sampling for the details, but the gist is that we generate a random number and sample a token based on it's probability with some smart rejection of low-probability tokens. On top of that, this is implemented in a fused kernel, so we can expect a nice speedup just from that.
As you see, the change is literally calling the
top_k_top_p_sampling_from_logits function. Let's see how it
performs:
For smaller batches we don't see a big improvement, but for larger batches we get 5-10 tokens per second per request back, which is a nice improvement. With 512 requests in parallel, we get more than +2k tokens per second, which is a nice improvement. Not to bad for changing a single function call.
This kernel fusion is quite common. Whenever you have several GPU kernels that are always executed one after the other, it’s often possible to combine them into a single kernel to prevent the overhead of waiting for each kernel to finish before the next one can be launched. This can significantly boost performance.