Posted on 2025-11-07

NanoGPT-inference - Tensor Parallelism

Scaling inference across multiple GPUs

Up to now, we did all calculations on one GPU. However, there might be several reasons why you want to distribute the work over multiple GPUs:

  • Memory limitations: If the model is too large to fit on a single GPU, you can distribute the model across multiple GPUs. This is quite common, for instance the popular Llama 3 70B model has 140 GB of parameters in bfloat16, while h100's and a100's only have 80 GB of VRAM. Armed with years of CS expertise, I can inform you that 140 is more than 80, so that will not fit on a single GPU. For models like Deepseek R1, this is even more of a problem, as it has 685B parameters[1].

    Even if the model itself fits, you also need space for activations and more importantly KV cache. As we saw earlier, for bigger batches our KV cache quickly becomes too large to fit on a single GPU.

  • Performance limitations: If our model and all other data fit, we still might want to distribute the work over multiple GPUs. Matrix multiplications are slow, memory bandwidth is even slower, and the users are not waiting forever.

GPU communications

To distribute the work over multiple GPUs, we need to communicate some data between the GPUs. Just like we cared a lot about memory bandwidth with one GPU, we also care a lot about the data transfer speed between GPUs.

Row of NVIDIA V100 GPUs
Schematic diagram of DGX-1 from the White Paper "NVIDIA DGX-1 With Tesla V100 System Architecture" credits: Nvidia Corporation.

Modern GPUs are fairly well-connected. They have a PCIe bus (64GB/s per direction for PCIe 5.0, bidirectional) that can connect them to other GPUs or to the CPU. They also have an NVLink bus (450 GB/s per direction for the h100s[2], bidirectional) that can connect them directly to other GPUs without passing through the CPU or PCIe switches. Beyond 1 node, they can use the InfiniBand interconnect to connect to other nodes.

As you can see, comms speed is a bit slower than our memory bandwidth, so we need to be careful what we want to transfer between GPUs.

Experimental setup

Before we discuss the different layer tensor parallelism setups, some confessions:

Firstly, our gpt2-xl model is not a good fit for tensor parallelism. It has only 1.5B parameters and one of the matrices actually can't be split into two parts without padding[4]. So I used a 4B model of the same architecture. Sadly there are no weights for this size, so we're using randomly initialized weights. For the benchmark this is not a problem since we were not even detokenizing the generations anyways.

Secondly, I do not have the most interesting hardware to run this on. I use two L40S GPUs, which are not the fastest GPUs around and, more importantly, they are not connected to each other. So our comms is limited by the PCIe bus, which is only 64GB/s per direction. While this is sad, it actually does help as motivation for the mixture of experts (MoE) in one of the next posts.

Thirdly, you need a different way of launching this script since we need to run it on two GPUs. The easiest way is to use torchrun.

torchrun --nproc_per_node=2 benchmark.py

Tensor Parallelism

Lets dive into the most interesting way of distributing the work over multiple GPUs, tensor parallelism, since it solves our model not fitting into memory[3].

Tensor parallelism is a technique where we split the model's parameters into multiple parts and assign each part to a different GPU. This way we can fit the model into memory on each GPU. Depending on the type of layer, we can split the parameters in different ways. For instance, for the attention layer, we can move different attention heads to different GPUs.

Let's go over this layer by layer.

Number of parameters in different model layers depending on the model size.

Embedding and LM head

The embedding layer is just a simple embedding lookup and the LM head is a simple linear layer. These both happen once per model forward pass and as you can see in the parameter distribution chart, they don't take up much space, so we can just do them on per GPU. For bigger models, it usually makes sense to split the embedding layer across GPUs, but given our poor bandwidth, I am skipping it.

MLP

Each Multi-Layer Perceptron (MLP) has two fully connected layers, at least in our NanoGPT model. These MLPs are actually taking up the bulk of the model's parameters, especially for bigger models. If we need to save VRAM, this is the quickest win.

