Deep Learning - PyTorch Data Loading

RESNET-50 Data Loading, Data Transforms, Custom Data Loading

Posted by Rico's Nerd Cluster on May 11, 2022

Dataset and Data Loading

Data Set and Data Loading in All-Together In Torch

In PyTorch, data is stored in the DataSet object. We can read input data all together, or read them one by one. Then, for many tasks, one may need to apply transforms at each loading call for data augmentation. In a real life application, you might see dataloading something like:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class MyDataset(Dataset)
    def __init__(self, data_path:str, transform=None):
        self.transform = transform
        self.data_path = data_path
        # for example, number of examples
        self.metadata = read_csv_metadata(data_path)
    def __getitem__(self, idx):
        sample_info = self.metadata[idx]
        sample = read_single_sample(sample_info)
        if self.transform:
            sample = self.transform(sample)
        return sample
    def __len__(self):
        return self.metadata.length

train_dataloader = DataLoader(
    train_data,
    sampler=train_sampler,
    batch_size=batch_size,
    num_workers=2,
    pin_memory=True,
)

# in training, send a tensor to batch 
src_batch = src_batch.to(device, non_blocking=True)
  • pin_memory=True is in general recommended (unless there’s a CUDA memory shortage). Here is why
  • To better work with multi-worker dataloading, it’s best to read a single sample into the dataset in __getitem__(self, idx)

(NOT RECOMMENDED) A slower yet naive alternative to dataset is we can create a dataset by passing established tensors into a TensorDataset object:

1
2
3
train_data = TensorDataset(
    torch.LongTensor(input_ids), torch.LongTensor(target_ids)
)

Dataset is NOT Supposed to Move Tensors Onto CUDA

If you see this CUDA error:

1
2
3
4
RuntimeError: CUDA error: initialization error
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

During dataloading:

1
2
3
4
train_dataloader = DataLoader(
...
num_workers=2,
)

It’s because CUDA tensors are not designed to be initialized by multiple processes:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = targets
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        if self.transform:
            x = self.transform(x)
        # Do not move to GPU here!!
        # return x.to(device), y.to(device)  
        # Do this 
        return x, y

Naive Data Loading

After that, in PyTorch, we will have dataloading. A naive dataloader reads input data one-by-one, then return to the user. Some notable points include:

  • DataLoader uses a random sampler to determine which single samples go in batches.
  • collate_fn takes in single samples, outputs a tensor. From the PyTorch documentation,
1
2
3
# Collate_fn equivalent
for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])
  • There’s a default function for it in PyTorch, but one can do things like padding input text (for NLP tasks)
1
2
3
4
5
6
# Padding during Collate in NLP tasks
def custom_collate_fn(batch):
    inputs, labels = zip(*batch)
    inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
    labels = torch.tensor(labels)
    return inputs, labels

So all-together, a naive dataloader is equivalent to:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class NaiveDataLoader:
    def __init__(self, dataset, batch_size=64, collate_fn=default_collate):
        self.dataset = dataset
        self.batch_size = batch_size
        self.collate_fn = collate_fn
        self.index = 0
    def __iter__(self):
        """
        Called to return an iterable, so it's only called once when iterable is to be returned.
        """
        self.index = 0
        return self
    def get(self):
        """
        Reutrn a single next item 
        """
        item = self.dataset[self.index]
        self.index += 1
        return item
    def __next__(self):
        """
        Create a tensor with single inputs
        """
        if self.index >= len(self.dataset):
            raise StopIteration
        batch_size = min(len(self.dataset) - self.index, self.batch_size)
        return self.collate_fn([self.get() for _ in range(batch_size)])

# Naively, the dataloader just needs a list
dataset = list(range(16))
dataloader = NaiveDataLoader(dataset, batch_size=8)
for batch in dataloader:
    print(batch)

Multi-Process Data Loading

DataLoading can be done using multiple subprocesses (workers), and moving them to CUDA could be parallelized. Here, we have a simple implementation inspired by this post

Some notable points include:

  • A worker is a long-running subprocess can load data asynchronously. It checks its input queue, loads (and transforms) single input data, puts them in a tensor, then put the tensor on an output queue
  • One big feature in Multi-processing data loading is pre-fetching. Everytime before we return a single item for a batch, we make sure that in total, we have read in 2 * num_workers * batch_size inputs asynchronously to speed up.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
def worker_fn(dataset, index_queue, output_queue):
    while True:
        index = index_queue.get()
        # TODO: collate function?
        output_queue.put((index, dataset[index]))

