Skip to content

Commit

Permalink
add mteb evaluation (#8538)
Browse files Browse the repository at this point in the history
* add mteb evaluation

* add mteb evaluation

* add mteb evaluation

* add mteb evaluation
  • Loading branch information
cxa-unique authored Jun 5, 2024
1 parent 3c21f0a commit 1cf780e
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 5 deletions.
92 changes: 88 additions & 4 deletions pipelines/examples/contrastive_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

## 安装

推荐安装gpu版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html),以cuda11.7的paddle为例,安装命令如下:
推荐安装gpu版本的[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/conda/linux-conda.html),以cuda11.7的paddle为例,安装命令如下:

```
python -m pip install paddlepaddle-gpu==2.6.0.post117 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
conda install nccl -c conda-forge
conda install paddlepaddle-gpu==2.6.1 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge
```
安装其他依赖:
```
Expand Down Expand Up @@ -98,15 +99,98 @@ python evaluation/benchmarks.py --model_type bert \
--passage_model checkpoints/checkpoint-1500 \
--query_max_length 64 \
--passage_max_length 512 \
--evaluate_all
```
- `model_type`: 模型的类似,可选bert或roberta等等
- `query_model`: query向量模型的路径
- `passage_model`: passage向量模型的路径
- `query_max_length`: query的最大长度
- `passage_max_length`: passage的最大长度
- `evaluate_all`: 是否评估所有的checkpoint,默认为False,即只评估指定的checkpoint
- `checkpoint_dir`: 与`evaluate_all`一起使用


