|
| 1 | +import argparse |
| 2 | + |
| 3 | + |
| 4 | +def parse_args(): |
| 5 | + parser = argparse.ArgumentParser( |
| 6 | + description='PyTorch Lottery Tickets Experiments') |
| 7 | + |
| 8 | + ##################################### Dataset ################################################# |
| 9 | + parser.add_argument('--data', type=str, default='../data', |
| 10 | + help='location of the data corpus') |
| 11 | + parser.add_argument('--dataset', type=str, |
| 12 | + default='cifar10', help='dataset') |
| 13 | + parser.add_argument('--input_size', type=int, |
| 14 | + default=32, help='size of input images') |
| 15 | + parser.add_argument('--data_dir', type=str, |
| 16 | + default='./tiny-imagenet-200', help='dir to tiny-imagenet') |
| 17 | + parser.add_argument('--num_workers', type=int, default=4) |
| 18 | + parser.add_argument('--num_classes', type=int, default=10) |
| 19 | + ##################################### Architecture ############################################ |
| 20 | + parser.add_argument('--arch', type=str, |
| 21 | + default='resnet18', help='model architecture') |
| 22 | + parser.add_argument('--imagenet_arch', action="store_true", |
| 23 | + help="architecture for imagenet size samples") |
| 24 | + parser.add_argument('--train_y_file', type=str, |
| 25 | + default='./labels/train_ys.pth', help='labels for training files') |
| 26 | + parser.add_argument('--val_y_file', type=str, |
| 27 | + default='./labels/val_ys.pth', help='labels for validation files') |
| 28 | + ##################################### General setting ############################################ |
| 29 | + parser.add_argument('--seed', default=2, type=int, help='random seed') |
| 30 | + parser.add_argument('--train_seed', default=1, type=int, |
| 31 | + help='seed for training (default value same as args.seed)') |
| 32 | + parser.add_argument('--gpu', type=int, default=0, help='gpu device id') |
| 33 | + parser.add_argument('--workers', type=int, default=4, |
| 34 | + help='number of workers in dataloader') |
| 35 | + parser.add_argument('--resume', action="store_true", |
| 36 | + help="resume from checkpoint") |
| 37 | + parser.add_argument('--checkpoint', type=str, |
| 38 | + default=None, help='checkpoint file') |
| 39 | + parser.add_argument( |
| 40 | + '--save_dir', help='The directory used to save the trained models', default=None, type=str) |
| 41 | + parser.add_argument('--mask', type=str, default=None, help='sparse model') |
| 42 | + |
| 43 | + ##################################### Training setting ################################################# |
| 44 | + parser.add_argument('--batch_size', type=int, |
| 45 | + default=256, help='batch size') |
| 46 | + parser.add_argument('--lr', default=0.1, type=float, |
| 47 | + help='initial learning rate') |
| 48 | + parser.add_argument('--momentum', default=0.9, type=float, help='momentum') |
| 49 | + parser.add_argument('--weight_decay', default=5e-4, |
| 50 | + type=float, help='weight decay') |
| 51 | + parser.add_argument('--epochs', default=182, type=int, |
| 52 | + help='number of total epochs to run') |
| 53 | + parser.add_argument('--warmup', default=0, type=int, help='warm up epochs') |
| 54 | + parser.add_argument('--print_freq', default=50, |
| 55 | + type=int, help='print frequency') |
| 56 | + parser.add_argument('--decreasing_lr', default='91,136', |
| 57 | + help='decreasing strategy') |
| 58 | + parser.add_argument('--no-aug', action='store_true', default=False, |
| 59 | + help='No augmentation in training dataset (transformation).') |
| 60 | + parser.add_argument('--no-l1-epochs', default=0, type=int, help='non l1 epochs') |
| 61 | + ##################################### Pruning setting ################################################# |
| 62 | + parser.add_argument('--prune', type=str, default="omp", |
| 63 | + help="method to prune") |
| 64 | + parser.add_argument('--pruning_times', default=1, |
| 65 | + type=int, help='overall times of pruning (only works for IMP)') |
| 66 | + parser.add_argument('--rate', default=0.95, type=float, |
| 67 | + help='pruning rate') # pruning rate is always 20% |
| 68 | + parser.add_argument('--prune_type', default='rewind_lt', type=str, |
| 69 | + help='IMP type (lt, pt or rewind_lt)') |
| 70 | + parser.add_argument('--random_prune', action='store_true', |
| 71 | + help='whether using random prune') |
| 72 | + parser.add_argument('--rewind_epoch', default=0, |
| 73 | + type=int, help='rewind checkpoint') |
| 74 | + parser.add_argument('--rewind_pth', default=None, |
| 75 | + type=str, help='rewind checkpoint to load') |
| 76 | + parser.add_argument('--hf_vit', default='NO', type=str, |
| 77 | + choices=['YES', 'NO'], help='lora method') |
| 78 | + parser.add_argument('--debug', default=False, type=bool, |
| 79 | + choices=[True, False], help='lora method') |
| 80 | + |
| 81 | + ##################################### Unlearn setting ################################################# |
| 82 | + parser.add_argument('--unlearn', type=str, |
| 83 | + default='retrain', help='method to unlearn') |
| 84 | + parser.add_argument('--unlearn_lr', default=0.01, type=float, |
| 85 | + help='initial learning rate') |
| 86 | + parser.add_argument('--unlearn_epochs', default=10, type=int, |
| 87 | + help='number of total epochs for unlearn to run') |
| 88 | + parser.add_argument('--num_indexes_to_replace', type=int, default=None, |
| 89 | + help='Number of data to forget') |
| 90 | + parser.add_argument('--class_to_replace', type=int, default=0, |
| 91 | + help='Specific class to forget') |
| 92 | + parser.add_argument('--indexes_to_replace', type=list, default=None, |
| 93 | + help='Specific index data to forget') |
| 94 | + parser.add_argument('--alpha', default=0.2, type=float, |
| 95 | + help='unlearn noise') |
| 96 | + parser.add_argument('--lora', default='NO', type=str, |
| 97 | + choices=['YES', 'NO'], help='lora method') |
| 98 | + |
| 99 | + ##################################### Attack setting ################################################# |
| 100 | + parser.add_argument('--attack', type=str, |
| 101 | + default='backdoor', help='method to unlearn') |
| 102 | + parser.add_argument('--trigger_size', type=int, default=4, |
| 103 | + help='The size of trigger of backdoor attack') |
| 104 | + return parser.parse_args() |
0 commit comments