Questions
You're running out of memory while fine-tuning a large model. How do you solve this?
The Scenario
You are an ML engineer at a startup that is building a new AI-powered code generation tool. You are trying to fine-tune the bigcode/starcoder model (15.5B parameters) on a custom dataset of Python code.
You have access to a single server with an NVIDIA A100 GPU with 40GB of memory. However, when you try to fine-tune the model, you get the following error:
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 40.00 GiB total capacity; 37.12 GiB already allocated; 1.88 GiB free; 37.12 GiB allowed in total)
Your manager has told you that you cannot get a bigger GPU, so you need to find a way to fine-tune the model with the resources you have.
The Challenge
What is your strategy for fine-tuning this large model on a single GPU? Explain the different memory-saving techniques you would use and how you would combine them to achieve the best results.
A junior engineer might give up and say that it's not possible to fine-tune such a large model on a single GPU. They might not be aware of the latest memory-saving techniques, or they might not know how to combine them effectively.
A senior engineer would know that it is possible to fine-tune a large model on a single GPU with the right combination of techniques. They would be able to explain the different memory-saving techniques that are available and would have a clear plan for how to use them to solve this problem.
Step 1: Quantization
The first step is to load the model in a lower-precision format. We can use 4-bit quantization to significantly reduce the memory footprint of the model.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "bigcode/starcoder"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, device_map="auto")Step 2: Parameter-Efficient Fine-tuning (PEFT)
Instead of fine-tuning all the weights of the model, we can use a PEFT technique like LoRA to only train a small number of additional parameters.
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)Step 3: Gradient Accumulation
To avoid running out of memory during the backward pass, we can use gradient accumulation to simulate a larger batch size.
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="starcoder-finetuned",
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
# ...
)Step 4: Memory-Efficient Optimizer
Finally, we can use a memory-efficient optimizer, like AdamW8bit from the bitsandbytes library, to reduce the memory usage of the optimizer.
Summary of Techniques
| Technique | How it works | Pros | Cons |
|---|---|---|---|
| Quantization | Loads the model’s weights in a lower-precision format (e.g., 4-bit or 8-bit). | Drastically reduces the model’s memory footprint. | Can lead to a small drop in performance. |
| PEFT | Freezes the weights of the pre-trained model and only trains a small number of additional parameters. | Much more memory-efficient and faster to train than full fine-tuning. | Might not perform as well as full fine-tuning on some tasks. |
| Gradient Accumulation | Accumulates the gradients for several smaller batches and then performs a single update. | Allows you to use a very small batch size without sacrificing the effective batch size. | Can slow down training if the accumulation steps are too high. |
| Memory-Efficient Optimizer | Uses a more memory-efficient algorithm to store the optimizer’s state. | Can significantly reduce the memory usage of the optimizer. | Might not be as well-tested as the standard AdamW optimizer. |
By combining these techniques, we can successfully fine-tune the 15.5B parameter StarCoder model on a single 40GB A100 GPU.
Practice Question
You are trying to fine-tune a large model, but you are still running out of memory even after using 4-bit quantization and PEFT. What is the next thing you should try?