Each matrix multiplication can be done in parts on different GPUs, and the results get aggregated back together. This is called tensor parallel and can be done either by distributing either the rows or columns of the matrices to different GPUs. For two GPUs, that looks like this:

Turns out, there is an natural way to split the MLP's parameters over two or more GPUs: first splitting the upprojection $W_1$ column-wise and then the downprojection $W_2$ row-wise. This way, we only need to communicate the results of the second matrix multiplication $Y_2$, by doing an all_reduce operation where we just sum up the results of the two GPUs.

For a more in-depth explanation, see the Hugging Face Ultrascale Playbook.

Aactually implementing this is quite simple in pytorch thanks to their distributed tensor (DTensor) and tensor parallel support:

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))

block.mlp = parallelize_module(
    module=block.mlp,
    device_mesh=device_mesh,
    parallelize_plan={
        "c_fc": ColwiseParallel(),
        "c_proj": RowwiseParallel(),
    }
)

Here, we are using the parallelize_module function to parallelize the MLP layer and replace it in-place. We are splitting the upprojection $W_1$, called c_fc in the code, column-wise and the downprojection $W_2$, called c_proj in the code, row-wise. That's basically it.

Attention

Tensor parallelism for the attention layer is a bit more interesting. The naive way (i.e. how I did it on the first try) is to only split the QKV projections across GPUs; notice that this is very similar to our MLP splitting. So while this is easy to implement, we only distribute the matmul (and the weights), but not the attention calculation itself, or the KV cache. So in general this not a good idea as there are bigger bottlenecks, especially for longer sequences.

The less naive way is to split the attention heads across GPUs. Conceptually, we can split each of the K, V and Q matrices across GPUs. So each head is split into two parts, one for each GPU. It does require a rewrite of the attention mechanism to split our matrices, as we were using one unified matrix of QKV and we don't want to mess up the splitting. Here is the code:

block.attn = parallelize_module(
    module=block.attn,
    device_mesh=device_mesh,
    parallelize_plan={
        "wq": ColwiseParallel(),
        "wk": ColwiseParallel(),
        "wv": ColwiseParallel(),
        "c_proj": RowwiseParallel(),
    }
)

block.attn.n_head = self.original_n_head // world_size

One thing to note is that, despire most examples online using use_local_output=False to produce a DTensor instead of a regular Tensor, we also use the defauls ColWiseParallel() for the QKV projections. I am personally not sure why people would use the DTensor output here, since it does not play nicely with the KV cache as that is using DTensor's global shapes.

Anyway, let's take a look how all of it fares with longer sequences as input and output, for a bigger model (4.4B). I only include sampling aside from the tensor parallel results, as there is no point in torturing our GPUs too much with the original baseline. After all, we already know that doesn't scale for an input length $L=32$ and output length $O=100$, let alone $L=512$ and $O=512$. So I only ran the tensor parallel results.

Throughput per request vs batch size for tensor parallelism on MLP and Attention layers.
Throughput per request for tensor parallelism on MLP and Attention layers.

So parallizing only the MLP layer gives us a decent boost for small batches sizes and we save some GPU memory, but for long sequences we are bottlenecked by the attention calculation and the KV cache. Parallelizing the attention layer as well gives us a much bigger boost, but I don't think we're gonna make a lot of money running this inference service.

For shorter sequences (but still our 4B model), the performance is quite nice with over 10k tokens per second in total.

Throughput per request vs batch size for tensor parallelism on MLP and Attention layers.
Throughput per request for tensor parallelism on MLP and Attention layers.

However, the real improvement is found on the memory usage (per GPU) chart, where we see we basically halve the memory usage per GPU. This allows us to serve bigger models, bigger batches and longer sequences. Look at how all previous curves ended at a batch size of 256 because we ran out of memory.

While tensor parallelism is a nice way to scale inference across multiple GPUs, it does have some limitations. The more GPUs we use, the more fragmented our model becomes and the more our communication speed matters. This does not scale indefinitely, so

Linked publications