DeployU
Interviews / AI & MLOps / How do you do distributed training in TensorFlow?

How do you do distributed training in TensorFlow?

practical Distributed Training Interactive Quiz Code Examples

The Scenario

You are an ML engineer at a research lab. You are working on a new language model that has over 100 billion parameters. The model is too large to fit on a single GPU, and the dataset is too large to train on a single machine.

You have access to a cluster of 8 machines, each with 8 NVIDIA A100 GPUs. Your task is to come up with a strategy for training the model on this cluster.

The Challenge

Explain your strategy for training this model on the cluster. What are the different distributed training strategies that you would use, and how would you combine them to achieve the best results?

Wrong Approach

A junior engineer might only be aware of data parallelism. They might not know how to handle a model that is too large to fit on a single GPU. They might also not be aware of the different distributed training strategies that are available in TensorFlow.

Right Approach

A senior engineer would know that this problem requires a combination of data parallelism and model parallelism. They would be able to explain how to use TensorFlow's `tf.distribute` API to implement both types of parallelism, and they would have a clear plan for how to combine them to achieve the best results.

Step 1: Choose the Right Strategy

The first step is to choose the right distributed training strategy.

StrategyDescriptionWhen to use it
MirroredStrategyImplements data parallelism on a single machine with multiple GPUs.When your model fits on a single GPU, but you want to speed up training.
MultiWorkerMirroredStrategyImplements data parallelism on multiple machines.When you want to speed up training by using multiple machines.
TPUStrategyDesigned for training on TPUs.When you have access to TPUs.
ParameterServerStrategyA data parallelism strategy where the model’s variables are placed on a central parameter server.When you have a large number of workers and a slow network.
Model ParallelismSplits the model itself across multiple devices.When the model is too large to fit on a single device.

For our use case, we need to use a combination of data parallelism and model parallelism. We will use MultiWorkerMirroredStrategy to distribute the training across the 8 machines, and we will use model parallelism to split the model across the 8 GPUs on each machine.

Step 2: Implement Model Parallelism

To implement model parallelism, we need to manually assign the different layers of the model to different GPUs.

import tensorflow as tf

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        with tf.device("/gpu:0"):
            self.layer1 = MyLayer1()
        with tf.device("/gpu:1"):
            self.layer2 = MyLayer2()
        # ...

    def call(self, inputs):
        x = self.layer1(inputs)
        with tf.device("/gpu:1"):
            x = self.layer2(x)
        # ...
        return x

Step 3: Implement Data Parallelism

To implement data parallelism, we can use the MultiWorkerMirroredStrategy.

import tensorflow as tf
import os
import json

# Set up the TF_CONFIG environment variable
os.environ["TF_CONFIG"] = json.dumps({
    "cluster": {
        "worker": ["host1:port", "host2:port", ...]
    },
    "task": {"type": "worker", "index": 0}
})

strategy = tf.distribute.MultiWorkerMirroredStrategy()

with strategy.scope():
    model = MyModel()
    optimizer = tf.keras.optimizers.Adam()
    # ...

# Train the model as usual
model.fit(train_dataset, epochs=5)

Step 4: Combine the Strategies

By combining data parallelism and model parallelism, we can train our 100-billion-parameter model on the cluster of 8 machines.

Practice Question

You are training a model that is too large to fit on a single GPU. Which distributed training strategy would you use?