import os import random import argparse import yaml from tqdm import tqdm import torch import torch.nn.functional as F import torch.nn as nn import torchvision.transforms as transforms from datasets import build_dataset from datasets.utils import build_data_loader import clip from utils import * def get_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--config', dest='config', help='settings of Tip-Adapter in yaml format') args = parser.parse_args() return args def run_tip_adapter(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights): print("\n-------- Searching hyperparameters on the val set. --------") # Zero-shot CLIP clip_logits = 100. * val_features @ clip_weights acc = cls_acc(clip_logits, val_labels) print("\n**** Zero-shot CLIP's val accuracy: {:.2f}. ****\n".format(acc)) # Tip-Adapter beta, alpha = cfg['init_beta'], cfg['init_alpha'] affinity = val_features @ cache_keys cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values tip_logits = clip_logits + cache_logits * alpha acc = cls_acc(tip_logits, val_labels) print("**** Tip-Adapter's val accuracy: {:.2f}. ****\n".format(acc)) # Search Hyperparameters best_beta, best_alpha = search_hp(cfg, cache_keys, cache_values, val_features, val_labels, clip_weights) print("\n-------- Evaluating on the test set. --------") # Zero-shot CLIP clip_logits = 100. * test_features @ clip_weights acc = cls_acc(clip_logits, test_labels) print("\n**** Zero-shot CLIP's test accuracy: {:.2f}. ****\n".format(acc)) # Tip-Adapter affinity = test_features @ cache_keys cache_logits = ((-1) * (best_beta - best_beta * affinity)).exp() @ cache_values tip_logits = clip_logits + cache_logits * best_alpha acc = cls_acc(tip_logits, test_labels) print("**** Tip-Adapter's test accuracy: {:.2f}. ****\n".format(acc)) def run_tip_adapter_F(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights, clip_model, train_loader_F): # Enable the cached keys to be learnable adapter = nn.Linear(cache_keys.shape[0], cache_keys.shape[1], bias=False).to(clip_model.dtype).cuda() adapter.weight = nn.Parameter(cache_keys.t()) optimizer = torch.optim.AdamW(adapter.parameters(), lr=cfg['lr'], eps=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg['train_epoch'] * len(train_loader_F)) beta, alpha = cfg['init_beta'], cfg['init_alpha'] best_acc, best_epoch = 0.0, 0 for train_idx in range(cfg['train_epoch']): # Train adapter.train() correct_samples, all_samples = 0, 0 loss_list = [] print('Train Epoch: {:} / {:}'.format(train_idx, cfg['train_epoch'])) for i, (images, target) in enumerate(tqdm(train_loader_F)): images, target = images.cuda(), target.cuda() with torch.no_grad(): image_features = clip_model.encode_image(images) image_features /= image_features.norm(dim=-1, keepdim=True) affinity = adapter(image_features) cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values clip_logits = 100. * image_features @ clip_weights tip_logits = clip_logits + cache_logits * alpha loss = F.cross_entropy(tip_logits, target) acc = cls_acc(tip_logits, target) correct_samples += acc / 100 * len(tip_logits) all_samples += len(tip_logits) loss_list.append(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() current_lr = scheduler.get_last_lr()[0] print('LR: {:.6f}, Acc: {:.4f} ({:}/{:}), Loss: {:.4f}'.format(current_lr, correct_samples / all_samples, correct_samples, all_samples, sum(loss_list)/len(loss_list))) # Eval adapter.eval() affinity = adapter(test_features) cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_values clip_logits = 100. * test_features @ clip_weights tip_logits = clip_logits + cache_logits * alpha acc = cls_acc(tip_logits, test_labels) print("**** Tip-Adapter-F's test accuracy: {:.2f}. ****\n".format(acc)) if acc > best_acc: best_acc = acc best_epoch = train_idx torch.save(adapter.weight, cfg['cache_dir'] + "/best_F_" + str(cfg['shots']) + "shots.pt") adapter.weight = torch.load(cfg['cache_dir'] + "/best_F_" + str(cfg['shots']) + "shots.pt") print(f"**** After fine-tuning, Tip-Adapter-F's best test accuracy: {best_acc:.2f}, at epoch: {best_epoch}. ****\n") print("\n-------- Searching hyperparameters on the val set. --------") # Search Hyperparameters best_beta, best_alpha = search_hp(cfg, cache_keys, cache_values, val_features, val_labels, clip_weights, adapter=adapter) print("\n-------- Evaluating on the test set. --------") affinity = adapter(test_features) cache_logits = ((-1) * (best_beta - best_beta * affinity)).exp() @ cache_values tip_logits = clip_logits + cache_logits * best_alpha acc = cls_acc(tip_logits, test_labels) print("**** Tip-Adapter-F's test accuracy: {:.2f}. ****\n".format(max(best_acc, acc))) def main(): # Load config file args = get_arguments() assert (os.path.exists(args.config)) cfg = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) cache_dir = os.path.join('./caches', cfg['dataset']) os.makedirs(cache_dir, exist_ok=True) cfg['cache_dir'] = cache_dir print("\nRunning configs.") print(cfg, "\n") # CLIP clip_model, preprocess = clip.load(cfg['backbone']) clip_model.eval() # Prepare dataset random.seed(1) torch.manual_seed(1) print("Preparing dataset.") dataset = build_dataset(cfg['dataset'], cfg['root_path'], cfg['shots']) val_loader = build_data_loader(data_source=dataset.val, batch_size=64, is_train=False, tfm=preprocess, shuffle=False) test_loader = build_data_loader(data_source=dataset.test, batch_size=64, is_train=False, tfm=preprocess, shuffle=False) train_tranform = transforms.Compose([ transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ]) train_loader_cache = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=False) train_loader_F = build_data_loader(data_source=dataset.train_x, batch_size=256, tfm=train_tranform, is_train=True, shuffle=True) # Textual features print("\nGetting textual features as CLIP's classifier.") clip_weights = clip_classifier(dataset.classnames, dataset.template, clip_model) # Construct the cache model by few-shot training set print("\nConstructing cache model by few-shot visual features and labels.") cache_keys, cache_values = build_cache_model(cfg, clip_model, train_loader_cache) # Pre-load val features print("\nLoading visual features and labels from val set.") val_features, val_labels = pre_load_features(cfg, "val", clip_model, val_loader) # Pre-load test features print("\nLoading visual features and labels from test set.") test_features, test_labels = pre_load_features(cfg, "test", clip_model, test_loader) # ------------------------------------------ Tip-Adapter ------------------------------------------ run_tip_adapter(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights) # ------------------------------------------ Tip-Adapter-F ------------------------------------------ run_tip_adapter_F(cfg, cache_keys, cache_values, val_features, val_labels, test_features, test_labels, clip_weights, clip_model, train_loader_F) if __name__ == '__main__': main()