This article is inspired by this referece
The source code of this project can be found here
Data Loading
When making a dataset, do NOT use jpeg
and stick to png instead. Jpeg
will compress the data and will corrupt the labels.
- Pascal VOC 20 classes, including humans.
- Subtract means and std deviation of ImageNet in pascal is okay since ImageNet is a large enough dataset.
- Since image segmentation target are
uint8
, we should add transformv2.PILToTensor()
to retain the data type (uint8) v2.Lambda(lambda x: replace_tensor_val(x.long(), 255, 21)),
we are replacing values 255, which is usually ignore index, to 21. This is because 255 might cause issues in loss function (say softmax) when you have only 20 classes.- Alternatively,
loss_function = torch.nn.CrossEntropyLoss(ignore_index=255)
can be used.
- Alternatively,
- Using
interpolation=InterpolationMode.NEAREST
is important in downsizing, why?- Because when downsizing, we need to combine multiple pixels together. Interpolation handles that.
- Other interpolation methods like bicubic will create continuous float values.
Training
THE BIGGEST PROBLEM I ENCOUNTERED was the output labels were mostly zero. This is because the dataset is imbalanced and has way many more zeros than other classes. In that case, do not use cross entropy.
-
Pascal VOC 2007 has only 209 images for training, 213 images for validation. We thought this would be far from being enough for training. Pascal VOC 2012 has 1,464 images for training, 1449 for validation. However, trying Pascal VOC 2012 did not solve the problem
-
torchvision.transforms.CenterCrop(size)
was necessary because after convolutions, the skip connections are slightly larger than their upsampled peers. -
focal loss
seems useful, but it was a bit tricky to check. I checked with the one-hot version of output labels, and compare against itself. That was supposed to be the “perfect” example, and I expected to see a loss of 0. However, due to the softmax operation in focal loss, I got 0.78. This was resolved by doing100 * one_hot_labels
. Just using focal loss alone did NOT get around the imbalance issue for Pascal VOC 207 -
However, UNet works on the Carvana dataset pretty well, on the GTA5 dataset it works decently, but on the Pascal VOC 2007 and 2012 it is pretty bad - it just learned the background.
Performance Profiling
- For the GTA5 dataset, my
train/dev/test
dataset split is70%, 15%, 15%
. My accuracies are- train:
68.1%
- dev:
68.7%
- test:
67.88%
-
examples
-
- train:
- Cavana dataset: my
train/dev/test
dataset split is70%, 15%, 15%
. My accuracies are:- Mixed precision training (average 383s/batch)
- train:
90.51%
- dev:
90.46%
- test:
90.61%
- train:
- FP32 Full precision training (time)
- train:
90.55%
- dev:
90.63%
- test:
90.66%
-
Example
- train:
- Mixed precision training (average 383s/batch)
- Pascal VOC 2007
- Mixed precision training (average 383s/batch)
- train:
72.97%
- dev:
73.61%
- test:
74.27%
- train:
-
Examples:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
<div style="text-align: center;"> <p align="center"> <figure> <img src="https://github.com/user-attachments/assets/c3353c3f-308c-4b81-9798-8873b2488b39" height="200" alt=""/> </figure> </p> </div> <div style="text-align: center;"> <p align="center"> <figure> <img src="https://github.com/user-attachments/assets/2f2538f8-0822-4da8-b8b3-b5460531b20d" height="200" alt=""/> </figure> </p> </div> <div style="text-align: center;"> <p align="center"> <figure> <img src="https://github.com/user-attachments/assets/00e629bb-d286-4d28-bea1-8e74c553eb36" height="200" alt=""/> </figure> </p> </div>
- Mixed precision training (average 383s/batch)
- Pascal VOC 2012
- Mixed precision training (average 383s/batch)
- train: ``64.70%`
- dev:
64.27%
- test:
64.69%
- Example:
- Mixed precision training (average 383s/batch)