DeployU
Interviews / AI & MLOps / You're running out of memory while fine-tuning a large model. How do you solve this?

You're running out of memory while fine-tuning a large model. How do you solve this?

practical Performance Optimization Interactive Quiz Code Examples

The Scenario

You are an ML engineer at a startup that is building a new AI-powered code generation tool. You are trying to fine-tune the bigcode/starcoder model (15.5B parameters) on a custom dataset of Python code.

You have access to a single server with an NVIDIA A100 GPU with 40GB of memory. However, when you try to fine-tune the model, you get the following error:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 40.00 GiB total capacity; 37.12 GiB already allocated; 1.88 GiB free; 37.12 GiB allowed in total)

Your manager has told you that you cannot get a bigger GPU, so you need to find a way to fine-tune the model with the resources you have.

The Challenge

What is your strategy for fine-tuning this large model on a single GPU? Explain the different memory-saving techniques you would use and how you would combine them to achieve the best results.

Wrong Approach

A junior engineer might give up and say that it's not possible to fine-tune such a large model on a single GPU. They might not be aware of the latest memory-saving techniques, or they might not know how to combine them effectively.

Right Approach

A senior engineer would know that it is possible to fine-tune a large model on a single GPU with the right combination of techniques. They would be able to explain the different memory-saving techniques that are available and would have a clear plan for how to use them to solve this problem.

Step 1: Quantization

The first step is to load the model in a lower-precision format. We can use 4-bit quantization to significantly reduce the memory footprint of the model.

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "bigcode/starcoder"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, device_map="auto")

Step 2: Parameter-Efficient Fine-tuning (PEFT)

Instead of fine-tuning all the weights of the model, we can use a PEFT technique like LoRA to only train a small number of additional parameters.

from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

Step 3: Gradient Accumulation

To avoid running out of memory during the backward pass, we can use gradient accumulation to simulate a larger batch size.

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="starcoder-finetuned",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    # ...
)

Step 4: Memory-Efficient Optimizer

Finally, we can use a memory-efficient optimizer, like AdamW8bit from the bitsandbytes library, to reduce the memory usage of the optimizer.

Summary of Techniques

TechniqueHow it worksProsCons
QuantizationLoads the model’s weights in a lower-precision format (e.g., 4-bit or 8-bit).Drastically reduces the model’s memory footprint.Can lead to a small drop in performance.
PEFTFreezes the weights of the pre-trained model and only trains a small number of additional parameters.Much more memory-efficient and faster to train than full fine-tuning.Might not perform as well as full fine-tuning on some tasks.
Gradient AccumulationAccumulates the gradients for several smaller batches and then performs a single update.Allows you to use a very small batch size without sacrificing the effective batch size.Can slow down training if the accumulation steps are too high.
Memory-Efficient OptimizerUses a more memory-efficient algorithm to store the optimizer’s state.Can significantly reduce the memory usage of the optimizer.Might not be as well-tested as the standard AdamW optimizer.

By combining these techniques, we can successfully fine-tune the 15.5B parameter StarCoder model on a single 40GB A100 GPU.

Practice Question

You are trying to fine-tune a large model, but you are still running out of memory even after using 4-bit quantization and PEFT. What is the next thing you should try?