DeployU
Interviews / AI & MLOps / What is the difference between `torch.nn.Module` and `torch.nn.Sequential`?

What is the difference between `torch.nn.Module` and `torch.nn.Sequential`?

conceptual Model Building Interactive Quiz Code Examples

The Scenario

You are an ML engineer at a self-driving car company. You are building a new computer vision model to detect pedestrians in a video feed. You have decided to use a ResNet architecture, which is known for its good performance on computer vision tasks.

A ResNet model consists of a series of residual blocks. Each residual block has a “skip connection” that adds the input of the block to its output. This allows the model to learn residual functions, which can make it easier to train very deep neural networks.

You need to decide whether to use torch.nn.Module or torch.nn.Sequential to build the ResNet model.

The Challenge

Explain the difference between torch.nn.Module and torch.nn.Sequential. Which one would you use to build the ResNet model, and why? Provide a code example that shows how to build a residual block using your chosen approach.

Wrong Approach

A junior engineer might try to use `torch.nn.Sequential` to build the ResNet model. They might not realize that `nn.Sequential` is not flexible enough to handle the skip connections in a ResNet model.

Right Approach

A senior engineer would know that `torch.nn.Module` is the correct choice for this task. They would be able to explain that `nn.Sequential` is only suitable for simple, linear stacks of layers, while `nn.Module` can be used to build any kind of model, including models with skip connections.

Step 1: nn.Module vs. nn.Sequential

Featurenn.Sequentialnn.Module
FlexibilityLimited, only works for linear stacks of layers.Very flexible, can be used to build any kind of model.
Control FlowDoes not support complex control flow.Supports complex control flow, such as if statements and for loops.
Use CasesSimple classification and regression tasks.Complex models with multiple inputs/outputs, shared layers, skip connections, etc.

Step 2: Choose the Right Tool for the Job

For our ResNet model, we need to use torch.nn.Module. This is because a ResNet model has skip connections, which are a form of non-linear control flow. torch.nn.Sequential is not flexible enough to handle this.

Step 3: Build the Residual Block

Here’s how we can build a residual block using torch.nn.Module:

import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

As you can see, the forward method of the ResidualBlock class implements the skip connection by adding the input of the block (x) to its output (out). This would not be possible with torch.nn.Sequential.

Practice Question

You are building a simple feed-forward neural network with a linear stack of layers. Which of the following would be the most appropriate tool for the job?