DeployU
Interviews / AI & MLOps / How do you do distributed training in PyTorch?

How do you do distributed training in PyTorch?

practical Distributed Training Interactive Quiz Code Examples

The Scenario

You are an ML engineer at a research lab. You are working on a new language model that has over 100 billion parameters. The model is too large to fit on a single GPU, and the dataset is too large to train on a single machine.

You have access to a cluster of 8 machines, each with 8 NVIDIA A100 GPUs. Your task is to come up with a strategy for training the model on this cluster.

The Challenge

Explain your strategy for training this model on the cluster. What are the different distributed training strategies that you would use, and how would you combine them to achieve the best results?

Wrong Approach

A junior engineer might only be aware of `DataParallel`. They might not know how to handle a model that is too large to fit on a single GPU, and they might not be aware of the different distributed training strategies that are available in PyTorch.

Right Approach

A senior engineer would know that this problem requires a combination of data parallelism and model parallelism. They would be able to explain how to use PyTorch's `DistributedDataParallel` and how to implement model parallelism, and they would have a clear plan for how to combine them to achieve the best results.

Step 1: Choose the Right Strategy

The first step is to choose the right distributed training strategy.

StrategyDescriptionWhen to use it
DataParallelImplements data parallelism on a single machine with multiple GPUs.When you have a small model and want a simple way to speed up training on a single machine.
DistributedDataParallelImplements data parallelism on a single machine or multiple machines.When you want the best performance for data parallelism.
Model ParallelismSplits the model itself across multiple devices.When the model is too large to fit on a single device.

For our use case, we need to use a combination of data parallelism and model parallelism. We will use DistributedDataParallel to distribute the training across the 8 machines, and we will use model parallelism to split the model across the 8 GPUs on each machine.

Step 2: Implement Model Parallelism

To implement model parallelism, we need to manually assign the different layers of the model to different GPUs.

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 20).to("cuda:0")
        self.fc2 = nn.Linear(20, 1).to("cuda:1")

    def forward(self, x):
        x = self.fc1(x.to("cuda:0"))
        x = self.fc2(x.to("cuda:1"))
        return x

Step 3: Implement Data Parallelism

To implement data parallelism, we can use DistributedDataParallel.

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def demo_basic(rank, world_size):
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = MyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    # ... (your training loop) ...

    cleanup()

Step 4: Combine the Strategies

By combining data parallelism and model parallelism, we can train our 100-billion-parameter model on the cluster of 8 machines. We would use DistributedDataParallel to replicate the model-parallel model on each machine, and then we would use the DistributedSampler to feed a different slice of the data to each machine.

Practice Question

You are training a model on a single machine with multiple GPUs and want the best possible performance. Which distributed training strategy would you use?