Skip to content

Commit 68e2b8c

Browse files
committed
Upload files
1 parent 1080c55 commit 68e2b8c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+7140
-0
lines changed

.gitignore

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
__pycache__
2+
tensorboard
3+
commands
4+
5+
*.log
6+
*.pt
7+
*.tar
8+
*.pkl
9+
*.bat
10+
*.pth
11+
*.png
12+
*.jpg
13+
*.sh
14+
*.pdf
15+
*.info

LS.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
4+
class LabelSmoothingCrossEntropy(nn.Module):
5+
def __init__(self, eps=0.1, reduction='mean'):
6+
super(LabelSmoothingCrossEntropy, self).__init__()
7+
self.eps = eps
8+
self.reduction = reduction
9+
10+
def forward(self, output, target):
11+
c = output.size()[-1]
12+
log_preds = F.log_softmax(output, dim=-1)
13+
if self.reduction=='sum':
14+
loss = -log_preds.sum()
15+
else:
16+
loss = -log_preds.sum(dim=-1)
17+
if self.reduction=='mean':
18+
loss = loss.mean()
19+
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction)

SAM.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
3+
4+
class SAM(torch.optim.Optimizer):
5+
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
6+
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
7+
8+
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
9+
super(SAM, self).__init__(params, defaults)
10+
11+
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
12+
self.param_groups = self.base_optimizer.param_groups
13+
self.defaults.update(self.base_optimizer.defaults)
14+
15+
@torch.no_grad()
16+
def first_step(self, zero_grad=False):
17+
grad_norm = self._grad_norm()
18+
for group in self.param_groups:
19+
scale = group["rho"] / (grad_norm + 1e-12)
20+
21+
for p in group["params"]:
22+
if p.grad is None: continue
23+
self.state[p]["old_p"] = p.data.clone()
24+
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
25+
p.add_(e_w) # climb to the local maximum "w + e(w)"
26+
27+
if zero_grad: self.zero_grad()
28+
29+
@torch.no_grad()
30+
def second_step(self, zero_grad=False):
31+
for group in self.param_groups:
32+
for p in group["params"]:
33+
if p.grad is None: continue
34+
p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
35+
36+
self.base_optimizer.step() # do the actual "sharpness-aware" update
37+
38+
if zero_grad: self.zero_grad()
39+
40+
@torch.no_grad()
41+
def step(self, closure=None):
42+
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
43+
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
44+
45+
self.first_step(zero_grad=True)
46+
closure()
47+
self.second_step()
48+
49+
def _grad_norm(self):
50+
shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
51+
norm = torch.norm(
52+
torch.stack([
53+
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
54+
for group in self.param_groups for p in group["params"]
55+
if p.grad is not None
56+
]),
57+
p=2
58+
)
59+
return norm
60+
61+
def load_state_dict(self, state_dict):
62+
super().load_state_dict(state_dict)
63+
self.base_optimizer.param_groups = self.param_groups

