Torch.autograd.Function
torch.autograd.Function lets you define a custom op with explicit forward() and backward() passes. Common built-ins like MaxPool are implemented this way. It’s especially useful for wrapping a CUDA kernel — calling FurthestSampling.apply exposes it as a regular callable you can drop into your model.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class FurthestSampling(Function):
@staticmethod
def forward(ctx, xyz: torch.Tensor, m: int) -> torch.Tensor:
"""
input: xyz: (B, N, 3)
m: number of points to sample
output: idx: (B, m) int32
NOTE: kernel supports float32 and float16.
For fp16, coordinates must be < ~181 units apart (max sq dist ~65504).
"""
assert xyz.is_contiguous()
b, n, _ = xyz.size()
idx = torch.zeros(b, m, dtype=torch.int32, device=xyz.device)
temp = torch.full((b, n), float('inf'), dtype=xyz.dtype, device=xyz.device)
pointops_cuda.furthestsampling_cuda(b, n, m, xyz, temp, idx)
return idx
@staticmethod
def backward(ctx, grad_idx):
return None, None
furthestsampling = FurthestSampling.apply
backward() returns None for both inputs because furthest-point sampling only selects indices — it is non-differentiable, so no gradients flow back through it:
xyzreceives no gradients from this op — the sampling step acts as a stop-gradient.mis a plain Pythonint, so it has no gradient anyway.
Why backprop still works in practice
1
selected = feat[..., idx, :] # shape (B, m, C)
feat does have a gradient, and torch.gather is differentiable w.r.t. feat. So gradients flow through the gathered features back into the feature extractor — the network learns what features to produce, even though it cannot learn where to sample.
Operations
Convolution
conv1d vs conv 2d:
- Conv1d is used when data is (B, C, N) — a list of N independent vectors:
- Conv2d is used when data is (B, C, M, K) — a grid of M points × K neighbors:
Normalization
nn.GroupNorm(ngroups, hidden)splits hidden channels (layers) intongroups, then nomalize points across each group. This could be more stable in small batches. E.g.,hidden = 8channels,ngroups = 2. Then channels are split: Group 0: channels 0–3, Group 1: channels 4–7.
- Then apply the learned shift and scale:
nn.ReLU(inplace=True)modifies the input tensor in place
1
relu(x) = max(0, x)
-
Sometimes it could break autograd if the original pre-Relu value is needed, e.g., skip connections. If those values are not reused they are usually fine.
-
torch.gather()idx[:, None]is the same asidx.unsqueeze(-1), which adds a new dimension. For example, an array of shape(5,)becomes(5, 1).idx.repeat(1, fdim, 1): copies the tensorfdimtimes along the second dimension. E.g., a(5, 1)array afterrepeat(1, 3)becomes(5, 3):
1 2 3 4
arr = np.array([1, 2, 3, 4, 5]) arr[:, None].shape # (5, 1) t = torch.tensor(arr[:, None]) t.repeat(1, 3) # shape (5, 3)
torch.gather(input, dim, index)picks input values by their indices alongdim:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# - Channel 0 (c=0): `[10, 11, 12, 13, 14]`
# - Channel 1 (c=1): `[20, 21, 22, 23, 24]`
xyzs =
[
[ # batch 0
[10, 11, 12, 13, 14], # channel 0
[20, 21, 22, 23, 24], # channel 1
]
] # shape [1, 2, 5]
# Now choose indices
idx = [[0, 3, 4]] # shape [1, 3]
torch.gather(xyzs, dim=2, index=idx.unsqueeze(1).expand(-1, 2, -1))
# gives
[
[
[10, 13, 14],
[20, 23, 24],
]
] # shape [1, 2, 3]
- negative infinity
1
2
# Use dtype-aware very negative value
neg_inf = torch.finfo(logits.dtype).min
Sum
- weighted sum across two dims if their dimension numbers are the same:
1
2
3
4
5
6
7
8
9
10
11
12
attn = torch.tensor(
[[[[0.2, 0.3, 0.5],
[0.1, 0.6, 0.3]]]]) # (1,1,2,3) bcmk
t = torch.tensor(
[[[[10., 20.],
[11., 21.],
[12., 22.]]]]) # (1,1,3,2) bckm
res = torch.einsum('bcmk, bckm -> bcm', attn, t)
#[0.2 * 10 + 0.3 * 11 + 0.5 * 12, 0.1 * 20 + 0.6 * 21 + 0.3 * 22]
print(res) # [[[11.3, 21.2]]]
- Histogram Count:
idx: idx to updatesrc: values to add to countscounts: histogram counting bins
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
sample_num = 5
idx = torch.tensor([0, 2, 2, 4, 2, 0, 1, 4]) # indices to add into (length N)
src = torch.tensor([3, 22, 1, 1, 1, 1, 1, 1], dtype=torch.long)
counts = torch.zeros(sample_num, dtype=torch.long)
counts.scatter_add_(0, idx, src)
# idx[0] = 0, so counts[0] += src[0] = 3
# idx[1] = 2, so counts[2] += src[1] = 22
# idx[2] = 2, so counts[2] += src[2] = 23
# idx[3] = 4, so counts[4] += src[3] = 1
# idx[4] = 2, so counts[2] += src[4] = 24
# idx[5] = 0, so counts[0] += src[5] = 4
# idx[6] = 1, so counts[1] += src[6] = 1
# idx[7] = 4, so counts[4] += src[7] = 2
print("counts:", counts.tolist()) # [4, 1, 24, 0, 2]
Infrastructure
what is ctx in autograd function?
ctx is the “backward stash” for whatever forward needs to remember. It’s part of the Torch API.
1
2
3
4
class Gathering(Function):
@staticmethod
def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
- In
forward(ctx, ...), you usectxto save things you’ll need later for gradients. - In
backward(ctx, ...), you retrieve them.
When to Empty Cuda Cache
Caching allocator exists to avoid cudaMalloc / cudaFree which is expensive.
torch.cuda.empty_cache() releases unused cached GPU memory back to the CUDA driver so other processes can use it. It doesn’t free memory held by live tensors - If something is still referenced, it stays allocated. This however, should NOT be called in the middle of a forward() call, because it introduces synchronization / stalls to reallocate memory later. So this is rarely used in real life