This is the official PyTorch implementation for our NeurIPS 2023 paper "DAW: Exploring the Better Weighting Function for Semi-supervised Semantic Segmentation".
The critical challenge of semi-supervised semantic segmentation lies how to fully exploit a large volume of unlabeled data to improve the model’s generalization performance for robust segmentation. Existing methods tend to employ certain criteria (weighting function) to select pixel-level pseudo labels. However, the trade-off exists between inaccurate yet utilized pseudo-labels, and correct yet discarded pseudo-labels in these methods when handling pseudo-labels without thoughtful consideration of the weighting function, hindering the generalization ability of the model. In this paper, we systematically analyze the trade-off in previous methods when dealing with pseudo-labels. We formally define the trade-off between inaccurate yet utilized pseudo-labels, and correct yet discarded pseudo labels by explicitly modeling the confidence distribution of correct and inaccurate pseudo-labels, equipped with a unified weighting function. To this end, we propose Distribution-Aware Weighting (DAW) to strive to minimize the negative equivalence impact raised by the trade-off. We find an interesting fact that the optimal solution for the weighting function is a hard step function, with the jump point located at the intersection of the two confidence distributions. Besides, we devise distribution alignment to mitigate the issue of the discrepancy between the prediction distributions of labeled and unlabeled data. Extensive experimental results on multiple benchmarks including mitochondria segmentation demonstrate that DAW performs favorably against state-of-the-art methods.
cd DAW
conda create -n daw python=3.10.4
conda activate daw
pip install -r requirements.txt
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
├── ./pretrained
├── resnet50.pth
└── resnet101.pth
- Pascal: JPEGImages | SegmentationClass
- Cityscapes: leftImg8bit | gtFine
Please modify your dataset path in configuration files.
The groundtruth masks were preprocessed by UniMatch.
The final folder structure should look like this:
├── DAW
├── pretrained
└── resnet50.pth, ...
└── daw.py
└── ...
├── data
├── pascal
├── JPEGImages
└── SegmentationClass
├── cityscapes
├── leftImg8bit
└── gtFine
# use torch.distributed.launch
sh scripts/train.sh <num_gpu> <port>
If you find this project useful, please consider citing:
@inproceedings{sun2023daw,
title={DAW: exploring the better weighting function for semi-supervised semantic segmentation},
author={Sun, Rui and Mai, Huayu and Zhang, Tianzhu and Wu, Feng},
booktitle={Proceedings of the 37th International Conference on Neural Information Processing Systems},
pages={61792--61805},
year={2023}
}
@inproceedings{mai2024rankmatch,
title={RankMatch: Exploring the Better Consistency Regularization for Semi-supervised Semantic Segmentation},
author={Mai, Huayu and Sun, Rui and Zhang, Tianzhu and Wu, Feng},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={3391--3401},
year={2024}
}
RankMatch is primarily based on UniMatch. We are grateful to their authors for open-sourcing their code.