arg_parser.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import argparse
2+
3+
4+
def parse_args():
5+
parser = argparse.ArgumentParser(
6+
description='PyTorch Lottery Tickets Experiments')
7+
8+
##################################### Dataset #################################################
9+
parser.add_argument('--data', type=str, default='../data',
10+
help='location of the data corpus')
11+
parser.add_argument('--dataset', type=str,
12+
default='cifar10', help='dataset')
13+
parser.add_argument('--input_size', type=int,
14+
default=32, help='size of input images')
15+
parser.add_argument('--data_dir', type=str,
16+
default='./tiny-imagenet-200', help='dir to tiny-imagenet')
17+
parser.add_argument('--num_workers', type=int, default=4)
18+
parser.add_argument('--num_classes', type=int, default=10)
19+
##################################### Architecture ############################################
20+
parser.add_argument('--arch', type=str,
21+
default='resnet18', help='model architecture')
22+
parser.add_argument('--imagenet_arch', action="store_true",
23+
help="architecture for imagenet size samples")
24+
parser.add_argument('--train_y_file', type=str,
25+
default='./labels/train_ys.pth', help='labels for training files')
26+
parser.add_argument('--val_y_file', type=str,
27+
default='./labels/val_ys.pth', help='labels for validation files')
28+
##################################### General setting ############################################
29+
parser.add_argument('--seed', default=2, type=int, help='random seed')
30+
parser.add_argument('--train_seed', default=1, type=int,
31+
help='seed for training (default value same as args.seed)')
32+
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
33+
parser.add_argument('--workers', type=int, default=4,
34+
help='number of workers in dataloader')
35+
parser.add_argument('--resume', action="store_true",
36+
help="resume from checkpoint")
37+
parser.add_argument('--checkpoint', type=str,
38+
default=None, help='checkpoint file')
39+
parser.add_argument(
40+
'--save_dir', help='The directory used to save the trained models', default=None, type=str)
41+
parser.add_argument('--mask', type=str, default=None, help='sparse model')
42+
43+
##################################### Training setting #################################################
44+
parser.add_argument('--batch_size', type=int,
45+
default=256, help='batch size')
46+
parser.add_argument('--lr', default=0.1, type=float,
47+
help='initial learning rate')
48+
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
49+
parser.add_argument('--weight_decay', default=5e-4,
50+
type=float, help='weight decay')
51+
parser.add_argument('--epochs', default=182, type=int,
52+
help='number of total epochs to run')
53+
parser.add_argument('--warmup', default=0, type=int, help='warm up epochs')
54+
parser.add_argument('--print_freq', default=50,
55+
type=int, help='print frequency')
56+
parser.add_argument('--decreasing_lr', default='91,136',
57+
help='decreasing strategy')
58+
parser.add_argument('--no-aug', action='store_true', default=False,
59+
help='No augmentation in training dataset (transformation).')
60+
parser.add_argument('--no-l1-epochs', default=0, type=int, help='non l1 epochs')
61+
##################################### Pruning setting #################################################
62+
parser.add_argument('--prune', type=str, default="omp",
63+
help="method to prune")
64+
parser.add_argument('--pruning_times', default=1,
65+
type=int, help='overall times of pruning (only works for IMP)')
66+
parser.add_argument('--rate', default=0.95, type=float,
67+
help='pruning rate') # pruning rate is always 20%
68+
parser.add_argument('--prune_type', default='rewind_lt', type=str,
69+
help='IMP type (lt, pt or rewind_lt)')
70+
parser.add_argument('--random_prune', action='store_true',
71+
help='whether using random prune')
72+
parser.add_argument('--rewind_epoch', default=0,
73+
type=int, help='rewind checkpoint')
74+
parser.add_argument('--rewind_pth', default=None,
75+
type=str, help='rewind checkpoint to load')
76+
parser.add_argument('--hf_vit', default='NO', type=str,
77+
choices=['YES', 'NO'], help='lora method')
78+
parser.add_argument('--debug', default=False, type=bool,
79+
choices=[True, False], help='lora method')
80+
81+
##################################### Unlearn setting #################################################
82+
parser.add_argument('--unlearn', type=str,
83+
default='retrain', help='method to unlearn')
84+
parser.add_argument('--unlearn_lr', default=0.01, type=float,
85+
help='initial learning rate')
86+
parser.add_argument('--unlearn_epochs', default=10, type=int,
87+
help='number of total epochs for unlearn to run')
88+
parser.add_argument('--num_indexes_to_replace', type=int, default=None,
89+
help='Number of data to forget')
90+
parser.add_argument('--class_to_replace', type=int, default=0,
91+
help='Specific class to forget')
92+
parser.add_argument('--indexes_to_replace', type=list, default=None,
93+
help='Specific index data to forget')
94+
parser.add_argument('--alpha', default=0.2, type=float,
95+
help='unlearn noise')
96+
parser.add_argument('--lora', default='NO', type=str,
97+
choices=['YES', 'NO'], help='lora method')
98+
99+
##################################### Attack setting #################################################
100+
parser.add_argument('--attack', type=str,
101+
default='backdoor', help='method to unlearn')
102+
parser.add_argument('--trigger_size', type=int, default=4,
103+
help='The size of trigger of backdoor attack')
104+
return parser.parse_args()

0 commit comments

Comments
 (0)