Deep Learning - PyTorch Model Training

Checkpointing

Posted by Rico's Nerd Cluster on March 6, 2022

Checkpointing

Checkpointing is a technique to trade compute for memory during training. Instead of storing all intermediate activations (outputs layers) for backprop, which consumes a lot of memory, checkpointing discards some and recomputes them during the backward pass. Thus, this saves memory at the expense of additional computation

1
2
3
4
5
6
7
8
9
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.inc = nn.Conv2d(3, 16, kernel_size=3, padding=1)  # Example layer
        self.inc = checkpoint.checkpoint(self.inc)  # Enable checkpointing

    def forward(self, x):
        x = self.inc(x)  # Checkpointed layer
        return x

checkpointing can be used on functions as well.

Training

  • model.n_channels print(f'model.n_channels: {} ')