## MTEB评估
[MTEB](https://github.com/embeddings-benchmark/mteb)
是一个大规模文本嵌入评测基准,包含了丰富的向量检索评估任务和数据集。
本仓库主要面向其中的中英文检索任务(Retrieval),并以SciFact数据集作为主要示例。

评估RepLLaMA向量检索模型([repllama-v1-7b-lora-passage](https://huggingface.co/castorini/repllama-v1-7b-lora-passage)):
```
export CUDA_VISIBLE_DEVICES=0
python evaluation/mteb/eval_mteb.py \
--base_model_name_or_path castorini/repllama-v1-7b-lora-passage \
--output_folder en_results/repllama-v1-7b-lora-passage \
--task_name SciFact \
--task_split test \
--query_instruction 'query: ' \
--document_instruction 'passage: ' \
--pooling_method last \
--max_seq_length 512 \
--eval_batch_size 2 \
--pad_token unk_token \
--padding_side right \
--add_bos_token 0 \
--add_eos_token 1
```
结果文件保存在`en_results/repllama-v1-7b-lora-passage/SciFact/last/no_revision_available/SciFact.json`,包含以下类似的评估结果:
```
'ndcg_at_1': 0.63,
'ndcg_at_3': 0.71785,
'ndcg_at_5': 0.73735,
'ndcg_at_10': 0.75708,
'ndcg_at_20': 0.7664,
'ndcg_at_100': 0.77394,
'ndcg_at_1000': 0.7794
```

评估BGE向量检索模型([bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5)):
```
export CUDA_VISIBLE_DEVICES=0
python evaluation/mteb/eval_mteb.py \
--base_model_name_or_path BAAI/bge-large-en-v1.5 \
--output_folder en_results/bge-large-en-v1.5 \
--task_name SciFact \
--task_split test \
--document_instruction 'Represent this sentence for searching relevant passages: ' \
--pooling_method mean \
--max_seq_length 512 \
--eval_batch_size 32 \
--pad_token pad_token \
--padding_side right \
--add_bos_token 0 \
--add_eos_token 0
```
结果文件保存在`en_results/bge-large-en-v1.5/SciFact/mean/no_revision_available/SciFact.json`,包含以下类似的评估结果:
```
'ndcg_at_1': 0.64667,
'ndcg_at_3': 0.70359,
'ndcg_at_5': 0.7265,
'ndcg_at_10': 0.75675,
'ndcg_at_20': 0.76743,
'ndcg_at_100': 0.77511,
'ndcg_at_1000': 0.77939
```

可支持配置的参数:
- `base_model_name_or_path`: 模型名称或路径
- `output_folder`: 结果文件存储路径
- `task_name`:任务(数据集)名称,如SciFact
- `task_split`:测试查询集合,如test或dev
- `query_instruction`:查询前添加的提示文本,如'query: '或None
- `document_instruction`:文档前添加的提示文本,如'passage: '或None
- `pooling_method`:获取表示的方式,last表示取最后token,mean表示取平均,cls表示取`[CLS]`token
- `max_seq_length`: 最大序列长度
- `eval_batch_size`: 模型预测的批次大小(单个GPU)
- `pad_token`:设置padding的token,可取unk_token、eos_token或pad_token
- `padding_side`:设置padding的位置,可取left或right
- `add_bos_token`:是否添加起始符,0表示不添加,1表示添加
- `add_eos_token`:是否添加结束符,0表示不添加,1表示添加


## Reference

[1] Aditya Kusupati, Gantavya Bhatt, Aniket Rege, Matthew Wallingford, Aditya Sinha, Vivek Ramanujan, William Howard-Snyder, Kaifeng Chen, Sham M. Kakade, Prateek Jain, Ali Farhadi: Matryoshka Representation Learning. NeurIPS 2022
[1] Aditya Kusupati, Gantavya Bhatt, Aniket Rege, Matthew Wallingford, Aditya Sinha, Vivek Ramanujan, William Howard-Snyder, Kaifeng Chen, Sham M. Kakade, Prateek Jain, Ali Farhadi: Matryoshka Representation Learning. NeurIPS 2022.

[2] Xueguang Ma, Liang Wang, Nan Yang, Furu Wei, Jimmy Lin: Fine-Tuning LLaMA for Multi-Stage Text Retrieval. arXiv 2023.

[3] Shitao Xiao, Zheng Liu, Peitian Zhang, Niklas Muennighof: C-Pack: Packaged Resources To Advance General Chinese Embedding. SIGIR 2024.
107 changes: 107 additions & 0 deletions pipelines/examples/contrastive_training/evaluation/mteb/eval_mteb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging

from mteb import MTEB
from mteb_models import EncodeModel

from paddlenlp.transformers import AutoModel, AutoTokenizer


def get_model(peft_model_name, base_model_name):
if peft_model_name is not None:
raise NotImplementedError("PEFT model is not supported yet")
else:
base_model = AutoModel.from_pretrained(base_model_name)
return base_model


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_name_or_path", default="bge-large-en-v1.5", type=str)
parser.add_argument("--peft_model_name_or_path", default=None, type=str)
parser.add_argument("--output_folder", default="tmp", type=str)

parser.add_argument("--task_name", default="SciFact", type=str)
parser.add_argument(
"--task_split",
default="test",
help='Note that some datasets do not have "test", they only have "dev"',
type=str,
)

parser.add_argument("--query_instruction", default=None, help="add prefix instruction before query", type=str)
parser.add_argument(
"--document_instruction", default=None, help="add prefix instruction before document", type=str
)

parser.add_argument("--pooling_method", default="last", help="choose in [mean, last, cls]", type=str)
parser.add_argument("--max_seq_length", default=512, type=int)
parser.add_argument("--eval_batch_size", default=1, type=int)

parser.add_argument("--pad_token", default="unk_token", help="unk_token, eos_token or pad_token", type=str)
parser.add_argument("--padding_side", default="left", help="right or left", type=str)
parser.add_argument("--add_bos_token", default=0, help="1 means add token", type=int)
parser.add_argument("--add_eos_token", default=1, help="1 means add token", type=int)

return parser.parse_args()


if __name__ == "__main__":
args = get_args()

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
logger.info("Args: {}".format(args))

model = get_model(args.peft_model_name_or_path, args.base_model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path)
assert hasattr(tokenizer, args.pad_token), f"Tokenizer does not have {args.pad_token} token"
token_dict = {"unk_token": tokenizer.unk_token, "eos_token": tokenizer.eos_token, "pad_token": tokenizer.pad_token}
tokenizer.pad_token = token_dict[args.pad_token]

assert args.padding_side in [
"right",
"left",
], f"padding_side should be either 'right' or 'left', but got {args.padding_side}"
assert not (
args.padding_side == "left" and args.pooling_method == "cls"
), "Padding 'left' is not supported for pooling method 'cls'"
tokenizer.padding_side = args.padding_side

assert args.add_bos_token in [0, 1], f"add_bos_token should be either 0 or 1, but got {args.add_bos_token}"
assert args.add_eos_token in [0, 1], f"add_eos_token should be either 0 or 1, but got {args.add_eos_token}"
tokenizer.add_bos_token = bool(args.add_bos_token)
tokenizer.add_eos_token = bool(args.add_eos_token)

encode_model = EncodeModel(
model=model,
tokenizer=tokenizer,
pooling_method=args.pooling_method,
query_instruction=args.query_instruction,
document_instruction=args.document_instruction,
eval_batch_size=args.eval_batch_size,
max_seq_length=args.max_seq_length,
)

logger.info("Ready to eval")
evaluation = MTEB(tasks=[args.task_name])
evaluation.run(
encode_model,
output_folder=f"{args.output_folder}/{args.task_name}/{args.pooling_method}",
eval_splits=[args.task_split],
)
127 changes: 127 additions & 0 deletions pipelines/examples/contrastive_training/evaluation/mteb/mteb_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Union

import numpy as np
import paddle
from tqdm import tqdm


class EncodeModel:
def __init__(
self,
model,
tokenizer,
pooling_method: str = "last",
query_instruction: str = None,
document_instruction: str = None,
eval_batch_size: int = 64,
max_seq_length: int = 512,
):
self.model = model
self.tokenizer = tokenizer
self.pooling_method = pooling_method
self.query_instruction = query_instruction
self.document_instruction = document_instruction
self.eval_batch_size = eval_batch_size
self.max_seq_length = max_seq_length

if paddle.device.is_compiled_with_cuda():
self.device = paddle.device.set_device("gpu")
else:
self.device = paddle.device.set_device("cpu")
self.model = self.model.to(self.device)

num_gpus = paddle.device.cuda.device_count()
if num_gpus > 1:
raise NotImplementedError("Multi-GPU is not supported yet.")

def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
"""
This function will be used to encode queries for retrieval task
if there is a instruction for queries, we will add it to the query text
"""
if self.query_instruction is not None:
input_texts = [f"{self.query_instruction}{query}" for query in queries]
else:
input_texts = queries
return self.encode(input_texts)

def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> np.ndarray:
"""
This function will be used to encode corpus for retrieval task
if there is a instruction for docs, we will add it to the doc text
"""
if isinstance(corpus[0], dict):
if self.document_instruction is not None:
input_texts = [
"{}{} {}".format(self.document_instruction, doc.get("title", ""), doc["text"]).strip()
for doc in corpus
]
else:
input_texts = ["{} {}".format(doc.get("title", ""), doc["text"]).strip() for doc in corpus]
else:
if self.document_instruction is not None:
input_texts = [f"{self.document_instruction}{doc}" for doc in corpus]
else:
input_texts = corpus
return self.encode(input_texts)

@paddle.no_grad()
def encode(self, sentences: List[str], **kwargs) -> np.ndarray:
self.model.eval()
all_embeddings = []
for start_index in tqdm(range(0, len(sentences), self.eval_batch_size), desc="Batches"):
sentences_batch = sentences[start_index : start_index + self.eval_batch_size]

inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors="pd",
max_length=self.max_seq_length,
return_attention_mask=True,
)
outputs = self.model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
return_dict=True,
output_hidden_states=True,
)
last_hidden_state = outputs.hidden_states[-1]

if self.pooling_method == "last":
if self.tokenizer.padding_side == "right":
sequence_lengths = inputs.attention_mask.sum(axis=1)
last_token_indices = sequence_lengths - 1
embeddings = last_hidden_state[paddle.arange(last_hidden_state.shape[0]), last_token_indices]
elif self.tokenizer.padding_side == "left":
embeddings = last_hidden_state[:, -1]
else:
raise NotImplementedError(f"Padding side {self.tokenizer.padding_side} not supported.")
elif self.pooling_method == "cls":
embeddings = last_hidden_state[:, 1]
elif self.pooling_method == "mean":
s = paddle.sum(last_hidden_state * inputs.attention_mask.unsqueeze(-1), axis=1)
d = inputs.attention_mask.sum(axis=1, keepdim=True)
embeddings = s / d
else:
raise NotImplementedError(f"Pooling method {self.pooling_method} not supported.")

embeddings = paddle.nn.functional.normalize(embeddings, p=2, axis=-1)

all_embeddings.append(embeddings.cpu().numpy().astype("float32"))

return np.concatenate(all_embeddings, axis=0)
3 changes: 2 additions & 1 deletion pipelines/examples/contrastive_training/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
paddlenlp>2.6.1
datasets
torch==2.0.1
mteb[beir]
mteb
beir
typer==0.9.0

0 comments on commit 1cf780e

Please sign in to comment.