DeployU
Interviews / AI & MLOps / How do you use `Dataset` and `DataLoader` to load and process data?

How do you use `Dataset` and `DataLoader` to load and process data?

practical Data Handling Interactive Quiz Code Examples

The Scenario

You are an ML engineer at a healthcare company. You are working on a project to build a model that can detect tumors in MRI scans. You have been given a dataset of 100,000 MRI scans in the DICOM format.

The data is very messy:

  • The scans are of different sizes and resolutions.
  • Some of the scans are corrupted and cannot be opened.
  • The labels are stored in a separate CSV file and are not always consistent with the file names.

Your task is to build a data processing pipeline that can clean and pre-process this data so that it can be used to train a deep learning model. The pipeline must be efficient and scalable, and it must be able to handle the large size of the dataset.

The Challenge

Explain how you would use torch.utils.data.Dataset and torch.utils.data.DataLoader to build a data processing pipeline for this task. What are the key features of these classes that you would use, and how would you use them to address the challenges of this dataset?

Wrong Approach

A junior engineer might try to load the entire dataset into memory at once, which would be impossible for a dataset of this size. They might also try to write their own data loading and processing code from scratch, which would be time-consuming and error-prone.

Right Approach

A senior engineer would know that `Dataset` and `DataLoader` are the right tools for this job. They would be able to explain how to use them to build an efficient and scalable data processing pipeline, and they would have a clear plan for how to address the specific challenges of this dataset.

Step 1: Create a Custom Dataset

The first step is to create a custom Dataset class to handle the messy data.

MethodDescription
__init__(self)Initializes the dataset. This is where you would load the labels from the CSV file and create a list of file paths.
__len__(self)Returns the total number of examples in the dataset.
__getitem__(self, index)Returns the example at the given index. This is where you would load the DICOM file, handle any corrupted files, and apply any transforms.
import torch
from torch.utils.data import Dataset
import pydicom
import pandas as pd

class MRIDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.labels = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.labels.iloc[idx, 0])
        try:
            image = pydicom.dcmread(img_path).pixel_array
        except:
            # Handle corrupted files by returning a dummy image and label
            return torch.zeros((1, 256, 256)), -1

        label = self.labels.iloc[idx, 1]

        if self.transform:
            image = self.transform(image)

        return image, label

Step 2: Create a DataLoader

The next step is to create a DataLoader to iterate over the dataset in batches.

ArgumentDescription
batch_sizeThe number of examples in each batch.
shuffleWhether to shuffle the data at the beginning of each epoch.
num_workersThe number of worker processes to use for data loading. This can significantly speed up data loading.
collate_fnA function that specifies how to merge a list of samples to form a mini-batch.
pin_memoryIf True, the data loader will copy tensors into CUDA pinned memory before returning them. This can speed up data transfer to the GPU.
from torch.utils.data import DataLoader

def collate_fn(batch):
    # Remove corrupted samples
    batch = list(filter(lambda x: x[1] != -1, batch))
    return torch.utils.data.dataloader.default_collate(batch)

dataset = MRIDataset(...)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, collate_fn=collate_fn, pin_memory=True)

Step 3: Iterate Over the Data

Finally, we can iterate over the data in the DataLoader.

for images, labels in dataloader:
    # ... your training code ...

By using Dataset and DataLoader, we can build an efficient and scalable data processing pipeline that can handle the challenges of our messy dataset.

Practice Question

You are working with a dataset of variable-sized images. Which `DataLoader` argument would you use to handle this?