DeployU
Interviews / AI & MLOps / How do you save and load a model in PyTorch?

How do you save and load a model in PyTorch?

practical Model Lifecycle Interactive Quiz Code Examples

The Scenario

You are an ML engineer at a smart home company. You have just finished training a new model that can detect whether a person in a video feed is a resident of the home or an intruder.

You need to save the model for three different purposes:

  1. Continued training: You want to be able to save a checkpoint of the model during training so that you can resume training later if it is interrupted.
  2. Production deployment: You want to deploy the model to a production server so that it can be used for inference by the company’s mobile app.
  3. Edge deployment: You want to deploy the model to a small edge device (like a smart camera) that has limited resources.

The Challenge

Explain your strategy for saving the model for each of these three purposes. What are the different saving strategies that you would use, and what are the trade-offs between them?

Wrong Approach

A junior engineer might just use `torch.save(model, 'model.pt')` for all three use cases. They might not be aware of the difference between saving the entire model and saving only the state dictionary, and they might not know how to save a checkpoint for resumable training.

Right Approach

A senior engineer would know that different use cases require different saving strategies. They would be able to explain the trade-offs between saving the entire model and saving only the state dictionary, and they would know how to save a checkpoint for resumable training and how to export a model to a format that is suitable for edge deployment.

Step 1: Choose the Right Strategy for Each Use Case

Use CaseRecommended StrategyWhy?
Continued trainingSave a checkpoint dictionary with the model’s state dictionary, the optimizer’s state dictionary, the epoch number, and the loss.This allows you to resume training from exactly where you left off.
Production deploymentSave only the model’s state dictionary.This is the most flexible approach, as it allows you to load the model’s weights into any model that has the same architecture. It also decouples the model’s code from its weights.
Edge deploymentExport the model to the ONNX format.ONNX is a standard format for representing machine learning models that can be run on a variety of different devices.

Step 2: Save the Model

Here’s how we can save the model in each format:

1. Save a checkpoint for resumable training:

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, "my_checkpoint.pt")

2. Save the state dictionary for production deployment:

torch.save(model.state_dict(), "my_model_state.pt")

3. Export the model to the ONNX format for edge deployment:

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "my_model.onnx")

Step 3: Load the Model

Here’s how we can load the model in each format:

1. Load a checkpoint for resumable training:

checkpoint = torch.load("my_checkpoint.pt")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

2. Load the state dictionary for production deployment:

model = MyModel()
model.load_state_dict(torch.load("my_model_state.pt"))
model.eval()

3. Load the ONNX model for edge deployment:

You can use a runtime like ONNX Runtime to load and run the model on an edge device.

Practice Question

You want to share your trained model with a colleague who is working on a different project. Which of the following would be the best way to save your model?