This is the official code repository for Gradient-Map-Guided Adaptive Domain Generalization for Cross Modality MRI Segmentation by Bingnan Li, Zhitong Gao, Xuming He (ML4H 2023).
Cross-modal MRI segmentation is of great value for computer-aided medical diagnosis, enabling flexible data acquisition and model generalization. However, most existing methods have difficulty in handling local variations in domain shift and typically require a significant amount of data for training, which hinders their usage in practice. To address these problems, we propose a novel adaptive domain generalization framework, which integrates a learning-free cross-domain representation based on image gradient maps and a class prior-informed test-time adaptation strategy for mitigating local domain shift. We validate our approach on two multi-modal MRI datasets with six cross-modal segmentation tasks. Across all the task settings, our method consistently outperforms competing approaches and shows a stable performance even with limited training data.
We only guarantee the correctness of the code on the following platforms:
- Linux (with
cuda
acceleration) - MacOS (with
MPS
acceleration)
We use Python 3.11
, feel free to use conda
or venv
to create the environment.
Once you have created the environment, install the dependencies with the following command:
pip install -r requirements.txt
You can download the datasets used in our experiments with instructions in the following links:
Once download the datasets, please place the folders into datasets
with the name of BraTS2018_Raw
and MS-CMRSeg2019_Raw
respectively. The folder structure should be like:
BraTS2018_Raw
├── HGG
│ ├── Brats18_2013_10_1
│ ├── Brats18_2013_11_1
│ ├── ...
│ └── Brats18_TCIA08_469_1
└── LGG
├── Brats18_2013_0_1
├── Brats18_2013_15_1
├── ...
└── Brats18_TCIA13_654_1
MS-CMRSeg2019_Raw
├── C0LET2_gt_for_challenge19
│ ├── C0_manual_10
│ ├── LGE_manual_35_TestData
│ └── T2_manual_10
└── C0LET2_nii45_for_challenge19
├── c0gt
├── c0t2lge
├── lgegt
└── t2gt
declare -a SOURCE=("t2" "flair")
declare -a TARGET=("t1" "t1ce")
for source in ${SOURCE[@]}
do
for target in ${TARGET[@]}
do
python datasets/BraTS_2018.py \
--root datasets/BraTS2018_Raw \
--save_dir datasets/BraTS_2018 \
--source $source \
--target $target \
--train_source True \
--val_target True
done
done
source="C0"
declare -a TARGET=("T2" "LGE")
for target in ${TARGET[@]}
do
python datasets/MSCMRSeg2019.py \
--root datasets/MS-CMRSeg2019_Raw \
--save_dir datasets/MS-CMRSeg2019 \
--source $source \
--target $target \
--train_source True \
--val_target True
done
cd scripts
bash train_<source_domain>.sh
To visualize the intermediate results, use the following command:
tensorboard --logdir ./saved_models/<DATASET>/<SOURCE_DOMAIN>/<EXP_NAME>
- Only set
--use_fp16
True
when usingNVIDIA GPU
orMPS
.
cd scripts
bash test_<SETTING>.sh
We only provide the test scripts of C02LGE
and t22t1
.
You can easily modify the scripts to test any settings in our paper.
If you want to see the segmentation results and formal evaluation metrics, use the following command:
tensorboard --logdir ./val_res/<DATASET>/<SETTING>/<EXP_NAME>
Dataset | Source Domain | Download | Avg_Dice(T2) | Avg_Dice(LGE) | Avg_Dice(t1) | Avg_Dice(t1ce) | Size |
---|---|---|---|---|---|---|---|
MS-CMRSeg2019 | C0 | ckpt | 0.8555 | 0.8562 | - | - | 105.6M |
BraTS2018 | t2 | ckpt | - | - | 0.6813 | 0.6914 | 105.6M |
BraTS2018 | flair | ckpt | - | - | 0.4189 | 0.5986 | 105.6M |
@InProceedings{pmlr-v225-li23a,
title = {Gradient-Map-Guided Adaptive Domain Generalization for Cross Modality MRI Segmentation},
author = {Li, Bingnan and Gao, Zhitong and He, Xuming},
booktitle = {Proceedings of the 3rd Machine Learning for Health Symposium},
pages = {292--306},
year = {2023},
editor = {Hegselmann, Stefan and Parziale, Antonio and Shanmugam, Divya and Tang, Shengpu and Asiedu, Mercy Nyamewaa and Chang, Serina and Hartvigsen, Tom and Singh, Harvineet},
volume = {225},
series = {Proceedings of Machine Learning Research},
month = {10 Dec},
publisher = {PMLR},
pdf = {https://proceedings.mlr.press/v225/li23a/li23a.pdf},
url = {https://proceedings.mlr.press/v225/li23a.html},
abstract = {Cross-modal MRI segmentation is of great value for computer-aided medical diagnosis, enabling flexible data acquisition and model generalization. However, most existing methods have difficulty in handling local variations in domain shift and typically require a significant amount of data for training, which hinders their usage in practice. To address these problems, we propose a novel adaptive domain generalization framework, which integrates a learning-free cross-domain representation based on image gradient maps and a class prior-informed test-time adaptation strategy for mitigating local domain shift. We validate our approach on two multi-modal MRI datasets with six cross-modal segmentation tasks. Across all the task settings, our method consistently outperforms competing approaches and shows a stable performance even with limited training data. Our Codes are available now at https://github.com/cuttle-fish-my/GM-Guided-DG .}
}
@misc{li2023gradientmapguided,
title={Gradient-Map-Guided Adaptive Domain Generalization for Cross Modality MRI Segmentation},
author={Bingnan Li and Zhitong Gao and Xuming He},
year={2023},
eprint={2311.09737},
archivePrefix={arXiv},
primaryClass={cs.CV}
}