class DataLoader(NaiveDataLoader):
    def __init__(
        self,
        dataset,
        batch_size=64,
        num_workers=1,
        prefetch_batches=2,
        collate_fn=default_collate,
    ):
        super().__init__(dataset, batch_size, collate_fn)
        self.num_workers = num_workers
        self.prefetch_batches = prefetch_batches
        self.output_queue = multiprocessing.Queue()
        self.index_queues = []
        self.workers = []
        # Round-Robin counter
        self.worker_cycle = itertools.cycle(range(num_workers))
        self.cache = {}
        self.prefetch_index = 0
        # start all workers with a individual index_queues and one shared output queue
        for _ in range(num_workers):
            index_queue = multiprocessing.Queue()
            worker = multiprocessing.Process(
                target=worker_fn, args=(self.dataset, index_queue, self.output_queue)
            )
            worker.daemon = True
            worker.start()
            self.workers.append(worker)
            self.index_queues.append(index_queue)

        self.prefetch()
    def __iter__(self):
        """
        Return the object itself as an iterable. Called once right before iteration, so this function is like reset
        """
        self.index = 0
        self.cache = {}
        self.prefetch_index = 0
        self.prefetch()
        return self

    def prefetch(self):
        """
        Called in every get() call before actual fetching
        """
        while (
            self.prefetch_index < len(self.dataset)
            and self.prefetch_index
            < self.index + 2 * self.num_workers * self.batch_size
        ):
            # if the prefetch_index hasn't reached the end of the dataset
            # and it is not 2 batches ahead, add indexes to the index queues
            self.index_queues[next(self.worker_cycle)].put(self.prefetch_index)
            self.prefetch_index += 1

    def get(self):
        """
        Reutrn a single next item to the __next__() call.
        """
        self.prefetch()
        if self.index in self.cache:
            item = self.cache[self.index]
            # Delete cache to save memory
            del self.cache[self.index]
        else:
            while True:
                try:
                    (index, data) = self.output_queue.get(timeout=0)
                except queue.Empty:  # output queue empty, keep trying
                    continue
                if index == self.index:  # found our item, ready to return
                    item = data
                    break
                else:  # item isn't the one we want, cache for later
                    self.cache[index] = data

        self.index += 1
        return item

    def __del__(self):
        """
        Called when the dataloader no longer has any references and is ready to be garbage collected
        """
        try:
            # Stop each worker by passing None to its index queue
            for i, w in enumerate(self.workers):
                self.index_queues[i].put(None)
                w.join(timeout=5.0)
            for q in self.index_queues:  # close all queues
                q.cancel_join_thread() 
                q.close()
            self.output_queue.cancel_join_thread()
            self.output_queue.close()
        finally:
            for w in self.workers:
                if w.is_alive():  # manually terminate worker if all else fails
                    w.terminate()

non_blocking=True

When having pin_memory=True, tensor transfer from CPU to GPU could be asynchronous. The GPU side of operations will wait if they depend on the specific tensor, and that’s handled by the stream manager (so as users we don’t need to worry about it). Otherwise, data transfer is by default blocking

1
src_batch = src_batch.to(device, non_blocking=True)

So, src_batch.to(device, non_blocking=True) will return immediately,

  • So CPU can attend to other tasks
  • GPU operations that are not reliant on this can attend to other tasks

Transforms in RESNET-20 Example

Imports

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
%matplotlib inline
# load the autoreload extension
%load_ext autoreload
# autoreload mode 2, which loads imported modules again 
# everytime they are changed before code execution.
# So we don't need to restart the kernel upon changes to modules
%autoreload 2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn #CUDA Deep Neural Network Library
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import time
import os
from tempfile import TemporaryDirectory

from PIL import Image
cudnn.benchmark = True
plt.ion()
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (5.0, 4.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
  • torch.backends.cudnn is a CUDA Deep Neural Network Library has optimized primitives such as Convolution, pooling, activation funcs. MXNet, TensorFlow, PyTorch use this under the hood.
  • torch.backends.cudnn.benchmark is A bool that, if True, causes cuDNN to benchmark multiple convolution algorithms and select the fastest.
  • plt.ion() puts pyplot into interactive mode: no explicit calls to plt.show(); show() is not blocking anymore, meaning we can see the real time updates.

  • ResNet was originally trained on the ImageNet dataset with 1.2M + high resolution images and 1000 categories. CIFAR-10 dataset has 60K 32x32 images across 10 classes. (Hence the 10)

Data Loading

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
DATA_DIR='./data'
mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]
transform_train = transforms.Compose([
    # v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomCrop(32, padding=4),
    v2.RandomHorizontalFlip(p=0.5), #flip given a probability
    v2.ToImage(), # only needed if you don't have an PIL image
    v2.ToDtype(torch.float32, scale=True), #Normalize expects float input. scale the value?
    v2.Normalize(mean, std), #normalize with CIFAR mean and std
])
transform_test = transforms.Compose([
    # v2.RandomResizedCrop(size=(224, 224), antialias=True),
    # v2.RandomHorizontalFlip(p=0.5), #flip given a probability
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True), #Normalize expects float input. scale the value?
    v2.Normalize(mean, std), #normalize with CIFAR mean and std
])
train_data = torchvision.datasets.CIFAR10(root=DATA_DIR, train = True, transform = transform_train, download = True)
test_data = torchvision.datasets.CIFAR10(root=DATA_DIR, train = False, transform = transform_test, download = True)

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True, num_workers=1, pin_memory=True)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=16, shuffle=False, num_workers=1, pin_memory=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device: ", device)

