-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathdecoding_paraphrase.py
286 lines (241 loc) · 12.7 KB
/
decoding_paraphrase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import torch
import torch.nn.functional as F
import numpy as np
import time
import wandb
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from util import *
import time
import random
stop_words = set(stopwords.words('english'))
def decode(model, tokenizer, classifer, device, x="", z="", constraints=None, args=None, sys_prompt=None, prefix=None, model_back=None, zz=None):
'''
x: left context (prompt in lexical lexical task)
z: optimization target (original ending in counterfactual task)
constraints: (constraint set in lexical constrained task)
'''
model.eval()
words = ['sorry', 'apologetic', 'regretful',
'unethical', 'immoral', 'dishonest',
'illegal', 'unlawful', 'illicit',
'harmful', 'detrimental', 'injurious',
'afraid', 'fearful', 'apprehensive',
'cannot', 'unable',
'apologize', 'regret',
'unable', 'incapable', 'powerless',
'not', 'no', 'nay',
'responsible', 'accountable', 'liable',
'ethic', 'morality', 'moral',
'legal', 'lawful', 'legitimate',
'ethical', 'moral', 'principled',
'fulfill', 'accomplish', 'achieve',
'just', 'fair', 'equitable',
'trustworthy', 'reliable', 'dependable',
'repentant', 'penitent', 'remorseful',
'unprincipled', 'unscrupulous', 'corrupt',
'unlawful', 'prohibited', 'forbidden',
'noxious', 'destructive', 'damaging',
'timid', 'terrified', 'panicked',
'unable', 'incompetent', 'incapacitated',
'prohibited', 'banned', 'forbidden',
'deny', 'refuse', 'decline',
'answerable', 'culpable', 'answerable',
'moral code', 'values', 'principles',
'legitimate', 'valid', 'authorized',
'virtuous', 'righteous', 'honorable',
'satisfy', 'complete', 'execute',
'fair-minded', 'impartial', 'equitable',
'reliable', 'trustable', 'faithful', 'invalid','safe', 'not', "can't", "but", "against","0", "1", "2",
"3", "4", "5", "6", "7", "8", "9"]
lowercase_words = [word.upper() for word in words]
bad_words = words + lowercase_words
bad_words = ' '.join(bad_words)
if args.use_sysprompt:
x_sys = sys_prompt + x
x_ = tokenizer.encode(x_sys)[1:]
else:
prefix_ = tokenizer.encode(prefix)[1:]
prefix_t = torch.tensor(prefix_, device=device, dtype=torch.long)
prefix_onehot = one_hot(prefix_t, dimension=tokenizer.vocab_size)
# repeat batch_size times
prefix_t = prefix_t.unsqueeze(0).repeat(args.batch_size, 1)
prefix_onehot = prefix_onehot.repeat(args.batch_size, 1, 1)
x_mask = None
# extract keywords:
z_ = tokenizer.encode(z)[1:]
z_t = torch.tensor(z_, device=device, dtype=torch.long)
z_onehot = one_hot(z_t, dimension=tokenizer.vocab_size)
z_onehot = z_onehot.repeat(args.batch_size, 1, 1)
z_t = z_t.unsqueeze(0).repeat(args.batch_size, 1)
length = args.length
if length <= 0:
length = z_t.shape[1] - length
###################################################
# x_mask: [batch_size, vocab_size]
x_words = word_tokenize(x) # delete the ". " token we appended before
x_nonstop_words = [w.lower() for w in x_words if w.lower() not in stop_words and w.isalnum()]
x_nonstop_words = ' '.join(x_nonstop_words)
print('|' + x_nonstop_words + '|')
x_nonstop_ = tokenizer.encode(x_nonstop_words.strip())[1:]
x_t = torch.tensor(x_nonstop_, device=device, dtype=torch.long)
x_t = x_t.unsqueeze(0).repeat(args.batch_size, 1)
x_seq = tokenizer.encode(x)[1:]
x_seq_t = torch.tensor(x_seq, device=device, dtype=torch.long)
x_seq_t = x_seq_t.unsqueeze(0).repeat(args.batch_size, 1)
x_mask = np.zeros([tokenizer.vocab_size])
x_mask[x_nonstop_] = 1.
x_mask = torch.tensor(x_mask, device=device)
x_mask = x_mask.unsqueeze(0).unsqueeze(0).repeat(args.batch_size, length, 1)
print(f"\n the shape of x_mask is {x_mask.shape}\n")
if args.verbose:
print("prefix:\t|%s|\nz:\t|%s|\nx:\t|%s| length:\t%d\nconstraints:\t%s" % (
tokenizer.decode(prefix_), tokenizer.decode(z_), tokenizer.decode(x_nonstop_), length, constraints))
###################################################
ref_embedding = get_ref_embedding(model, x, device, tokenizer).to(torch.float16)
if args.init_mode == 'original':
init_logits = initialize(model, x_t, length, args.init_temp, args.batch_size ,device, tokenizer)
else:
init_logits = z_onehot / 0.01
init_logits = init_logits[:, :length, :]
if length > init_logits.shape[1]:
init_logits = torch.cat(
[init_logits,
10 * torch.rand([args.batch_size, length - init_logits.shape[1], tokenizer.vocab_size], device=device)],
dim=1)
text, _, _ = get_text_from_logits(init_logits, tokenizer)
for bi in range(args.batch_size):
print("[initial]: %s" % (text[bi]))
if args.wandb:
wandb.init(
project='args.mode' + str(int(round(time.time() * 1000))),
config=args)
y_logits = init_logits
# print(y_logits)
epsilon = torch.nn.Parameter(torch.zeros_like(y_logits))
if args.prefix_length > 0:
optim = torch.optim.Adam([epsilon, prefix_logits], lr=args.stepsize)
else:
optim = torch.optim.Adam([epsilon], lr=args.stepsize)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim, step_size=args.stepsize_iters,
gamma=args.stepsize_ratio)
frozen_len = args.frozen_length
y_logits_ = None
noise_std = 0.0
## Encode x beforehand
assert args.prefix_length <= 0, "The current code does not support prefix-length > 0"
soft_forward_prefix = prefix_onehot[:, -1:, :] # The last token of x is used in soft_forward
if prefix_t.shape[1] == 1:
prefix_model_past = None
else:
prefix_model_outputs = model(prefix_t[:, :-1], use_cache=True)
prefix_model_past = prefix_model_outputs.past_key_values
# For right to left model
mask_t = None
for iter in range(args.num_iters):
optim.zero_grad()
y_logits_ = y_logits + epsilon
soft_forward_y = y_logits_ / 0.001
if args.straight_through:
if mask_t is None:
soft_forward_y = (y_logits_.detach() / 0.001 - y_logits_).detach() + y_logits_
else:
soft_forward_y = top_k_filter_3d(y_logits_, args.topk, mask=mask_t, extra_mask=x_mask, bad_mask=None) / 0.001
if args.fp16:
with torch.autocast(device_type="cuda", dtype=torch.float16):
y_logits_t = soft_forward(model, soft_forward_prefix, soft_forward_y, args.topk, extra_mask=x_mask, x_past=prefix_model_past, bad_mask=None) # without gradient
else:
y_logits_t = soft_forward(model, soft_forward_prefix, soft_forward_y, args.topk, extra_mask=x_mask, x_past=prefix_model_past, bad_mask=None)
if args.topk == 0:
mask_t = None
else:
_, indices_t = torch.topk(y_logits_t, args.topk)
mask_t = torch.zeros_like(y_logits_t).scatter_(2, indices_t, 1)
flu_loss = soft_nll(
top_k_filter_3d(y_logits_t / args.output_lgt_temp, args.topk, extra_mask=x_mask, bad_mask=None),
y_logits_ / args.input_lgt_temp)
soft_forward_y_ = (y_logits_.detach() / 0.001 - y_logits_).detach() + y_logits_
if args.fp16:
with torch.autocast(device_type="cuda", dtype=torch.float16):
xyz_logits, xy_length = soft_forward_xyz(model, soft_forward_prefix, soft_forward_y_, z_onehot)
else:
xyz_logits, xy_length = soft_forward_xyz(model, soft_forward_prefix, soft_forward_y_, z_onehot)
# Reshaping
bz = args.batch_size
lg = xyz_logits.shape[1]
st = xy_length - 1
ed = xyz_logits.shape[1] - 1
xyz_logits = xyz_logits.view(-1, xyz_logits.shape[-1])
z_logits = torch.cat([xyz_logits[bi * lg + st:bi * lg + ed, :] for bi in range(bz)], dim=0)
# Jailbreak attack loss
c_loss_1 = torch.nn.CrossEntropyLoss(reduction='none')(
z_logits,
z_t.view(-1))
c_loss_1 = c_loss_1.view(args.batch_size, -1).mean(-1)
# Paraphrasing loss1
c_loss_2 = batch_log_bleulosscnn_ae(
decoder_outputs=top_k_filter_3d(y_logits_, args.topk, mask=mask_t, extra_mask=x_mask).transpose(0, 1),
target_idx=x_seq_t,
ngram_list=list(range(1, args.counterfactual_max_ngram + 1))
)
# Paraphrasing loss2
c_loss_3 = sim_score(model, top_k_filter_3d(y_logits_, args.topk, mask=mask_t, extra_mask=x_mask), ref_embedding)
loss = args.goal_weight * c_loss_1 + args.rej_weight * c_loss_2 + args.lr_nll_portion * flu_loss - c_loss_3
loss = loss.mean()
if iter < args.num_iters - 1: # so that the mask_t at the last iteration will not change
loss.backward()
optim.step()
scheduler.step()
last_lr = scheduler.get_last_lr()[0]
if args.verbose and ((iter + 1) % args.print_every == 0 or iter == 0 or iter + 1 == args.num_iters):
text, _, last_text_ids = decode_with_model_topk(
model, y_logits_, args.topk, soft_forward_prefix, prefix_model_past, tokenizer, extra_mask=x_mask, bad_mask=None)
text_post = text
for bi in range(args.batch_size):
prompt = prefix + text_post[bi]
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
print(f"\n Output of the model:\n")
output_ids = model.generate(inputs=input_ids, temperature=0.7, max_length = 512, pad_token_id=tokenizer.pad_token_id, do_sample=True, top_k=args.topk)
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
print(
"%d, loss: %.4f, c_loss_1: %.4f, c_loss_2: %.4f, c_loss_3: %.4f, lr: %.4f, |%s|" % (
iter + 1, loss.item(), c_loss_1[bi].item(), c_loss_2[bi].item(), c_loss_3[bi].item(), last_lr, text_post[bi]))
## noise
if iter < args.num_iters - 1:
large_noise_iters = [int(_) for _ in args.large_noise_iters.split(',')]
large_gs_stds = [float(_) for _ in args.large_gs_std.split(',')]
noise_std = 0.
if iter % args.noise_iters == 0:
noise_last = True
for ni in range(len(large_noise_iters)):
if iter < large_noise_iters[ni]:
noise_last = False
break
if noise_last:
noise_std = args.gs_std
else:
noise_std = large_gs_stds[ni]
noise = torch.normal(mean=args.gs_mean, std=noise_std, size=epsilon.size(),
device='cuda', requires_grad=False)
if args.win_anneal_iters >= 0 and iter >= args.win_anneal_iters:
zeros = torch.zeros_like(noise)
noise_mix = torch.cat([zeros[:, :frozen_len], noise[:, frozen_len:]], dim=1)
y_logits = y_logits + noise_mix
else:
y_logits = y_logits + noise
text, _, last_text_ids = decode_with_model_topk(
model, y_logits_, args.topk, soft_forward_prefix, prefix_model_past, tokenizer, extra_mask=x_mask, bad_mask=None)
text_post = text
decoded_text = []
for bi in range(args.batch_size):
prompt = prefix + " " + text[bi]
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
output_ids = model.generate(inputs=input_ids, temperature=0.7, max_length = 512, pad_token_id=tokenizer.pad_token_id, do_sample=True, top_k=args.topk)
output_ids = output_ids[:, input_ids.shape[1]:]
text_dec = tokenizer.decode(output_ids[0], skip_special_tokens=True)
decoded_text.append(text_dec.strip())
last_rank_loss = model(input_ids=last_text_ids, labels=last_text_ids).loss
last_rank_loss = last_rank_loss.detach().clone().data.cpu().numpy()
ppl_last = np.exp(last_rank_loss)
prompt_with_adv = [x + " " + t for t in text_post]
return ppl_last, text, text_post, decoded_text, prompt_with_adv