Skip to content

Commit bbdeb4f

Browse files
committed
add codes
1 parent 6ed8735 commit bbdeb4f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+4948
-2
lines changed

README.md

+100-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,100 @@
1-
# Diffusion-Models-Improve-AT
2-
Code for the paper "Better Diffusion Models Further Improve Adversarial Training"
1+
# Better Diffusion Models Further Improve Adversarial Training
2+
3+
4+
5+
## Environment settings and libraries we used in our experiments
6+
7+
This project is tested under the following environment settings:
8+
- OS: Ubuntu 20.04.3
9+
- GPU: NVIDIA A100
10+
- Cuda: 11.1, Cudnn: v8.2
11+
- Python: 3.9.5
12+
- PyTorch: 1.8.0
13+
- Torchvision: 0.9.0
14+
15+
## Acknowledgement
16+
The codes are modifed based on the [PyTorch implementation](https://github.com/imrahulr/adversarial_robustness_pytorch) of [Rebuffi et al., 2021](https://arxiv.org/abs/2103.01946).
17+
18+
## Requirements
19+
20+
- Install or download [AutoAttack](https://github.com/fra31/auto-attack):
21+
```
22+
pip install git+https://github.com/fra31/auto-attack
23+
```
24+
25+
- Install or download [RandAugment](https://github.com/ildoonet/pytorch-randaugment):
26+
```
27+
pip install git+https://github.com/ildoonet/pytorch-randaugment
28+
```
29+
30+
- Download EDM generated data. Since 20M and 50M data files are too large, we split them into several parts:
31+
32+
| dataset | size | link |
33+
|---|:---:|:---:|
34+
| CIFAR-10 | 1M | [npz](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) |
35+
| CIFAR-10 | 5M | [npz](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) |
36+
| CIFAR-10 | 10M | [npz](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) |
37+
| CIFAR-10 | 20M | [part1](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) [part2](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) |
38+
| CIFAR-10 | 50M | [part1](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) [part2](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) [part3](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) [part4](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) |
39+
| CIFAR-100 | 1M | [npz](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_ddpm.npz) |
40+
| CIFAR-100 | 50M | [part1](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) [part2](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) [part3](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) [part4](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_ddpm.npz) |
41+
42+
- Merge 20M and 50M generated data:
43+
44+
```
45+
python merge_data.py
46+
```
47+
48+
## Training Commands
49+
50+
Run [`train-wa.py`](./train-wa.py) for reproducing the results reported in the papers. For example, train a WideResNet-28-10 model via [TRADES](https://github.com/yaodongyu/TRADES) on CIFAR-10 with the additional generated data provided by EDM ([Karras et al., 2022](https://github.com/NVlabs/edm)):
51+
52+
```python
53+
python train-wa.py --data-dir 'cifar-data' \
54+
--log-dir 'trained_models' \
55+
--desc 'WRN28-10Swish_cifar10s_lr0p2_TRADES5_epoch400_bs512_fraction0p7_ls0p1' \
56+
--data cifar10s \
57+
--batch-size 512 \
58+
--model wrn-28-10-swish \
59+
--num-adv-epochs 400 \
60+
--lr 0.2 \
61+
--beta 5.0 \
62+
--unsup-fraction 0.7 \
63+
--aux-data-filename <path_to_additional_data> \
64+
--ls 0.1
65+
```
66+
67+
68+
69+
## Downloading models
70+
71+
We provide checkpoints which Download a model from links listed in the following table. Clean and robust accuracies are measured on the full test set. The robust accuracy is measured using [AutoAttack](https://github.com/fra31/auto-attack).
72+
73+
| dataset | norm | radius | architecture | clean | robust | link |
74+
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
75+
| CIFAR-10 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-28-10 | 92.44% | 67.31% | [checkpoint](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt) [argtxt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt)
76+
| CIFAR-10 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-70-16 | 93.25% | 70.69% | [checkpoint](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn70-16_with.pt) [argtxt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt)
77+
| CIFAR-10 | &#8467;<sub>2</sub> | 128 / 255 | WRN-28-10 | 95.16% | 83.63% | [checkpoint](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_with.pt) [argtxt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt)
78+
| CIFAR-10 | &#8467;<sub>2</sub> | 128 / 255 | WRN-70-16 | 95.54% | 84.86% | [checkpoint](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_l2_wrn70-16_without.pt) [argtxt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt)
79+
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-28-10 | 72.58% | 38.83% | [checkpoint](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_with.pt) [argtxt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt)
80+
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-70-16 | 75.22% | 42.67% | [checkpoint](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_without.pt) [argtxt](https://storage.googleapis.com/dm-adversarial-robustness/cifar10_linf_wrn28-10_with.pt)
81+
82+
- **Downloading `checkpoint` to `trained_models/mymodel/weights-best.pt`**
83+
- **Downloading `argtxt` to `trained_models/mymodel/args.txt`**
84+
85+
## Evaluation Commands
86+
The trained models can be evaluated by running [`eval-aa.py`](./eval-aa.py) which uses [AutoAttack](https://github.com/fra31/auto-attack) for evaluating the robust accuracy. Run the command:
87+
88+
```python
89+
python eval-aa.py --data-dir 'cifar-data' \
90+
--log-dir 'trained_models' \
91+
--desc mymodel
92+
```
93+
94+
To evaluate the model on last epoch under AutoAttack, run the command:
95+
96+
```python
97+
python eval-last-aa.py --data-dir 'cifar-data' \
98+
--log-dir 'trained_models' \
99+
--desc mymodel
100+
```

core/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

core/attacks/__init__.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from .base import Attack
2+
3+
from .apgd import LinfAPGDAttack
4+
from .apgd import L2APGDAttack
5+
6+
from .fgsm import FGMAttack
7+
from .fgsm import FGSMAttack
8+
from .fgsm import L2FastGradientAttack
9+
from .fgsm import LinfFastGradientAttack
10+
11+
from .pgd import PGDAttack
12+
from .pgd import L2PGDAttack
13+
from .pgd import LinfPGDAttack
14+
15+
from .deepfool import DeepFoolAttack
16+
from .deepfool import LinfDeepFoolAttack
17+
from .deepfool import L2DeepFoolAttack
18+
19+
from .utils import CWLoss
20+
21+
22+
ATTACKS = ['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd']
23+
24+
25+
def create_attack(model, criterion, attack_type, attack_eps, attack_iter, attack_step, rand_init_type='uniform',
26+
clip_min=0., clip_max=1.):
27+
"""
28+
Initialize adversary.
29+
Arguments:
30+
model (nn.Module): forward pass function.
31+
criterion (nn.Module): loss function.
32+
attack_type (str): name of the attack.
33+
attack_eps (float): attack radius.
34+
attack_iter (int): number of attack iterations.
35+
attack_step (float): step size for the attack.
36+
rand_init_type (str): random initialization type for PGD (default: uniform).
37+
clip_min (float): mininum value per input dimension.
38+
clip_max (float): maximum value per input dimension.
39+
Returns:
40+
Attack
41+
"""
42+
43+
if attack_type == 'fgsm':
44+
attack = FGSMAttack(model, criterion, eps=attack_eps, clip_min=clip_min, clip_max=clip_max)
45+
elif attack_type == 'fgm':
46+
attack = FGMAttack(model, criterion, eps=attack_eps, clip_min=clip_min, clip_max=clip_max)
47+
elif attack_type == 'linf-pgd':
48+
attack = LinfPGDAttack(model, criterion, eps=attack_eps, nb_iter=attack_iter, eps_iter=attack_step,
49+
rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max)
50+
elif attack_type == 'l2-pgd':
51+
attack = L2PGDAttack(model, criterion, eps=attack_eps, nb_iter=attack_iter, eps_iter=attack_step,
52+
rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max)
53+
elif attack_type == 'linf-df':
54+
attack = LinfDeepFoolAttack(model, overshoot=0.02, nb_iter=attack_iter, search_iter=0, clip_min=clip_min,
55+
clip_max=clip_max)
56+
elif attack_type == 'l2-df':
57+
attack = L2DeepFoolAttack(model, overshoot=0.02, nb_iter=attack_iter, search_iter=0, clip_min=clip_min,
58+
clip_max=clip_max)
59+
elif attack_type == 'linf-apgd':
60+
attack = LinfAPGDAttack(model, criterion, n_restarts=2, eps=attack_eps, nb_iter=attack_iter)
61+
elif attack_type == 'l2-apgd':
62+
attack = L2APGDAttack(model, criterion, n_restarts=2, eps=attack_eps, nb_iter=attack_iter)
63+
else:
64+
raise NotImplementedError('{} is not yet implemented!'.format(attack_type))
65+
return attack

core/attacks/apgd.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import numpy as np
2+
3+
import torch
4+
from autoattack.autopgd_base import APGDAttack
5+
6+
7+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8+
9+
10+
class APGD():
11+
"""
12+
APGD attack (from AutoAttack) (Croce et al, 2020).
13+
The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point.
14+
Arguments:
15+
predict (nn.Module): forward pass function.
16+
loss_fn (str): loss function - ce or dlr.
17+
n_restarts (int): number of random restarts.
18+
eps (float): maximum distortion.
19+
nb_iter (int): number of iterations.
20+
ord (int): (optional) the order of maximum distortion (inf or 2).
21+
"""
22+
def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, ord=np.inf, seed=1):
23+
assert loss_fn in ['ce', 'dlr'], 'Only loss_fn=ce or loss_fn=dlr are supported!'
24+
assert ord in [2, np.inf], 'Only ord=inf or ord=2 are supported!'
25+
26+
norm = 'Linf' if ord == np.inf else 'L2'
27+
self.apgd = APGDAttack(predict, n_restarts=n_restarts, n_iter=nb_iter, verbose=False, eps=eps, norm=norm,
28+
eot_iter=1, rho=.75, seed=seed, device=device)
29+
self.apgd.loss = loss_fn
30+
31+
def perturb(self, x, y):
32+
x_adv = self.apgd.perturb(x, y)[1]
33+
r_adv = x_adv - x
34+
return x_adv, r_adv
35+
36+
37+
class LinfAPGDAttack(APGD):
38+
"""
39+
APGD attack (from AutoAttack) with order=Linf.
40+
The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point.
41+
Arguments:
42+
predict (nn.Module): forward pass function.
43+
loss_fn (str): loss function - ce or dlr.
44+
n_restarts (int): number of random restarts.
45+
eps (float): maximum distortion.
46+
nb_iter (int): number of iterations.
47+
"""
48+
49+
def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, seed=1):
50+
ord = np.inf
51+
super(L2APGDAttack, self).__init__(
52+
predict=predict, loss_fn=loss_fn, n_restarts=n_restarts, eps=eps, nb_iter=nb_iter, ord=ord, seed=seed)
53+
54+
55+
class L2APGDAttack(APGD):
56+
"""
57+
APGD attack (from AutoAttack) with order=L2.
58+
The attack performs nb_iter steps of adaptive size, while always staying within eps from the initial point.
59+
Arguments:
60+
predict (nn.Module): forward pass function.
61+
loss_fn (str): loss function - ce or dlr.
62+
n_restarts (int): number of random restarts.
63+
eps (float): maximum distortion.
64+
nb_iter (int): number of iterations.
65+
"""
66+
67+
def __init__(self, predict, loss_fn='ce', n_restarts=2, eps=0.3, nb_iter=40, seed=1):
68+
ord = 2
69+
super(L2APGDAttack, self).__init__(
70+
predict=predict, loss_fn=loss_fn, n_restarts=n_restarts, eps=eps, nb_iter=nb_iter, ord=ord, seed=seed)

core/attacks/base.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from .utils import replicate_input
5+
6+
7+
class Attack(object):
8+
"""
9+
Abstract base class for all attack classes.
10+
Arguments:
11+
predict (nn.Module): forward pass function.
12+
loss_fn (nn.Module): loss function.
13+
clip_min (float): mininum value per input dimension.
14+
clip_max (float): maximum value per input dimension.
15+
"""
16+
17+
def __init__(self, predict, loss_fn, clip_min, clip_max):
18+
self.predict = predict
19+
self.loss_fn = loss_fn
20+
self.clip_min = clip_min
21+
self.clip_max = clip_max
22+
23+
def perturb(self, x, **kwargs):
24+
"""
25+
Virtual method for generating the adversarial examples.
26+
Arguments:
27+
x (torch.Tensor): the model's input tensor.
28+
**kwargs: optional parameters used by child classes.
29+
Returns:
30+
adversarial examples.
31+
"""
32+
error = "Sub-classes must implement perturb."
33+
raise NotImplementedError(error)
34+
35+
def __call__(self, *args, **kwargs):
36+
return self.perturb(*args, **kwargs)
37+
38+
39+
class LabelMixin(object):
40+
def _get_predicted_label(self, x):
41+
"""
42+
Compute predicted labels given x. Used to prevent label leaking during adversarial training.
43+
Arguments:
44+
x (torch.Tensor): the model's input tensor.
45+
Returns:
46+
torch.Tensor containing predicted labels.
47+
"""
48+
with torch.no_grad():
49+
outputs = self.predict(x)
50+
_, y = torch.max(outputs, dim=1)
51+
return y
52+
53+
def _verify_and_process_inputs(self, x, y):
54+
if self.targeted:
55+
assert y is not None
56+
57+
if not self.targeted:
58+
if y is None:
59+
y = self._get_predicted_label(x)
60+
61+
x = replicate_input(x)
62+
y = replicate_input(y)
63+
return x, y

0 commit comments

Comments
 (0)