Questions
How do you do distributed training in TensorFlow?
The Scenario
You are an ML engineer at a research lab. You are working on a new language model that has over 100 billion parameters. The model is too large to fit on a single GPU, and the dataset is too large to train on a single machine.
You have access to a cluster of 8 machines, each with 8 NVIDIA A100 GPUs. Your task is to come up with a strategy for training the model on this cluster.
The Challenge
Explain your strategy for training this model on the cluster. What are the different distributed training strategies that you would use, and how would you combine them to achieve the best results?
A junior engineer might only be aware of data parallelism. They might not know how to handle a model that is too large to fit on a single GPU. They might also not be aware of the different distributed training strategies that are available in TensorFlow.
A senior engineer would know that this problem requires a combination of data parallelism and model parallelism. They would be able to explain how to use TensorFlow's `tf.distribute` API to implement both types of parallelism, and they would have a clear plan for how to combine them to achieve the best results.
Step 1: Choose the Right Strategy
The first step is to choose the right distributed training strategy.
| Strategy | Description | When to use it |
|---|---|---|
MirroredStrategy | Implements data parallelism on a single machine with multiple GPUs. | When your model fits on a single GPU, but you want to speed up training. |
MultiWorkerMirroredStrategy | Implements data parallelism on multiple machines. | When you want to speed up training by using multiple machines. |
TPUStrategy | Designed for training on TPUs. | When you have access to TPUs. |
ParameterServerStrategy | A data parallelism strategy where the model’s variables are placed on a central parameter server. | When you have a large number of workers and a slow network. |
| Model Parallelism | Splits the model itself across multiple devices. | When the model is too large to fit on a single device. |
For our use case, we need to use a combination of data parallelism and model parallelism. We will use MultiWorkerMirroredStrategy to distribute the training across the 8 machines, and we will use model parallelism to split the model across the 8 GPUs on each machine.
Step 2: Implement Model Parallelism
To implement model parallelism, we need to manually assign the different layers of the model to different GPUs.
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
with tf.device("/gpu:0"):
self.layer1 = MyLayer1()
with tf.device("/gpu:1"):
self.layer2 = MyLayer2()
# ...
def call(self, inputs):
x = self.layer1(inputs)
with tf.device("/gpu:1"):
x = self.layer2(x)
# ...
return xStep 3: Implement Data Parallelism
To implement data parallelism, we can use the MultiWorkerMirroredStrategy.
import tensorflow as tf
import os
import json
# Set up the TF_CONFIG environment variable
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": ["host1:port", "host2:port", ...]
},
"task": {"type": "worker", "index": 0}
})
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = MyModel()
optimizer = tf.keras.optimizers.Adam()
# ...
# Train the model as usual
model.fit(train_dataset, epochs=5)Step 4: Combine the Strategies
By combining data parallelism and model parallelism, we can train our 100-billion-parameter model on the cluster of 8 machines.
Practice Question
You are training a model that is too large to fit on a single GPU. Which distributed training strategy would you use?