Questions
How do you save and load a model in PyTorch?
The Scenario
You are an ML engineer at a smart home company. You have just finished training a new model that can detect whether a person in a video feed is a resident of the home or an intruder.
You need to save the model for three different purposes:
- Continued training: You want to be able to save a checkpoint of the model during training so that you can resume training later if it is interrupted.
- Production deployment: You want to deploy the model to a production server so that it can be used for inference by the company’s mobile app.
- Edge deployment: You want to deploy the model to a small edge device (like a smart camera) that has limited resources.
The Challenge
Explain your strategy for saving the model for each of these three purposes. What are the different saving strategies that you would use, and what are the trade-offs between them?
A junior engineer might just use `torch.save(model, 'model.pt')` for all three use cases. They might not be aware of the difference between saving the entire model and saving only the state dictionary, and they might not know how to save a checkpoint for resumable training.
A senior engineer would know that different use cases require different saving strategies. They would be able to explain the trade-offs between saving the entire model and saving only the state dictionary, and they would know how to save a checkpoint for resumable training and how to export a model to a format that is suitable for edge deployment.
Step 1: Choose the Right Strategy for Each Use Case
| Use Case | Recommended Strategy | Why? |
|---|---|---|
| Continued training | Save a checkpoint dictionary with the model’s state dictionary, the optimizer’s state dictionary, the epoch number, and the loss. | This allows you to resume training from exactly where you left off. |
| Production deployment | Save only the model’s state dictionary. | This is the most flexible approach, as it allows you to load the model’s weights into any model that has the same architecture. It also decouples the model’s code from its weights. |
| Edge deployment | Export the model to the ONNX format. | ONNX is a standard format for representing machine learning models that can be run on a variety of different devices. |
Step 2: Save the Model
Here’s how we can save the model in each format:
1. Save a checkpoint for resumable training:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, "my_checkpoint.pt")2. Save the state dictionary for production deployment:
torch.save(model.state_dict(), "my_model_state.pt")3. Export the model to the ONNX format for edge deployment:
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "my_model.onnx")Step 3: Load the Model
Here’s how we can load the model in each format:
1. Load a checkpoint for resumable training:
checkpoint = torch.load("my_checkpoint.pt")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']2. Load the state dictionary for production deployment:
model = MyModel()
model.load_state_dict(torch.load("my_model_state.pt"))
model.eval()3. Load the ONNX model for edge deployment:
You can use a runtime like ONNX Runtime to load and run the model on an edge device.
Practice Question
You want to share your trained model with a colleague who is working on a different project. Which of the following would be the best way to save your model?