-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add DiffCSE model #2643
add DiffCSE model #2643
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave some comments
export CUDA_VISIBLE_DEVICES=0,1,2,3 | ||
|
||
python -u -m paddle.distributed.launch --gpus ${gpu_ids} \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
launch 启动任务建议加上 --log_dir 参数指定日志输出目录,否则启动多个任务的时候会同时写到 log 目录下,日志会串行。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加
python -u -m paddle.distributed.launch --gpus ${gpu_ids} \ | ||
run_diffcse.py \ | ||
--mode "train" \ | ||
--extractor_name "rocketqa-zh-dureader-query-encoder" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
变量命名建议和论文术语标准保持一致,extractor_name -> sentence encoder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改至:encoder_name
|
||
可支持配置的参数: | ||
* `mode`:可选,用于指明本次运行是模型训练、模型评估还是模型预测,仅支持[train, eval, infer]三种模式;默认为 infer。 | ||
* `extractor_name`:可选,DiffCSE模型中用于向量抽取的模型名称;默认为 ernie-1.0。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
变量命名规范化
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改至:encoder_name
|
||
python run_diffcse.py \ | ||
--mode "eval" \ | ||
--extractor_name "rocketqa-zh-dureader-query-encoder" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改至:encoder_name
if not with_pooler: | ||
ori_cls_embedding = sequence_output[:, 0, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的分支逻辑少 1 个 Else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修复
key_token_type_ids=key_token_type_ids, | ||
query_attention_mask=query_attention_mask, | ||
key_attention_mask=key_attention_mask, | ||
cls_token=tokenizer.cls_token_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cls_token 的作用是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
解释同上
if global_step % (args.eval_steps // 10) == 0 and rank == 0: | ||
print( | ||
"global step {}, epoch: {}, batch: {}, loss: {:.5f}, speed: {:.2f} step/s" | ||
.format(global_step, epoch, step, loss.item(), | ||
(args.eval_steps // 10) / | ||
(time.time() - tic_train))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 日志需要输出 RTD 任务的 Loss 和 Discriminator 预测的 Accuracy,不好分析结果。
- Generator 生成的样本数据可以存一部分到本地文件,用来分析 Debug。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加相关指标,画图功能
|
||
with paddle.no_grad(): | ||
# mask tokens for query and key input_ids and then predict mask token with generator | ||
input_ids = paddle.concat([query_input_ids, key_input_ids], axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么需要把文本重复2遍拼接起来?相当于同 1 个样本进行两次不同的 Mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同原论文设置
pred_tokens = self.generator( | ||
mlm_input_ids, attention_mask=attention_mask).argmax(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加一下必要的注释吧。mlm_input_ids 示例输入、pred_tokens 示例输出。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议这里的 API 使用 paddle.argmax 接口,明确指出 -1 对应的参数名 axis,代码语义表示更清楚一些。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
mlm_input_ids, attention_mask=attention_mask).argmax(-1) | ||
|
||
pred_tokens[:, 0] = cls_token | ||
e_inputs = pred_tokens * attention_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里预期的 attention_mask 输入是什么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该attention mask 即为tokenizer输出的attention mask,其作用是将padding位置mask掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cls_token 这个参数感觉非必要,通过 tokenizer 应该可以获取到特殊字符的 ID。
encoded_inputs = tokenizer(text=text, | ||
max_seq_len=max_seq_length, | ||
return_attention_mask=True) | ||
# print(encoded_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多余的注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
cosine_sim = cosine_sim - paddle.diag(margin_diag) | ||
|
||
# scale cosine to ease training converge | ||
cosine_sim *= self.scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和 DiffCSE 官方代码保持一致,去掉对 embedding 的 Normalize 操作和 scale 参数吧。
pred_tokens = self.generator( | ||
mlm_input_ids, attention_mask=attention_mask).argmax(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议这里的 API 使用 paddle.argmax 接口,明确指出 -1 对应的参数名 axis,代码语义表示更清楚一些。
key_token_type_ids=None, | ||
query_attention_mask=None, | ||
key_attention_mask=None, | ||
cls_token=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 cls_token 参数必须么?我理解可以根据 tokenizer 获取到 CLS 特殊字符的 ID。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 和 DiffCSE 官方代码保持一致,去掉对 embedding 的 Normalize 操作和 scale 参数吧。 : 已删除
- 建议这里的 API 使用 paddle.argmax 接口,明确指出 -1 对应的参数名 axis,代码语义表示更清楚一些。: 已指定axis=-1
- 这里 cls_token 参数必须么?我理解可以根据 tokenizer 获取到 CLS 特殊字符的 ID。: 已改为通过tokenizer获取
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave some comments
def read_text_pair(data_path, is_infer=False): | ||
with open(data_path, "r", encoding="utf-8") as f: | ||
for line in f: | ||
data = line.rstrip().split("\t") | ||
if is_infer: | ||
if len(data[0]) == 0 or len(data[1]) == 0: | ||
continue | ||
yield {"text_a": data[0], "text_b": data[1]} | ||
else: | ||
if len(data[0]) == 0 or len(data[1]) == 0 or len(data[2]) == 0: | ||
continue | ||
yield {"text_a": data[0], "text_b": data[1], "label": data[2]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数看起来是多余的,没有用到?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
read_text_pair在加载评估数据集的时候有用到
yield {"text_a": data[0], "text_b": data[1], "label": data[2]} | ||
|
||
|
||
def word_repetition(input_ids, token_type_ids, dup_rate=0.32): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数应该也没有用到,可以删除。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
word_repetition已删除
from sklearn import metrics | ||
|
||
|
||
def eval_metrics(labels, sims): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
厂内业务的评估逻辑代码不需要开源,DiffCSE 就开源论文中用的评估指标即可。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除当前评估逻辑,统一修改为spearman系数
y = y.unsqueeze(0) | ||
sim = self.cos(x, y) | ||
self.record = sim.detach() | ||
min_size = min(self.record.shape[0], self.record.shape[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x, y 2 个输入的向量个数有可能不相等么?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是在测试模式下,想获取x和y向量的相似度,那么x,y的向量个数必须是相等的
如果是在训练模式下,允许x和y个数不相等的,但在我们的输入数据处理场景中是相等的
self.pos_avg = paddle.diag(self.record).sum().item() / min_size | ||
self.neg_avg = (self.record.sum().item() - paddle.diag( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这2个变量的作用是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pos_avg用于统计一个输入batch中,正例的平均相似度
neg_avg用于统计一个输入batch中,负例的平均相似度
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Models
Description
commit DiffCSE model to paddleNLP repo.