Official implementation for ECCV2024 paper "AlignDiff: Aligning Diffusion Models for General Few-Shot Segmentation".
[Paper]
Clone the codebase.
cd ~
git clone --recursive https://github.com/RogerQi/aligndiff
Follow INSTALL.md to install required dependencies, and PREPARE.md to set up huggingface accounts and download sample FSS weights.
As decribed in the paper, for the same category, we introduce a prompt bank to make generated samples more diverse. There are three planned sources of prompts:
- Texts (manually entered simple nouns or with detailed descriptions)
- (TODO) Augmented texts (retrieval from a database or augmented using LLMs)
- Learned per-object embedding (proposed normalized textual inversion)
AlignDiff requires the prompts to be organized in a directory to be used. An example
layout containing two categories and a JSON file is provided in ./examples_dataset/
.
To optimize the per-object embedding for image-mask pairs in the provided datasets, run
python -m scripts.image_mask_inversion --dataset_dir ./example_dataset
On a single RTX4090, this step takes ~3mins.
Tips: the invesion process is quite sensitive to bg_lambda
. If it ends up generating degenerate samples, you can try tweaking this parameter.
We use the following command to generate images and masks.
python -m scripts.diffuse_img_attention --dataset_dir ./example_dataset --image_per_class 100
This diffusion generation process is the most time-consuming process, so we added routine to use multiple GPUs to accelerate the process. To use multiple GPUs, simply set the environment variables. For example,
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m scripts.diffuse_img_attention --dataset_dir ./example_dataset --image_per_class 100
We use the following command to refine masks with few-shot conditioning.
python -m scripts.mask_generation --dataset_dir ./example_dataset --fss_ckpt_path pretrained_weights/voc_2_resnet101_5shot_78.89.pth
TODOs
- Switch to modern diffusers
- Use xformers to speed up inversion / generation
- Integrate LLM / CLIPRetrieval to augment text prompts
- Switch to a better few-shot segmenter in mask generation (more recent networks & more pre-training data?)
- Integrate SAM to further refine mask generation?
- Add image / text template augmentations to the NM inversion to further stablize the training process.
- Make a prettier tqdm pbar during image-attention synthesis.