Questions
How do you do distributed training in PyTorch?
The Scenario
You are an ML engineer at a research lab. You are working on a new language model that has over 100 billion parameters. The model is too large to fit on a single GPU, and the dataset is too large to train on a single machine.
You have access to a cluster of 8 machines, each with 8 NVIDIA A100 GPUs. Your task is to come up with a strategy for training the model on this cluster.
The Challenge
Explain your strategy for training this model on the cluster. What are the different distributed training strategies that you would use, and how would you combine them to achieve the best results?
A junior engineer might only be aware of `DataParallel`. They might not know how to handle a model that is too large to fit on a single GPU, and they might not be aware of the different distributed training strategies that are available in PyTorch.
A senior engineer would know that this problem requires a combination of data parallelism and model parallelism. They would be able to explain how to use PyTorch's `DistributedDataParallel` and how to implement model parallelism, and they would have a clear plan for how to combine them to achieve the best results.
Step 1: Choose the Right Strategy
The first step is to choose the right distributed training strategy.
| Strategy | Description | When to use it |
|---|---|---|
DataParallel | Implements data parallelism on a single machine with multiple GPUs. | When you have a small model and want a simple way to speed up training on a single machine. |
DistributedDataParallel | Implements data parallelism on a single machine or multiple machines. | When you want the best performance for data parallelism. |
| Model Parallelism | Splits the model itself across multiple devices. | When the model is too large to fit on a single device. |
For our use case, we need to use a combination of data parallelism and model parallelism. We will use DistributedDataParallel to distribute the training across the 8 machines, and we will use model parallelism to split the model across the 8 GPUs on each machine.
Step 2: Implement Model Parallelism
To implement model parallelism, we need to manually assign the different layers of the model to different GPUs.
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 20).to("cuda:0")
self.fc2 = nn.Linear(20, 1).to("cuda:1")
def forward(self, x):
x = self.fc1(x.to("cuda:0"))
x = self.fc2(x.to("cuda:1"))
return xStep 3: Implement Data Parallelism
To implement data parallelism, we can use DistributedDataParallel.
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def demo_basic(rank, world_size):
setup(rank, world_size)
# create model and move it to GPU with id rank
model = MyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
# ... (your training loop) ...
cleanup()Step 4: Combine the Strategies
By combining data parallelism and model parallelism, we can train our 100-billion-parameter model on the cluster of 8 machines. We would use DistributedDataParallel to replicate the model-parallel model on each machine, and then we would use the DistributedSampler to feed a different slice of the data to each machine.
Practice Question
You are training a model on a single machine with multiple GPUs and want the best possible performance. Which distributed training strategy would you use?