torch.zeros
By default:
1
out = torch.zeros(b, c, m).to(features.device) # allocates on CPU, then transfer to device
torch.zeros(...)creates a tensor with dtypetorch.float32unless you specifydtype=.....to(features.device)moves it to the same device (CPU/GPU) asfeatures.
However, it does not guarantee the same dtype as features if features is fp16/bf16/etc. If you want it to match features exactly (device + dtype), do:
1
out = torch.zeros((b, c, m), device=features.device, dtype=features.dtype)