PyTorch Mixed Precision Training
torch.zeros, GradScaler, GradCheck
Pytorch Setup
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
use_amp = True
net = make_model(in_size, out_size, num_layers)
opt = torch.optim.SGD(net.parameters(), lr=0.001)
# if False, th...