PyTorch Mixed Precision Training

torch.zeros

Posted by Rico's Nerd Cluster on January 17, 2026

torch.zeros

By default:

1
out = torch.zeros(b, c, m).to(features.device)  # allocates on CPU, then transfer to device
  • torch.zeros(...) creates a tensor with dtype torch.float32 unless you specify dtype=....
  • .to(features.device) moves it to the same device (CPU/GPU) as features.

However, it does not guarantee the same dtype as features if features is fp16/bf16/etc. If you want it to match features exactly (device + dtype), do:

1
out = torch.zeros((b, c, m), device=features.device, dtype=features.dtype)