NOTE: We reimplement our method based on Swin, the models and logs are old version. You will run into some problems with the wrong module names, but it can be fixed manually. We will update these resources when we have time. But you can reproduce our work and results with the following instructions.
Clone the repository firstly:
git clone
cd CAT
For classification, we need pytorch
and timm
conda create -n cat python=3.7
conda activate cat
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
Install other requirements:
pip install timm==0.3.2 opencv-python== termcolor==1.1.0 yacs==0.1.8
Install Apex
git clone
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
NOTE: You may install Apex
failed, please run install as follows:
pip install -v --no-cache-dir ./
For standard ImageNet dataset, you can download it from ImageNet.
The file structure should as follows:
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...
You can simplely run as follows:
python -m torch.distributed.launch --nproc_per_node <number-of-gpus> --master_port 10086 \
--cfg <config-file> --data-path <imagenet-path> --batch-size <batch-size>
For small
python -m torch.distributed.launch --nproc_per_node 8 --master_port 10086 \
--cfg configs/cat_small.yaml --data-path data/CLS-LOC --batch-size 128
You can evaluate models as follows:
python -m torch.distributed.launch --nproc_per_node <number-of-gpus> --master_port 10086 \
--eval --cfg <config-file> --resume <checkpoint-file> --data-path <imagenet-path>
You can evaluate the throughput as follow:
python -m torch.distributed.launch --nproc_per_node 1 --master_port 10086 \
--cfg configs/cat_small.yaml --data-path data/CLS-LOC --batch-size 64 --throughput --amp-opt-level O0
Out implementation is based on mmdetection. Please install mmdetection.
To train CAT based detection methods, run as follows:
cd detection
Run RetinaNet
with 8 gpus:
bash configs/ 8 --options model.pretrained=<pretrained-model>
To evaluate the mAP of CAT based RetinaNet
on COCO with 8 gpus, run:
bash configs/ <checkpoint-file> 8 --eval mAP
Out implementation is based on mmsegmentation. Please install mmsegmentation.
To train CAT based segmentation methods, run as follows:
cd segmentation
Run Semantic FPN
with 8 gpus:
bash configs/ 8 --options model.pretrained=<pretrained-model>
To evaluate the mAP of CAT based Semantic FPN
on COCO with 8 gpus, run:
bash configs/ <checkpoint-file> 8 --eval mIoU
To evaluate FLOPs of methods:
cd detection # or cd segmentation
python <config-file> [--shape <evaluate-shape>]