For data transforms

  • scale=True in ToDtype() scales RGB values to 255 or 1.0.
  • When downsizing, there could be jagged edges or Moire Pattern due to violation of the Nyquist Theorem. Antialiasing will apply low-pass filtering (smoothing out edges), resample with a different frequency
  • v2.Normalize(mean, std) normalizes data around a mean and std, which does NOT clip under 1.0 (float) or 255 (int). This helps the training to converge faster, but visualization would require clipping in 1.0 or 255.
  • torch.utils.data.Dataset stores the data and their labels
  • torch.utils.data.DataLoader stores an iterable to the data. You can specify batch size so you can create mini-batches off of it. By default, it returns data in CHW format
  • some data are in CHW format (Channel-Height-Weight), so we need to flip it by tensor.permute(1,2,0)

We normalize the pixel values to [0,1], then subtract out the mean and std of the CIFAR-10 dataset. It’s a common practice to normalize input data so images have consistent data distributions over RGB channels (imaging very high and low values)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class_names = train_data.classes
def denorm(img):
    m = np.array(mean)
    s = np.array(std)
    img = img.numpy() * s + m
    return img

ROWS, CLM=2, 2
fig, axes = plt.subplots(nrows=ROWS, ncols=CLM)
fig.suptitle('Sample Images')
features, labels=next(iter(test_dataloader))
for i, ax in enumerate(axes.flat):
    img = denorm(features[i].permute(1,2,0).squeeze())
    ax.imshow(img)
    ax.axis("off")
    ax.set_title(class_names[labels[i].item()])
plt.tight_layout()
plt.imshow(img)

Normalized Input Data

Regular Input Data

Custom Data Loading

For image classification, one custom way to store images is to save images under directories named with its class. Then, save a label -> class name mapping.

Here, we are loading PASCAL VOC (Visual Object Classification) 2007 Dataset to test a neural net trained for CIFAR_10. Some key nuances include:

  • CIFAR-10 takes in 32x32 images and we need to supply some class name mappings. Input data normalization is done as usual
  • We do not add images and their labels if the labels don’t appear in the class name mapping
  • A custom torch.utils.data.Dataset needs to subclass Dataset and has __len__(self) and __getitem__(self) methods.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torch.utils.data import Dataset
voc_root = './data'
year = '2007'
transform_voc = transforms.Compose([
    v2.Resize((32,32)),
    v2.ToTensor(),
    v2.Normalize(mean, std), #normalize with CIFAR mean and std
])

# These are handpicked VOC->CIFAR-10 mapping. If VOC's label doesn't fall into this dictionary, we shouldn't feed it to the model.
class_additional_mapping = {'aeroplane': 'airplane', 'car': 'automobile', 'bird':'bird', 'cat':'cat', 'dog':'dog', 'frog':'frog'}

mean = [0.4914, 0.4822, 0.4465]
std = [0.2470, 0.2435, 0.2616]

class FilteredVOCtoCIFARDataset(Dataset):
    def __init__(self, root, year, image_set, transform=None, class_mapping=None):
        self.voc_dataset = torchvision.datasets.VOCDetection(
            root=root,
            year=year,
            image_set=image_set,
            download=True,
            transform=None  # Transform applied manually later
        )
        self.transform = transform
        self.class_mapping = class_mapping
        self.filtered_indices = self._filter_indices()

    def _filter_indices(self):
        indices = []
        for idx in range(len(self.voc_dataset)):
            target = self.voc_dataset[idx][1]  # Get the annotation
            objects = target['annotation'].get('object', [])
            if not isinstance(objects, list):
                objects = [objects]  # Ensure it's a list of objects
            if len(objects) > 1:
                continue
            obj = objects[0]
            label = obj['name']
            if label in self.class_mapping:  # Check if class is in our mapping
                indices.append(idx)
        return indices

    def __len__(self):
        return len(self.filtered_indices)

    def __getitem__(self, idx):
        actual_idx = self.filtered_indices[idx]
        image, target = self.voc_dataset[actual_idx]
        
        # Apply transformations to the image
        if self.transform:
            image = self.transform(image)

        # Map VOC labels to CIFAR-10 labels
        objects = target['annotation'].get('object', [])
        if not isinstance(objects, list):
            objects = [objects]  # Ensure it's a list of objects

        # Create a list of labels for the image
        labels = []
        for obj in objects:
            label = obj['name']
            if label in self.class_mapping:
                labels.append(self.class_mapping[label])

        # Return the image and the first label (as a classification task)
        return image, labels[0]  # In classification, return a single label per image

dataset = FilteredVOCtoCIFARDataset(
    root=voc_root,
    year='2007',
    image_set='val',
    transform=transform_voc,
    class_mapping=class_additional_mapping
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=16,          # Adjust based on your memory constraints
    shuffle=True,
    num_workers=2,         # Adjust based on your system
    pin_memory=True
)