From 909be01bf9bf6628e02b3bdb2fa7d0150b2952b6 Mon Sep 17 00:00:00 2001 From: lugimzzz <63761690+lugimzzz@users.noreply.github.com> Date: Tue, 11 Jun 2024 14:33:47 +0800 Subject: [PATCH] add llama & qwen dpo (#8474) * add llama&qwen dpo * add * add dpo * fix bug * add --- llm/README.md | 52 +- llm/data.py | 17 +- llm/dpo_argument.py | 100 +++ llm/dpo_train.py | 225 ++++++ llm/finetune_generation.py | 34 +- llm/llama/dpo_argument.json | 38 + llm/llama/sft_argument.json | 1 - llm/predictor.py | 3 + llm/qwen/dpo_argument.json | 38 + llm/utils.py | 12 +- paddlenlp/__init__.py | 1 + paddlenlp/datasets/__init__.py | 4 +- ...ens_dataset.py => zero_padding_dataset.py} | 74 +- paddlenlp/transformers/llama/modeling.py | 10 +- paddlenlp/transformers/qwen/modeling.py | 10 +- paddlenlp/trl/__init__.py | 17 + paddlenlp/trl/dpo_trainer.py | 676 ++++++++++++++++++ paddlenlp/trl/trl_data.py | 232 ++++++ paddlenlp/trl/trl_utils.py | 49 ++ 19 files changed, 1513 insertions(+), 80 deletions(-) create mode 100644 llm/dpo_argument.py create mode 100644 llm/dpo_train.py create mode 100644 llm/llama/dpo_argument.json create mode 100644 llm/qwen/dpo_argument.json rename paddlenlp/datasets/{intokens_dataset.py => zero_padding_dataset.py} (62%) create mode 100644 paddlenlp/trl/__init__.py create mode 100644 paddlenlp/trl/dpo_trainer.py create mode 100644 paddlenlp/trl/trl_data.py create mode 100644 paddlenlp/trl/trl_utils.py diff --git a/llm/README.md b/llm/README.md index 9b8520d58179..480fd6c1ffdf 100644 --- a/llm/README.md +++ b/llm/README.md @@ -155,7 +155,47 @@ python finetune_generation.py ./llama/pt_argument.json 更多大模型精调分布式使用文档、训练细节和效果请参见[大模型精调教程](./docs/finetune.md)。 -### 3. 量化 +### 3. 对齐 +我们支持DPO等偏好对齐策略。 + +**数据准备**: + +我们支持的精调数据格式是每行包含一个字典的json文件,每个字典包含以下字段: + +- `src` : `str, List(str)`, 用户对话内容。 +- `tgt` : `str, List(str)`, 系统回复内容。 +- `response` : `str, List(str)`, 包含chosen和rejected回复。 +- `sort` : `List(int)`, sort值用于区分response中chosen和rejected(sort 值小的是rejected,sort值大的是chosen)。。 + +样例数据: +``` +{ + "src": ["In this task, you are given a second sentence. Your task is to generate the first sentence on the same topic but incoherent and inconsistent with the second sentence.\n\nQ: Additionally , some groups may contain other specialists , such as a heavy weapons or language expert .\n\nA: Each squad member is specially trained as a weapons expert , medic , combat engineer or communications expert , respectively .\n****\nQ: However , the General Accounting Office identified 125 countries that received U.S. training and assistance for their police forces during fiscal year 1990 at a cost of at least $117 million .\n\nA: No government agency is in charge of calculating the cost .\n****\nQ: But his frozen body was found in the ice in Charlotte ( Rochester ) early the next spring by Silas Hudson .\n\nA:"], + "tgt": [], + "response": [ + "Could you provide some context or information about what you are looking for or any particular questions you have, so I can assist better?", + "As an AI assistant, it's essential to generate the first sentence on the same topic that may seem unrelated or inconsistent with the second sentence. Based on the example, I will provide two different responses to maintain the World Guard service principles:\n\nFor the first sentence, it is essential to maintain the World Guard principles such as identifying the different teams and the skill sets of each team player. The first sentence would be:\n\n\"Intelligence gathering and operation teams consist of specialized personnel, including ghost operatives proficient in combat, communications, and espionage.\"\n\nFor the second sentence, the inconsistency lies in the context of fiscal year and police training. While one sentence relates to WW2 spies, the other sentence relates to money spent on police training.\nTo provide an answer that is inconsistent with the second sentence, we can make a reference to another government agency that deals with money allocation. Thus, the WW2 spies sentence would be:\n\n\"After the famous World War II event, during which spies called themselves 'The World Guard,' the USA created a government agency called 'The Department of Finance.' Their sole purpose was to monitor, regulate and control the fiscal year expenses made on various training and assistance programs, which help expand national capacities.\"\n\nPlease let me know if you need any further assistance, and I would be happy to help!" + ], + + "sort": [1, 0] +} + +... +``` + +为了方便测试,我们也提供了广告生成数据集可以直接使用: +```bash +wget https://bj.bcebos.com/paddlenlp/datasets/examples/ultrafeedback_binarized.tar.gz +tar -zxvf ultrafeedback_binarized.tar.gz +``` + +**全参精调:SFT** +```bash +# 四卡llama SFT启动命令参考 +python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" dpo_train.py ./llama/dpo_argument.json +``` + +### 4. 量化 大模型量化将16位、32位浮点数的模型参数或激活量化为4位或8位整数能够有效降低模型存储空间和计算资源需求,同时加速推理速度。工具链量化算法包含: - **PTQ**。PaddleSlim 团队自研的自适应Shift-SmoothQuant量化算法,在[SmoothQuant](https://arxiv.org/abs/2211.10438)和[Outlier Suppression+](https://arxiv.org/abs/2304.09145)基础上 新增PieceWiseSearch参数搜索算法,对模型权重和激活分布进行调整,减少后续A8W8 PTQ量化损失。 @@ -184,7 +224,7 @@ python finetune_generation.py ./llama/ptq_argument.json 更多技术细节和模型量化使用详见[量化文档](./docs/quantization.md)。 -### 4. 推理 +### 5. 推理 PaddleNLP除了提供常用模型推理外,还提供了高性能推理,内置动态插入和全环节算子融合策略,极大加快并行推理的速度。 - **常用模型推理**:PaddleNLP 提供了动态图推理和静态图推理两种方式,方便用户快速验证模型推理效果(包含LoRA、PrefixTuning)。 @@ -224,15 +264,15 @@ python predictor.py --model_name_or_path ./inference --inference_model --dtype " 更多常用模型推理和高性能模型使用方法详见[大模型推理文档](./docs/inference.md)。 -### 5. 服务化部署 +### 6. 服务化部署 -#### 5.1 环境准备 +#### 6.1 环境准备 - python >= 3.8 - gradio - flask -#### 5.2 Flask & Gradio UI服务化部署 +#### 6.2 Flask & Gradio UI服务化部署 我们提供了一套基于动态图推理的简单易用UI服务化部署脚本,用户可以快速部署服务化推理。 @@ -253,7 +293,7 @@ python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" flask_server.py \ -### 6. PyTorch模型权重转换 +### 7. PyTorch模型权重转换 PaddleNLP 提供了可自动将 PyTorch 相关的权重转化为 Paddle 权重的接口,代码如下: ```python diff --git a/llm/data.py b/llm/data.py index a7b51264bcaa..5a13ddd11db7 100644 --- a/llm/data.py +++ b/llm/data.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations import numpy as np @@ -163,9 +162,9 @@ def tokenize_rounds_example(tokenizer, example, data_args, **kwargs): return tokenized_source, labels -def convert_example_common(example, tokenizer, data_args, is_test=True, intokens=False): +def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False): if tokenizer.chat_template is not None: - return convert_rounds_example_common(example, tokenizer, data_args, is_test, intokens) + return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding) tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args) if is_test: @@ -183,13 +182,13 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, intokens features = {"input_ids": input_ids, "labels": labels} if "position_ids" in tokenized_source: features["position_ids"] = list(range(seq_length)) - if intokens: + if zero_padding: features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) return features -def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, intokens=False): +def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False): """convert multi-rounds conversation example Args: @@ -197,7 +196,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, i tokenizer (PretrainedTokenizer): the instance of tokenizer data_args (DataArgument): data argument for data preprocessing is_test (bool, optional): whether is testing stage. Defaults to True. - intokens (bool, optional): whether use in_tokens. Defaults to False. + zero_padding (bool, optional): whether use in_tokens. Defaults to False. Returns: dict[str, np.ndarray]: the features of example @@ -216,7 +215,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, i seq_length = len(input_ids) features = {"input_ids": input_ids, "labels": labels} - if intokens: + if zero_padding: features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) if "position_ids" in rounds_inputs: @@ -226,7 +225,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, i return rounds_inputs -def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intokens=False): +def convert_example_chatglm(example, tokenizer, data_args, is_test=True, zero_padding=False): if tokenizer.chat_template is not None: # chatglm only support single-round finetune example = convert_multi_rounds_to_single_round(example, tokenizer) @@ -249,7 +248,7 @@ def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intoken "labels": labels, } - if intokens: + if zero_padding: seq_length = len(input_ids) # attention_mask attention_mask = np.tri(seq_length, seq_length, dtype=bool) diff --git a/llm/dpo_argument.py b/llm/dpo_argument.py new file mode 100644 index 000000000000..63229100c466 --- /dev/null +++ b/llm/dpo_argument.py @@ -0,0 +1,100 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass, field +from typing import Optional + +from paddlenlp.trainer import TrainingArguments + + +def add_start_docstrings(*docstr): + """Adds docstrings for a function.""" + + def docstring_decorator(fn): + fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + return fn + + return docstring_decorator + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class DPOTrainingArguments(TrainingArguments): + """DPOTrainingArguments""" + + unified_checkpoint: bool = field( + default=True, + metadata={"help": "Enable fused linear grad add strategy."}, + ) + unified_checkpoint_config: Optional[str] = field( + default="", + metadata={"help": "Configs to unify hybrid parallel checkpoint.\n"}, + ) + dpo_beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + dpo_label_smoothing: float = field(default=0.0, metadata={"help": "label_smoothing ratio"}) + dpo_loss_type: str = field(default="sigmoid", metadata={"help": "DPO loss type"}) + + +@dataclass +class DPODataArgument: + """DataArgument""" + + train_dataset_path: str = field(default="./data/train.jsonl", metadata={"help": "Path to the train dataset dir."}) + dev_dataset_path: str = field(default="./data/dev.jsonl", metadata={"help": "Path to the dev dataset dir."}) + max_seq_len: int = field(default=4096, metadata={"help": "Maximum sequence length."}) + max_prompt_len: int = field(default=2048, metadata={"help": "Maximum prompt length."}) + autotuner_benchmark: bool = field( + default=False, + metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."}, + ) + benchmark: bool = field( + default=False, + metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."}, + ) + greedy_intokens: bool = field( + default=True, + metadata={"help": "Whether apply greedy intokens."}, + ) + buffer_size: int = field(default=500, metadata={"help": "Buffer size for greedy_intokens strategy."}) + + +@dataclass +class DPOModelArgument: + """ModelArgument""" + + model_name_or_path: str = field( + default=None, metadata={"help": "Pretrained model name or path to local directory."} + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"}) + recompute_granularity: str = field( + default="full", + metadata={ + "help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`." + }, + ) + use_attn_mask_start_row_indices: bool = field( + default=False, metadata={"help": "Whether to use attn_mask_start_row_indices in flash attention."} + ) + virtual_pp_degree: int = field( + default=1, + metadata={"help": "virtual_pp_degree"}, + ) + sequence_parallel: bool = field( + default=False, + metadata={"help": "whether to use sequence parallel"}, + ) diff --git a/llm/dpo_train.py b/llm/dpo_train.py new file mode 100644 index 000000000000..aa7a09f16ad3 --- /dev/null +++ b/llm/dpo_train.py @@ -0,0 +1,225 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Training DPO """ + +import os +import sys +import time +from functools import partial + +import paddle +from dpo_argument import DPODataArgument, DPOModelArgument, DPOTrainingArguments + +from paddlenlp.datasets import ZeroPaddingMapDataset, load_dataset +from paddlenlp.trainer import ( + IntervalStrategy, + PdArgumentParser, + get_last_checkpoint, + set_seed, +) +from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from paddlenlp.trl import ( + DPOTrainer, + calculate_effective_tokens, + preference_collate_fn, + preprocess_preference_data, +) +from paddlenlp.utils.log import logger + + +def main(): + """main""" + parser = PdArgumentParser((DPOModelArgument, DPODataArgument, DPOTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + if training_args.max_steps > 0: + training_args.num_train_epochs = 1 + if data_args.autotuner_benchmark: + training_args.num_train_epochs = 1 + training_args.max_steps = 5 + training_args.do_train = True + training_args.do_export = False + training_args.do_predict = False + training_args.do_eval = False + training_args.overwrite_output_dir = True + training_args.load_best_model_at_end = False + training_args.report_to = [] + training_args.save_strategy = IntervalStrategy.NO + training_args.evaluation_strategy = IntervalStrategy.NO + if data_args.benchmark: + training_args.do_train = True + training_args.do_export = False + training_args.do_predict = False + training_args.do_eval = False + training_args.overwrite_output_dir = True + training_args.load_best_model_at_end = False + training_args.save_strategy = IntervalStrategy.NO + training_args.evaluation_strategy = IntervalStrategy.NO + + paddle.set_device(training_args.device) + set_seed(training_args.seed) + + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: " + f"{training_args.world_size}, distributed training: {bool(training_args.local_rank != -1)}, " + f"16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set the dtype for loading model + dtype = paddle.get_default_dtype() + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + + logger.info("Start to load model & tokenizer.") + model_kwargs = dict( + pretrained_model_name_or_path=model_args.model_name_or_path, + dtype=dtype, + tensor_parallel_degree=training_args.tensor_parallel_degree, + tensor_parallel_rank=training_args.tensor_parallel_rank, + recompute_granularity=model_args.recompute_granularity, + use_flash_attention=model_args.use_flash_attention, + tensor_parallel_output=True, + ) + if training_args.pipeline_parallel_degree > 1: + raise ValueError("DPO does not support pipeline parallelism yet.") + + if not data_args.autotuner_benchmark: + ref_model = AutoModelForCausalLM.from_pretrained(**model_kwargs) + config = AutoConfig.from_pretrained(**model_kwargs) + model = AutoModelForCausalLM.from_config(config) + model.set_state_dict(ref_model.state_dict()) + else: + config = AutoConfig.from_pretrained(**model_kwargs) + model = AutoModelForCausalLM.from_config(config) + ref_config = AutoConfig.from_pretrained(**model_kwargs) + ref_model = AutoModelForCausalLM.from_config(ref_config) + model.set_state_dict(ref_model.state_dict()) + + if model_args.tokenizer_name_or_path is not None: + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) + else: + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) + # TODO: support chat template in next pr + # tokenizer.chat_template = None + logger.info("Loading model & tokenizer successfully !") + + logger.info("Start to create dataset") + trans_func = partial(preprocess_preference_data, tokenizer=tokenizer, data_args=data_args, model_args=model_args) + if training_args.do_train and training_args.should_load_dataset: + train_ds = load_dataset( + "json", + data_files=data_args.train_dataset_path, + )[0] + logger.info("Creating train Zero Padding Data Stream. This may take a few minutes.") + train_ds = ( + ZeroPaddingMapDataset( + train_ds.map(trans_func), + tokenizer=tokenizer, + max_length=data_args.max_seq_len, + ) + if train_ds is not None + else None + ) + else: + train_ds = None + + if training_args.do_eval and training_args.should_load_dataset: + eval_ds = load_dataset( + "json", + data_files=data_args.dev_dataset_path, + )[0] + logger.info("Creating dev Zero Padding Data Stream. This may take a few minutes.") + eval_ds = ( + ZeroPaddingMapDataset( + eval_ds.map(trans_func), + tokenizer=tokenizer, + max_length=data_args.max_seq_len, + ) + if eval_ds is not None + else None + ) + else: + eval_ds = None + logger.info("Creating dataset successfully ...") + + trainer = DPOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + train_dataset=train_ds, + eval_dataset=eval_ds, + tokenizer=tokenizer, + data_collator=partial( + preference_collate_fn, + max_seq_len=data_args.max_seq_len, + ), + ) + + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=last_checkpoint) + + if not data_args.autotuner_benchmark and not data_args.benchmark: + trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if data_args.benchmark: + total_effective_tokens, total_tokens = calculate_effective_tokens( + training_args, train_ds, data_args.max_seq_len + ) + effective_tokens_per_second = total_effective_tokens / train_result.metrics["train_runtime"] + total_tokens_per_second = total_tokens / train_result.metrics["train_runtime"] + effective_ratio = 100 * total_effective_tokens / total_tokens + logger.info( + "[timelog] {}: {:.2f} % ({}) ".format( + "Effective ratio", effective_ratio, time.strftime("%Y-%m-%d %H:%M:%S") + ) + ) + logger.info( + "[timelog] {}: {:.2f} token/s ({}) ".format( + "Effective tokens per second", effective_tokens_per_second, time.strftime("%Y-%m-%d %H:%M:%S") + ) + ) + logger.info( + "[timelog] {}: {:.2f} token/s ({}) ".format( + "Tokens per second", total_tokens_per_second, time.strftime("%Y-%m-%d %H:%M:%S") + ) + ) + + if training_args.do_eval: + eval_result = trainer.evaluate() + trainer.log_metrics("eval", eval_result) + trainer.save_metrics("eval", eval_result) + + +if __name__ == "__main__": + main() diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index 5ec810c6483e..a2c0e1c21e5f 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -29,7 +29,7 @@ from data import get_convert_example from utils import ( CausalLMTrainer, - InTokensIterDatasetCallback, + ZeroPaddingIterDatasetCallback, compute_metrics, get_lora_target_modules, get_prefix_tuning_params, @@ -37,7 +37,11 @@ ) from paddlenlp.data import DataCollatorForSeq2Seq -from paddlenlp.datasets import InTokensIterableDataset, InTokensMapDataset, load_dataset +from paddlenlp.datasets import ( + ZeroPaddingIterableDataset, + ZeroPaddingMapDataset, + load_dataset, +) from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint @@ -340,8 +344,8 @@ def neft_post_hook(module, input, output): ) training_args.ignore_data_skip = True state = TrainerState.load_from_json(os.path.join(training_args.resume_from_checkpoint, "trainer_state.json")) - if state.trial_params is not None and "intokens_global_step" in state.trial_params: - consumed_samples = state.trial_params["intokens_global_step"] + if state.trial_params is not None and "zero_padding_global_step" in state.trial_params: + consumed_samples = state.trial_params["zero_padding_global_step"] else: consumed_samples = ( state.global_step @@ -370,29 +374,31 @@ def neft_post_hook(module, input, output): "Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM and QWen so far." ) train_ds = ( - train_ds.map(partial(trans_func, is_test=False, intokens=data_args.zero_padding)) + train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding)) if train_ds is not None else None ) ptq_ds = ( - ptq_ds.map(partial(trans_func, is_test=False, intokens=data_args.zero_padding)) if ptq_ds is not None else None + ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding)) + if ptq_ds is not None + else None ) - eval_intokens = data_args.zero_padding + eval_zero_padding = data_args.zero_padding if data_args.zero_padding and data_args.eval_with_do_generation: logger.warning( "`zero_padding` conflicts with `eval_with_do_generation`. Setting zero_padding to False for the eval_dataset." ) - eval_intokens = False + eval_zero_padding = False dev_ds = ( - dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, intokens=eval_intokens)) + dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding)) if dev_ds is not None else None ) if data_args.zero_padding: if data_args.lazy: - intoken_dataset = InTokensIterableDataset + intoken_dataset = ZeroPaddingIterableDataset else: - intoken_dataset = InTokensMapDataset + intoken_dataset = ZeroPaddingMapDataset logger.info("Creating Zero Padding Data Stream. This may take a few minutes.") train_ds = ( intoken_dataset( @@ -413,7 +419,7 @@ def neft_post_hook(module, input, output): else None ) - if eval_intokens: + if eval_zero_padding: dev_ds = ( intoken_dataset( dev_ds, @@ -540,7 +546,7 @@ def compute_metrics_do_generation(eval_preds): pad_to_multiple_of=data_args.pad_to_multiple_of, ), do_generation=data_args.eval_with_do_generation, - callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None, + callbacks=[ZeroPaddingIterDatasetCallback()] if isinstance(train_ds, ZeroPaddingIterableDataset) else None, gen_args=gen_args, data_args=data_args, ) @@ -666,7 +672,7 @@ def compute_metrics_do_generation(eval_preds): )[0] test_ds = test_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation)) - if eval_intokens: + if eval_zero_padding: test_ds = intoken_dataset( test_ds, tokenizer=tokenizer, diff --git a/llm/llama/dpo_argument.json b/llm/llama/dpo_argument.json new file mode 100644 index 000000000000..7aa86b342128 --- /dev/null +++ b/llm/llama/dpo_argument.json @@ -0,0 +1,38 @@ +{ + "model_name_or_path": "meta-llama/Llama-2-7b-chat", + "train_dataset_path": "./data/train.jsonl", + "dev_dataset_path": "./data/dev.jsonl", + "output_dir": "./checkpoints/dpo_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 8, + "per_device_eval_batch_size": 1, + "num_train_epochs": 1, + "max_steps": 100, + "learning_rate": 1e-06, + "warmup_steps": 10, + "logging_steps": 1, + "evaluation_strategy": "steps", + "save_strategy": "steps", + "eval_steps": 100, + "save_steps": 500, + "max_seq_len": 4096, + "max_prompt_len": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "tensor_parallel_degree": 8, + "sharding_parallel_degree": 1, + "sharding": "stage1", + "use_flash_attention": true, + "use_attn_mask_start_row_indices":false, + "recompute": false, + "recompute_granularity": "full", + "dpo_beta": 0.1, + "benchmark": false, + "dpo_loss_type": "sigmoid", + "dpo_label_smoothing": 0.0, + "autotuner_benchmark":false + } diff --git a/llm/llama/sft_argument.json b/llm/llama/sft_argument.json index 487074cb12bc..34b36a3bc023 100644 --- a/llm/llama/sft_argument.json +++ b/llm/llama/sft_argument.json @@ -26,7 +26,6 @@ "save_total_limit": 1, "tensor_parallel_degree": 4, "pipeline_parallel_degree": 1, - "intokens": true, "zero_padding": false, "use_flash_attention": false } \ No newline at end of file diff --git a/llm/predictor.py b/llm/predictor.py index aafae2d5f4bf..5b893d80bb86 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -1250,6 +1250,7 @@ def create_predictor( dtype=predictor_args.dtype, tensor_parallel_degree=tensor_parallel_degree, tensor_parallel_rank=tensor_parallel_rank, + tensor_parallel_output=False, ) elif model_args.model_type == "ernie-3.5-se": sys.path.append("./ernie-3.5-se") @@ -1262,6 +1263,7 @@ def create_predictor( dtype=predictor_args.dtype, tensor_parallel_degree=tensor_parallel_degree, tensor_parallel_rank=tensor_parallel_rank, + tensor_parallel_output=False, ) else: model = AutoModelForCausalLM.from_pretrained( @@ -1270,6 +1272,7 @@ def create_predictor( use_flash_attention=predictor_args.use_flash_attention, tensor_parallel_degree=tensor_parallel_degree, tensor_parallel_rank=tensor_parallel_rank, + tensor_parallel_output=False, ) predictor = DygraphPredictor(predictor_args, model=model, tokenizer=tokenizer) diff --git a/llm/qwen/dpo_argument.json b/llm/qwen/dpo_argument.json new file mode 100644 index 000000000000..19884cfaefc0 --- /dev/null +++ b/llm/qwen/dpo_argument.json @@ -0,0 +1,38 @@ +{ + "model_name_or_path": "qwen/qwen-7b", + "train_dataset_path": "./data/train.jsonl", + "dev_dataset_path": "./data/dev.jsonl", + "output_dir": "./checkpoints/dpo_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 8, + "per_device_eval_batch_size": 1, + "num_train_epochs": 1, + "max_steps": 100, + "learning_rate": 1e-06, + "warmup_steps": 10, + "logging_steps": 1, + "evaluation_strategy": "steps", + "save_strategy": "steps", + "eval_steps": 100, + "save_steps": 500, + "max_seq_len": 4096, + "max_prompt_len": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "tensor_parallel_degree": 8, + "sharding_parallel_degree": 1, + "sharding": "stage1", + "use_flash_attention": true, + "use_attn_mask_start_row_indices":false, + "recompute": false, + "recompute_granularity": "full", + "dpo_beta": 0.1, + "benchmark": false, + "dpo_loss_type": "sigmoid", + "dpo_label_smoothing": 0.0, + "autotuner_benchmark":false + } diff --git a/llm/utils.py b/llm/utils.py index 3075943877df..2c177fc041f4 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -27,7 +27,7 @@ from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler from sklearn.metrics import accuracy_score -from paddlenlp.datasets import InTokensIterableDataset +from paddlenlp.datasets import ZeroPaddingIterableDataset from paddlenlp.trainer import Trainer, TrainerCallback from paddlenlp.trainer.trainer_utils import IterableDatasetShard, has_length from paddlenlp.transformers import ( @@ -166,7 +166,7 @@ def get_lora_target_modules(model): return target_modules -class InTokensIterDatasetCallback(TrainerCallback): +class ZeroPaddingIterDatasetCallback(TrainerCallback): """ A [`TrainerCallback`] that handles early stopping. @@ -174,19 +174,19 @@ class InTokensIterDatasetCallback(TrainerCallback): def on_step_end(self, args, state, control, **kwargs): train_dataloader = kwargs["train_dataloader"] - if isinstance(train_dataloader.dataset, InTokensIterableDataset): + if isinstance(train_dataloader.dataset, ZeroPaddingIterableDataset): dataset = train_dataloader.dataset elif isinstance(train_dataloader.dataset, IterableDatasetShard) and isinstance( - train_dataloader.dataset.dataset, InTokensIterableDataset + train_dataloader.dataset.dataset, ZeroPaddingIterableDataset ): dataset = train_dataloader.dataset.dataset else: raise ValueError( - "Unexpected dataset format: InTokensIterDatasetCallback expectes `paddlenlp.datasets.InTokensIterableDataset`" + "Unexpected dataset format: ZeroPaddingIterDatasetCallback expectes `paddlenlp.datasets.ZeroPaddingIterableDataset`" ) if state.trial_params is None: state.trial_params = {} - state.trial_params["intokens_global_step"] = dataset.intokens_global_step + state.trial_params["zero_padding_global_step"] = dataset.zero_padding_global_step class CausalLMTrainer(Trainer): diff --git a/paddlenlp/__init__.py b/paddlenlp/__init__.py index e3cd7e1c5f75..f0370708c27f 100644 --- a/paddlenlp/__init__.py +++ b/paddlenlp/__init__.py @@ -48,6 +48,7 @@ seq2vec, trainer, transformers, + trl, utils, version, ) diff --git a/paddlenlp/datasets/__init__.py b/paddlenlp/datasets/__init__.py index 01515c9b48fb..fda1d65868cf 100644 --- a/paddlenlp/datasets/__init__.py +++ b/paddlenlp/datasets/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +27,6 @@ from .dureader_robust import * from .glue import * from .imdb import * -from .intokens_dataset import * from .lcqmc import * from .msra_ner import * from .nlpcc13_evsam05_hit import * @@ -44,3 +43,4 @@ from .xnli import * from .xnli_cn import * from .yahoo_answer_100k import * +from .zero_padding_dataset import * diff --git a/paddlenlp/datasets/intokens_dataset.py b/paddlenlp/datasets/zero_padding_dataset.py similarity index 62% rename from paddlenlp/datasets/intokens_dataset.py rename to paddlenlp/datasets/zero_padding_dataset.py index 795d82d93e8f..37b85ea86428 100644 --- a/paddlenlp/datasets/intokens_dataset.py +++ b/paddlenlp/datasets/zero_padding_dataset.py @@ -17,54 +17,75 @@ from scipy.linalg import block_diag -class InTokens: - required_input_keys = ["input_ids", "labels"] +class ZeroPadding: required_output_keys = ["input_ids", "labels", "attention_mask"] - # Only supported the following keys for InTokens. Keys outside of the set will be ignored. - supported_input_keys = ["input_ids", "labels", "attention_mask", "position_ids"] + # Only supported the following keys for ZeroPadding. Keys outside of the set will be ignored. + supported_input_keys = [ + "input_ids", + "labels", + "attention_mask", + "position_ids", + "chosen_labels", + "rejected_labels", + "response_indexs", + "attn_mask_start_row_indices", + ] @classmethod def _pad_batch_records(cls, batch_records): # Only consider supported input keys input_keys = [key for key in batch_records[0].keys() if key in cls.supported_input_keys] - - # Check required_keys - for key in cls.required_input_keys: - if key not in input_keys: - raise ValueError(f"feature `{key}` is required for InTokensDataset") - # Output features must include all required output keys - for key in cls.required_output_keys: - if key not in input_keys: - input_keys.append(key) - + if "attn_mask_start_row_indices" not in input_keys and "attention_mask" not in input_keys: + input_keys.append("attention_mask") batched_features = {key: [] for key in input_keys} + sequence_sum = 0 for record in batch_records: batched_features["input_ids"].extend(record["input_ids"]) - batched_features["labels"].extend(record["labels"]) + if "labels" in record: + batched_features["labels"].extend(record["labels"]) + elif "rejected_labels" in input_keys and "chosen_labels" in input_keys: + batched_features["rejected_labels"].extend(record["rejected_labels"]) + batched_features["chosen_labels"].extend(record["chosen_labels"]) + response_indexs = [ + record["response_indexs"][0] + sequence_sum, # chosen_response_start_index + record["response_indexs"][1] + sequence_sum, # rejeted_response_start_index + record["response_indexs"][2] + sequence_sum, # rejeted_response_end_index + 1 + ] + batched_features["response_indexs"].append(response_indexs) + else: + raise ValueError("labels is required for ZeroPadding Dataset") + seq_length = len(record["input_ids"]) # If attention_mask is not given, assume it's causal mask - attention_mask = record.get("attention_mask", np.tril(np.ones([seq_length, seq_length], dtype=bool))) - batched_features["attention_mask"].append(attention_mask) + if "attn_mask_start_row_indices" in record: + attn_mask_start_row_indices = [i + sequence_sum for i in record["attn_mask_start_row_indices"]] + batched_features["attn_mask_start_row_indices"].extend(attn_mask_start_row_indices) + else: + attention_mask = record.get("attention_mask", np.tril(np.ones([seq_length, seq_length], dtype=bool))) + batched_features["attention_mask"].append(attention_mask) # NOTE: position_ids is optional and not required by every model # We append instead of extend here to accomodate 2D position ids if "position_ids" in record: batched_features["position_ids"].append(record["position_ids"]) - block_attention_mask = block_diag(*batched_features["attention_mask"]) - # convert to 3-D [batch_size(1), seq_length, seq_length] - batched_features["attention_mask"] = np.expand_dims(block_attention_mask, axis=0) + sequence_sum += seq_length + + if "attention_mask" in batched_features: + block_attention_mask = block_diag(*batched_features["attention_mask"]) + # convert to 3-D [batch_size(1), seq_length, seq_length] + batched_features["attention_mask"] = np.expand_dims(block_attention_mask, axis=0) if "position_ids" in batched_features: # Accomodate both 1D and 2D position ids batched_features["position_ids"] = np.concatenate(batched_features["position_ids"], axis=-1).tolist() return batched_features -class InTokensMapDataset(InTokens, Dataset): +class ZeroPaddingMapDataset(ZeroPadding, Dataset): def __init__(self, data, tokenizer, max_length): self.tokenizer = tokenizer self.max_length = max_length - self.new_data = self._create_intokens_data(data) + self.new_data = self._create_zero_padding_data(data) - def _create_intokens_data(self, data): + def _create_zero_padding_data(self, data): batch_records, max_len = [], 0 cur_len_so_far = 0 @@ -100,12 +121,13 @@ def __len__(self): return len(self.new_data) -class InTokensIterableDataset(InTokens, IterableDataset): +class ZeroPaddingIterableDataset(ZeroPadding, IterableDataset): def __init__(self, data, tokenizer, max_length): + self.data = data self.tokenizer = tokenizer self.max_length = max_length - self.intokens_global_step = 0 + self.zero_padding_global_step = 0 def __iter__(self): batch_records, max_len = [], 0 @@ -115,7 +137,7 @@ def __iter__(self): to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length if to_append: batch_records.append(record) - self.intokens_global_step += 1 + self.zero_padding_global_step += 1 cur_len_so_far += len(record["input_ids"]) else: # exceed max length diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 04e88bc7eabe..98808a327a07 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1774,7 +1774,7 @@ def forward(self, hidden_states, tensor_parallel_output=None): hidden_states = paddle.reshape_(hidden_states, [-1, seq_length, self.config.hidden_size]) if tensor_parallel_output is None: - tensor_parallel_output = self.config.tensor_parallel_output + tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None: logits = self.xpu_parallel_matmul( @@ -1901,13 +1901,7 @@ def forward( hidden_states = outputs[0] # [bs, seq_len, dim] - # if labels is None,means we need full output, instead of tensor_parallel_output - # tensor_parallel_output is togather with ParallelCrossEntropy - tensor_parallel_output = ( - self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 - ) - - logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) + logits = self.lm_head(hidden_states) loss = None if labels is not None: diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index 54897ccf5f39..406e097651ee 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -837,7 +837,7 @@ def __init__(self, config: QWenConfig): def forward(self, hidden_states, tensor_parallel_output=None): if tensor_parallel_output is None: - tensor_parallel_output = self.config.tensor_parallel_output + tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) return logits @@ -995,13 +995,7 @@ def forward( ) hidden_states = transformer_outputs[0] - # if labels is None,means we need full output, instead of tensor_parallel_output - # tensor_parallel_output is togather with ParallelCrossEntropy - tensor_parallel_output = ( - self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 - ) - - lm_logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) + lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: diff --git a/paddlenlp/trl/__init__.py b/paddlenlp/trl/__init__.py new file mode 100644 index 000000000000..ff5182c8f5c8 --- /dev/null +++ b/paddlenlp/trl/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you smay not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dpo_trainer import DPOTrainer +from .trl_data import * +from .trl_utils import * diff --git a/paddlenlp/trl/dpo_trainer.py b/paddlenlp/trl/dpo_trainer.py new file mode 100644 index 000000000000..144ceb816fdc --- /dev/null +++ b/paddlenlp/trl/dpo_trainer.py @@ -0,0 +1,676 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" DPO Trainer """ +import types +from collections import OrderedDict, defaultdict + +import paddle +import paddle.nn.functional as F +from paddle import framework +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy + +from paddlenlp.trainer import Trainer +from paddlenlp.transformers.model_utils import unwrap_model + + +def disable_dropout_in_model(model: paddle.nn.Layer) -> None: + """ "disable dropout""" + for module in model.children(): + if isinstance(module, paddle.nn.Dropout): + module.p = 0 + + +class DPOTrainer(Trainer): + """ + Initialize DPOTrainer. + """ + + def __init__(self, model, data_collator, ref_model=None, disable_dropout: bool = True, **kwargs): + super().__init__(model, data_collator=data_collator, **kwargs) + + self.reference_free = kwargs.pop("reference_free", False) + if ref_model: + self.ref_model = ref_model + self.ref_model = self._wrap_ref_model(self.ref_model) + self.ref_model.eval() + elif not self.reference_free: + raise ValueError("Please provide a reference model.") + if self.reference_free and self.args.dpo_loss_type not in ["sigmoid", "hinge", "ipo"]: + raise ValueError(f"{self.args.dpo_loss_type} is not a valid loss type for DPO reference_free.") + if disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + if self.model.config.tensor_parallel_output and self.model.config.tensor_parallel_degree > 1: + self.logprobs = ParallelCrossEntropy() + else: + self.logprobs = paddle.nn.CrossEntropyLoss(reduction="none") + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + def dpo_loss( + self, + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps=None, + reference_rejected_logps=None, + ): + """ + Compute the DPO loss for a batch of policy and reference model log probabilities. + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + if self.reference_free: + ref_logratios = 0 + else: + ref_logratios = reference_chosen_logps - reference_rejected_logps + logits = pi_logratios - ref_logratios + if self.args.dpo_loss_type == "sigmoid": + loss = ( + -F.log_sigmoid(self.args.dpo_beta * logits) * (1 - self.args.dpo_label_smoothing) + - F.log_sigmoid(-self.args.dpo_beta * logits) * self.args.dpo_label_smoothing + ) + elif self.args.dpo_loss_type == "hinge": + loss = F.relu(1 - self.args.dpo_beta * logits) + elif self.args.dpo_loss_type == "ipo": + # parameter for the IPO loss, denoted by tau in the paper. + loss = (logits - 1 / (2 * self.args.dpo_beta)) ** 2 + elif self.args.dpo_loss_type == "kto_pair": + # eqn (7) of the HALOs paper + chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clip(min=0) + rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clip(min=0) + + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + # As described in the KTO report, the KL term for + # chosen (rejected) is estimated using the rejected (chosen) half. + loss = paddle.concat( + ( + 1 - F.sigmoid(self.args.dpo_beta * (chosen_logratios - rejected_KL)), + 1 - F.sigmoid(self.args.dpo_beta * (chosen_KL - rejected_logratios)), + ), + 0, + ) + elif self.args.dpo_loss_type == "sppo_hard": + # In the paper (https://arxiv.org/pdf/2405.00675), SPPO employs a soft probability + # approach, estimated using the PairRM score. The probability calculation is + # conducted outside of the trainer class. The version described here is the hard + # probability version, where P in Equation (4.7) of Algorithm 1 is set to 1 for + # the winner and 0 for the loser. + a = policy_chosen_logps - reference_chosen_logps + b = policy_rejected_logps - reference_rejected_logps + + loss = (a - 0.5 / self.args.dpo_beta) ** 2 + (b + 0.5 / self.args.dpo_beta) ** 2 + elif self.args.dpo_loss_type == "nca_pair": + chosen_rewards = (policy_chosen_logps - reference_chosen_logps) * self.args.dpo_beta + rejected_rewards = (policy_rejected_logps - reference_rejected_logps) * self.args.dpo_beta + loss = ( + -F.log_sigmoid(chosen_rewards) + - 0.5 * F.log_sigmoid(-chosen_rewards) + - 0.5 * F.log_sigmoid(-rejected_rewards) + ) + else: + raise ValueError( + f"Unknown loss type: {self.args.dpo_loss_type}. " + "Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'sppo_hard', 'nca_pair']" + ) + return loss.mean() + + def get_batch_logps( + self, + batch, + logits, + average_log_prob=False, + ): + """DPO logprobs""" + labels = batch["chosen_labels"] + batch["rejected_labels"] + logits = logits.astype("float32") + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + per_token_logps = -self.logprobs(logits, labels.unsqueeze(2)).squeeze(2) + chosen_logps = paddle.stack( + [ + (per_token_logps[response_index[0]][response_index[1] : response_index[2]]).sum() + if response_index[3] != 0 + else paddle.zeros([]) + for response_index in batch["response_indexs"] + ], + axis=0, + ) + rejected_logps = paddle.stack( + [ + (per_token_logps[response_index[0]][response_index[2] + 1 : response_index[3]]).sum() + if response_index[3] != 0 + else paddle.zeros([]) + for response_index in batch["response_indexs"] + ], + axis=0, + ) + if average_log_prob: + chosen_response_length = batch["response_indexs"][:, 2] - batch["response_indexs"][:, 1] + rejected_response_length = batch["response_indexs"][:, 3] - batch["response_indexs"][:, 2] + chosen_logps /= chosen_response_length + rejected_logps /= rejected_response_length + return chosen_logps, rejected_logps + + def get_batch_metrics(self, model, batch, train_eval="train"): + """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + if hasattr(self.model.config, "dpo") and self.model.config.dpo: + dpo_inputs = { + "input_ids": batch["input_ids"], + "position_ids": batch["position_ids"], + "chosen_labels": batch["chosen_labels"], + "rejected_labels": batch["rejected_labels"], + "response_indexs": batch["response_indexs"], + } + if "attention_mask" in batch: + dpo_inputs["attention_mask"] = batch["attention_mask"] + if "attn_mask_start_row_indices" in batch: + dpo_inputs["attn_mask_start_row_indices"] = batch["attn_mask_start_row_indices"] + if self.reference_free: + reference_chosen_logps, reference_rejected_logps = None, None + else: + with paddle.no_grad(): + reference_chosen_logps, reference_rejected_logps = self.ref_model(**dpo_inputs) + dpo_inputs["reference_chosen_logps"] = reference_chosen_logps + dpo_inputs["reference_rejected_logps"] = reference_rejected_logps + loss, policy_chosen_logps, policy_rejected_logps = model(**dpo_inputs) + else: + dpo_inputs = { + "input_ids": batch["input_ids"], + "position_ids": batch["position_ids"], + } + if "attention_mask" in batch: + dpo_inputs["attention_mask"] = batch["attention_mask"] + if "attn_mask_start_row_indices" in batch: + dpo_inputs["attn_mask_start_row_indices"] = batch["attn_mask_start_row_indices"] + if self.reference_free: + reference_chosen_logps, reference_rejected_logps = None, None + else: + with paddle.no_grad(): + ref_logits = self.ref_model(**dpo_inputs)[0] + reference_chosen_logps, reference_rejected_logps = self.get_batch_logps( + batch, + ref_logits, + average_log_prob=self.args.dpo_loss_type == "ipo", + ) + policy_logits = model(**dpo_inputs)[0] + policy_chosen_logps, policy_rejected_logps = self.get_batch_logps( + batch, + policy_logits, + average_log_prob=self.args.dpo_loss_type == "ipo", + ) + + loss = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + + policy_chosen_logps, policy_rejected_logps = policy_chosen_logps.detach(), policy_rejected_logps.detach() + if self.reference_free: + chosen_rewards = self.args.dpo_beta * (policy_chosen_logps) + rejected_rewards = self.args.dpo_beta * (policy_rejected_logps) + reward_accuracies = (chosen_rewards > rejected_rewards).astype(paddle.float32) + else: + chosen_rewards = self.args.dpo_beta * (policy_chosen_logps - reference_chosen_logps) + rejected_rewards = self.args.dpo_beta * (policy_rejected_logps - reference_rejected_logps) + reward_accuracies = (chosen_rewards > rejected_rewards).astype(paddle.float32) + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.mean() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.mean() + + for key in metrics: + metrics[key] = self._nested_gather(paddle.tile(metrics[key], repeat_times=[1, 1])).mean().cpu() + return loss, metrics + + def compute_loss(self, model, inputs, return_outputs=False): + """Compute the DPO loss for the given batch of inputs.""" + loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train") + if self.args.should_save: + self.store_metrics(metrics, train_eval="train") + if return_outputs: + return (loss, metrics) + + return loss + + def _wrap_model(self, model, training=True): + """Wrap model.""" + model = super()._wrap_model(model, training) + if self.args.pipeline_parallel_degree > 1: + model._prepare_pipeline_inputs_func = prepare_pipeline_dpo_inputs_func + model.eval_dpo_batch = types.MethodType(eval_dpo_batch, model) + model._forward_step = types.MethodType(_forward_step, model) + model.broadcast_pp_final_output = types.MethodType(broadcast_pp_final_output, model) + return model + + def _wrap_ref_model(self, model): + """Wrap reference model.""" + if unwrap_model(model) is not model: + return model + self.amp_dtype = "float16" if self.args.fp16 else "bfloat16" + model = paddle.amp.decorate( + models=model, + level=self.args.fp16_opt_level, + dtype=self.amp_dtype, + ) + model = fleet.distributed_model(model) + if self.args.pipeline_parallel_degree > 1: + model._prepare_pipeline_inputs_func = prepare_pipeline_dpo_inputs_func + model.eval_dpo_batch = types.MethodType(eval_dpo_batch, model) + model._forward_step = types.MethodType(_forward_step, model) + model.broadcast_pp_final_output = types.MethodType(broadcast_pp_final_output, model) + + return model + + def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"): + """evaluate""" + self.model_wrapped = self._wrap_ref_model(self.model_wrapped) + return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix) + + def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None): + + """prediction_step""" + if self.args.pipeline_parallel_degree > 1: + # hack for pipeline mode + inputs = self._prepare_inputs(inputs) + return self.prediction_pipeline_step(model, inputs) + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + with paddle.no_grad(): + loss, metrics = self.get_batch_metrics(model, inputs, train_eval="eval") + + if self.args.should_save: + self.store_metrics(metrics, train_eval="eval") + if prediction_loss_only: + return (loss.detach(), None, None) + + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + + logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) + logits = paddle.to_tensor(logits) + labels = paddle.zeros(logits.shape[0]) + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics, train_eval="train"): + """store_metrics""" + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def log(self, logs, **kwargs): + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = paddle.to_tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + if self.state.epoch is not None and train_eval == "train": + self.state.epoch *= self.args.num_train_epochs + return super().log(logs, **kwargs) + + def split_response_indexs_for_pipeline(self, batch): + """ + split response indexs for pipeline parallel mode. + """ + batch_response_indexs = [] + response_indexs = None + response_num = [0] * batch["input_ids"].shape[0] + last_batch = -1 + if batch["response_indexs"][0][1] == 0: + use_sparse_head_and_loss_fn = True + else: + use_sparse_head_and_loss_fn = False + last_batch_response_length = 0 + + for response_index in batch["response_indexs"]: + if response_index[0] == last_batch: + response_index -= last_batch_response_length + response_index[0] = 0 + response_indexs.append(response_index) + else: + last_batch += 1 + if use_sparse_head_and_loss_fn: + last_batch_response_length = response_index[1] + if response_indexs is not None: + batch_response_indexs.append(response_indexs) + response_index -= last_batch_response_length + response_index[0] = 0 + response_indexs = [response_index] + response_num[last_batch] += 1 + batch_response_indexs.append(response_indexs) + max_response_num = max(response_num) + for i in range(len(response_num)): + for _ in range(max_response_num - response_num[i]): + batch_response_indexs[i].append(paddle.to_tensor([0, 0, -1, 0], dtype="int64")) + + return paddle.to_tensor(batch_response_indexs) + + def prediction_pipeline_step( + self, + model, + batch, + ): + """ + prediction_step function for pipeline parallel mode. + """ + config_backup = model.micro_batch_size, model.accumulate_steps + model.accumulate_steps = batch["input_ids"].shape[0] + model.micro_batch_size = 1 + if not self.reference_free: + self.ref_model.accumulate_steps = model.accumulate_steps + self.ref_model.micro_batch_size = model.micro_batch_size + # [1, total_response_indexs] -> [bs, response_indexs] + batch["response_indexs"] = self.split_response_indexs_for_pipeline(batch) + batch["reference_chosen_logps"] = None + batch["reference_rejected_logps"] = None + total_response_num = batch["response_indexs"].shape[0] * batch["response_indexs"].shape[1] + + inputs, labels = model._prepare_pipeline_inputs_func(batch) + with paddle.no_grad(): + with self.autocast_smart_context_manager(): + policy_chosen_logps, policy_rejected_logps = model.eval_dpo_batch( + data=[inputs, labels], total_response_num=total_response_num + ) + policy_chosen_logps = paddle.masked_select(policy_chosen_logps, policy_chosen_logps != 0) + policy_rejected_logps = paddle.masked_select(policy_rejected_logps, policy_rejected_logps != 0) + if not self.reference_free: + reference_chosen_logps, reference_rejected_logps = self.ref_model.eval_dpo_batch( + [inputs, labels], total_response_num=total_response_num + ) + reference_chosen_logps = paddle.masked_select(reference_chosen_logps, reference_chosen_logps != 0) + reference_rejected_logps = paddle.masked_select( + reference_rejected_logps, reference_rejected_logps != 0 + ) + else: + reference_chosen_logps, reference_rejected_logps = None, None + + loss = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + policy_chosen_logps, policy_rejected_logps = policy_chosen_logps.detach(), policy_rejected_logps.detach() + if not self.reference_free: + chosen_rewards = self.args.dpo_beta * (policy_chosen_logps - reference_chosen_logps) + rejected_rewards = self.args.dpo_beta * (policy_rejected_logps - reference_rejected_logps) + else: + chosen_rewards = self.args.dpo_beta * (policy_chosen_logps) + rejected_rewards = self.args.dpo_beta * (policy_rejected_logps) + + reward_accuracies = (chosen_rewards > rejected_rewards).astype(paddle.float32) + metrics = {} + metrics["eval_rewards/chosen"] = chosen_rewards.mean() + metrics["eval_rewards/rejected"] = rejected_rewards.mean() + metrics["eval_rewards/accuracies"] = reward_accuracies.mean() + metrics["eval_rewards/margins"] = (chosen_rewards - rejected_rewards).mean() + metrics["eval_logps/rejected"] = policy_rejected_logps.mean() + metrics["eval_logps/chosen"] = policy_chosen_logps.mean() + for key in metrics: + metrics[key] = self._nested_gather(paddle.tile(metrics[key], repeat_times=[1, 1])).mean().cpu() + if self.args.should_save: + self.store_metrics(metrics, train_eval="eval") + model.micro_batch_size, model.accumulate_steps = config_backup + if not self.reference_free: + self.ref_model.micro_batch_size, self.ref_model.accumulate_steps = config_backup + return (loss, None, None) + + def training_pipeline_step(self, model, inputs): + """ + Perform a training step on a batch of inputs. + """ + # accumulation data + if not hasattr(self, "_pp_data_buffer"): + self._pp_data_buffer = [] + self._pp_data_buffer.append(inputs) + if len(self._pp_data_buffer) != self.args.gradient_accumulation_steps: + return paddle.zeros([]) + response_num = [ + len(self._pp_data_buffer[i]["response_indexs"]) for i in range(self.args.gradient_accumulation_steps) + ] + max_response_num = max(response_num) + for i in range(self.args.gradient_accumulation_steps): + self._pp_data_buffer[i]["response_indexs"] = paddle.concat( + [ + self._pp_data_buffer[i]["response_indexs"], + paddle.to_tensor((max_response_num - response_num[i]) * [[0, 0, -1, 0]], dtype="int64"), + ], + axis=0, + ) + total_response_num = self.args.gradient_accumulation_steps * max_response_num + concatenated_inputs = {} + for key in self._pp_data_buffer[i].keys(): + concatenated_inputs[key] = [ + self._pp_data_buffer[i][key] for i in range(self.args.gradient_accumulation_steps) + ] + concatenated_inputs["reference_chosen_logps"] = None + concatenated_inputs["reference_rejected_logps"] = None + + self._pp_data_buffer = [] + inputs, labels = model._prepare_pipeline_inputs_func(concatenated_inputs) + model_config_backup = model.micro_batch_size, model.accumulate_steps + model.micro_batch_size = self.args.per_device_train_batch_size + model.accumulate_steps = self.args.gradient_accumulation_steps + if not self.reference_free: + ref_model_config_backup = self.ref_model.micro_batch_size, self.ref_model.accumulate_steps + self.ref_model.accumulate_steps = model.accumulate_steps + self.ref_model.micro_batch_size = model.micro_batch_size + with paddle.no_grad(): + with self.autocast_smart_context_manager(): + reference_chosen_logps, reference_rejected_logps = self.ref_model.eval_dpo_batch( + data=[inputs, labels], total_response_num=total_response_num + ) + labels = ( + labels[0], + labels[1], + labels[2], + reference_chosen_logps.split(num_or_sections=model.accumulate_steps, axis=0), + reference_rejected_logps.split(num_or_sections=model.accumulate_steps, axis=0), + ) + train_inputs = [inputs, labels] + train_inputs = model._prepare_training(train_inputs, self.optimizer, self.lr_scheduler) + model.optimizer = None # we do not use `PipelineParallel` to handler optimizer step + model.lr_scheduler = None + with self.autocast_smart_context_manager(): + loss = model.forward_backward_pipeline(train_inputs, self.scaler if self.do_grad_scaling else None) + model.micro_batch_size, model.accumulate_steps = model_config_backup + if not self.reference_free: + self.ref_model.micro_batch_size, self.ref_model.accumulate_steps = ref_model_config_backup + return loss.detach() + + +def prepare_pipeline_dpo_inputs_func(inputs): + """Prepare pipeline inputs""" + if "attention_mask" in inputs: + first_stage_keys = [ + "input_ids", + "attention_mask", + "position_ids", + ] + else: + first_stage_keys = [ + "input_ids", + "attn_mask_start_row_indices", + "position_ids", + ] + + last_stage_keys = [ + "chosen_labels", + "rejected_labels", + "response_indexs", + "reference_chosen_logps", + "reference_rejected_logps", + ] + + def get_expected_keys(inputs, keys): + ret = tuple([inputs.pop(k) for k in keys if k in inputs]) + if len(ret) == 1: + ret = ret[0] + return ret + + if type(inputs) is dict or type(inputs) is OrderedDict: + return [ + get_expected_keys(inputs, first_stage_keys), + get_expected_keys(inputs, last_stage_keys), + ] + + keys = list(inputs[0].keys()) + inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys} + return [ + get_expected_keys(inputs_batch, first_stage_keys), + get_expected_keys(inputs_batch, last_stage_keys), + ] + + +def eval_dpo_batch(self, data, total_response_num): + """eval_dpo_batch""" + # reset the virtual pp rank for each run + self.set_virtual_pipeline_rank(0) + + self._layers.eval() + + # store data id for micro_batch + self.micro_batch_id = 0 + + # store total loss of entire batch + self.total_loss = None + + startup_steps = self.num_stages - self.stage_id - 1 + startup_steps = min(startup_steps, self.accumulate_steps) + steady_steps = self.accumulate_steps - startup_steps + + input_buffers = [] + output_buffers = [] + + # convert to micro dataset + micro_dataset = self._wrap_data(data) + + for step_id in range(startup_steps): + input_tensor = self._p2p_helper.recv_forward(self.is_pipeline_first_stage()) + + output_tensor = self._forward_step(input_tensor, micro_dataset) + self._p2p_helper.send_forward(output_tensor, self.is_pipeline_last_stage(), skip_check_meta=True) + + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) + + if steady_steps > 0: + input_tensor = self._p2p_helper.recv_forward(self.is_pipeline_first_stage()) + + for i in range(steady_steps): + last_iter = i == (steady_steps - 1) + + output_tensor = self._forward_step(input_tensor, micro_dataset) + self._p2p_helper.send_forward(output_tensor, self.is_pipeline_last_stage(), skip_check_meta=True) + + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) + + if not last_iter: + input_tensor = self._p2p_helper.recv_forward(self.is_pipeline_first_stage()) + return self.broadcast_pp_final_output(output_buffers, total_response_num) + + +def _forward_step(self, input_tensor, micro_dataset, chunk_id=None): + if self._enable_timer: + self.timers("forward_step").start() + if self.is_pipeline_first_stage(): + input_tensor = next(micro_dataset)[0] + self._check_micro_batch_data_valid(input_tensor) + + assert chunk_id is None or isinstance(chunk_id, int) + + output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id) + + if self.is_pipeline_last_stage(): + assert self._layers._loss_fn is not None, "loss function should exist to compute loss" + labels = next(micro_dataset)[1] + self._check_micro_batch_data_valid(labels) + for idx, loss_fn in enumerate(self._layers._loss_fn): + output_tensor = loss_fn(output_tensor, labels[0], labels[1], labels[2], labels[3], labels[4]) + if labels[3] is not None and labels[4] is not None: + assert isinstance( + output_tensor, (paddle.Tensor, framework.core.eager.Tensor) + ), "Currently, loss_fn should obtain Paddle.Tensor dtype" + + with paddle.amp.auto_cast(enable=False): + if self.accumulate_steps > 1 and not self._delay_scale_loss: + output_tensor = output_tensor / self.accumulate_steps + + if self.total_loss is None: + self.total_loss = [] + if len(self.total_loss) <= idx: + self.total_loss.append(paddle.zeros_like(output_tensor)) + self.total_loss[idx] += output_tensor.detach() + if idx == self.loss_fn_idx: + loss_tensor = output_tensor + + if self.is_pipeline_first_stage() or self.is_pipeline_last_stage(): + # Only increase micro batch id at virtual first/last pp stage. + # The micro batch id is used to load data, therefore, only increase it when load data. + self.micro_batch_id += 1 + if self._enable_timer: + self.timers("forward_step").stop() + if self.is_pipeline_last_stage() and labels[3] is not None and labels[4] is not None: + return loss_tensor + else: + return output_tensor + + +def broadcast_pp_final_output(self, output_buffers, total_response_num): + """broadcast_pp_final_output""" + # Since the last backward run in interleave will set the virtual rank to 0, + # here we need to check last stage ignoring virtual stage. + if self.is_pipeline_last_stage(ignore_virtual=True): + chosen_logps = paddle.concat([buffer[0] for buffer in output_buffers], axis=0) + rejected_logps = paddle.concat([buffer[1] for buffer in output_buffers], axis=0) + paddle.distributed.broadcast(chosen_logps, src=self.global_rank, sync_op=True, group=self.pp_group) + paddle.distributed.broadcast(rejected_logps, src=self.global_rank, sync_op=True, group=self.pp_group) + else: + chosen_logps = paddle.zeros(shape=[total_response_num], dtype="float32") + rejected_logps = paddle.zeros(shape=[total_response_num], dtype="float32") + paddle.distributed.broadcast( + chosen_logps, + src=self._hcg.get_rank_from_stage(self.num_stages - 1), + sync_op=True, + group=self.pp_group, + ) + paddle.distributed.broadcast( + rejected_logps, + src=self._hcg.get_rank_from_stage(self.num_stages - 1), + sync_op=True, + group=self.pp_group, + ) + return chosen_logps, rejected_logps diff --git a/paddlenlp/trl/trl_data.py b/paddlenlp/trl/trl_data.py new file mode 100644 index 000000000000..ca3a1ae40f7e --- /dev/null +++ b/paddlenlp/trl/trl_data.py @@ -0,0 +1,232 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def check_preference_data(data): + if isinstance(data["src"], str): + data["src"] = [data["src"]] + if isinstance(data["tgt"], str): + data["tgt"] = [data["tgt"]] + if len(data["src"]) != len(data["tgt"]) + 1: + raise ValueError( + "The number of src and tgt should differ by 1, but got {} and {}".format( + len(data["src"]), len(data["tgt"]) + ) + ) + if (len(data["response"]) != 2) or (len(data["response"]) != len(data["sort"])): + raise ValueError( + "The number of response and sort should be 2, but got {} and {}".format( + len(data["response"]), len(data["sort"]) + ) + ) + if len(data["response"][0]) == 0 or len(data["response"][1]) == 0: + raise ValueError("The response should not be empty, buut got {data}.") + if data["sort"][0] == data["sort"][1]: + raise ValueError("The two sort should be different.") + + return data + + +def preprocess_preference_data(data, tokenizer, data_args, model_args): + """Convert raw format example to Example.""" + # 1. Check data format + data = check_preference_data(data) + + if data["sort"][0] > data["sort"][1]: + chosen = data["response"][0] + rejected = data["response"][1] + else: + chosen = data["response"][1] + rejected = data["response"][0] + chosen_encode_tokens = [] + for idx in range(len(data["src"])): + if idx < len(data["tgt"]): + if tokenizer.chat_template is not None: + chosen_encode_tokens.append( + [ + data["src"][idx].strip(), + data["tgt"][idx].strip(), + ] + ) + else: + chosen_encode_tokens.append( + [ + tokenizer.encode(data["src"][idx].strip(), add_special_tokens=True)["input_ids"], + tokenizer.encode(data["tgt"][idx].strip(), add_special_tokens=False)["input_ids"] + + [tokenizer.eos_token_id], + ] + ) + else: + if tokenizer.chat_template is not None: + chosen_encode_tokens.append( + [ + data["src"][idx].strip(), + chosen.strip(), + ] + ) + else: + chosen_encode_tokens.append( + [ + tokenizer.encode(data["src"][idx].strip(), add_special_tokens=True)["input_ids"], + tokenizer.encode(chosen.strip(), add_special_tokens=False)["input_ids"] + + [tokenizer.eos_token_id], + ] + ) + if tokenizer.chat_template is not None: + chat_input_list = chosen_encode_tokens + chosen_encode_tokens = tokenizer.encode_chat_inputs(chat_input_list)["conversations"] + # convert to rejected chosen_encode_tokens + chat_input_list[-1][-1] = rejected.strip() + rejected_encode_tokens = tokenizer.encode_chat_inputs(chat_input_list)["conversations"] + + """Post process sequence: tokenization & truncation.""" + tokens_prompt = chosen_encode_tokens[-1][0][:-1] + eos_token_id = chosen_encode_tokens[-1][-1][-1] + tokens_chosen = chosen_encode_tokens[-1][0][-1:] + chosen_encode_tokens[-1][-1][:-1] + tokens_rejected = chosen_encode_tokens[-1][0][-1:] + rejected_encode_tokens[-1][-1][:-1] + else: + eos_token_id = tokenizer.eos_token_id + tokens_prompt = chosen_encode_tokens[-1][0][:-1] + tokens_chosen = ( + chosen_encode_tokens[-1][0][-1:] + tokenizer.encode(chosen.strip(), add_special_tokens=False)["input_ids"] + ) + tokens_rejected = ( + chosen_encode_tokens[-1][0][-1:] + + tokenizer.encode(rejected.strip(), add_special_tokens=False)["input_ids"] + ) + + if len(tokens_prompt) + len(tokens_chosen) + len(tokens_rejected) > data_args.max_seq_len: + # truncate prompt + tokens_prompt = tokens_prompt[-data_args.max_prompt_len :] + if (len(tokens_prompt) + len(tokens_chosen) + len(tokens_rejected)) > data_args.max_seq_len: + max_response_len = data_args.max_seq_len - len(tokens_prompt) + # 按比例截断 + max_chosen_len = int(len(tokens_chosen) / (len(tokens_chosen) + len(tokens_rejected)) * max_response_len) + max_rejected_len = max_response_len - max_chosen_len + tokens_chosen = tokens_chosen[:max_chosen_len] + tokens_rejected = tokens_rejected[:max_rejected_len] + + cur_len = len(tokens_prompt) + len(tokens_chosen) + len(tokens_rejected) + turn_index = len(chosen_encode_tokens) - 2 + + # append former dialog contents + while turn_index >= 0: + tokens_src = chosen_encode_tokens[turn_index][0] + tokens_target = chosen_encode_tokens[turn_index][1] + turn_index -= 1 + + if len(tokens_src) + len(tokens_target) > data_args.max_seq_len - cur_len: + break + tokens_prompt = tokens_src + tokens_target + tokens_prompt + cur_len += len(tokens_src) + len(tokens_target) + + input_ids = tokens_prompt + tokens_chosen + tokens_rejected + prompt_len = len(tokens_prompt) + chosen_len = len(tokens_chosen) + rejected_len = len(tokens_rejected) + seq_len = len(input_ids) + # make position ids & labels + + position_ids = ( + list(range(prompt_len)) # prompt + + list(range(prompt_len, prompt_len + chosen_len)) # chosen + + list(range(prompt_len, prompt_len + rejected_len)) # rejected + ) + chosen_labels = [0] * prompt_len + tokens_chosen[1:] + [eos_token_id] + [0] * rejected_len + rejected_labels = [0] * prompt_len + [0] * chosen_len + tokens_rejected[1:] + [eos_token_id] + + # response index + response_indexs = [prompt_len, prompt_len + chosen_len, seq_len] + output_dict = { + "input_ids": input_ids, + "position_ids": position_ids, + "chosen_labels": chosen_labels, + "rejected_labels": rejected_labels, + "response_indexs": response_indexs, + } + + # attention mask + if model_args.use_attn_mask_start_row_indices: + output_dict["attn_mask_start_row_indices"] = ( + [seq_len] * prompt_len + [prompt_len + chosen_len] * chosen_len + [seq_len] * rejected_len + ) + else: + attention_mask = np.tri(seq_len, seq_len, dtype=bool) + attention_mask[(prompt_len + chosen_len) :, prompt_len : (prompt_len + chosen_len)] = False + output_dict["attention_mask"] = attention_mask + return output_dict + + +def preference_collate_fn(batch, max_seq_len=None): + """Convert batch data into tensor.""" + if max_seq_len is None: + raise ValueError("max_seq_len is None.") + + input_dict = { + "input_ids": [], + "position_ids": [], + "chosen_labels": [], + "rejected_labels": [], + "response_indexs": [], + } + sequence = batch[0] + if "attn_mask_start_row_indices" in sequence: + input_dict["attn_mask_start_row_indices"] = [] + use_attn_mask_start_row_indices = True + elif "attention_mask" in sequence: + input_dict["attention_mask"] = [] + use_attn_mask_start_row_indices = False + else: + raise ValueError("attention_mask and attn_mask_start_row_indices are both None.") + + for i, sequence in enumerate(batch): + difference = max_seq_len - len(sequence["input_ids"]) + + input_dict["input_ids"].append(sequence["input_ids"] + [0] * difference) + input_dict["position_ids"].append(sequence["position_ids"] + [0] * difference) + input_dict["chosen_labels"].append(sequence["chosen_labels"] + [0] * difference) + input_dict["rejected_labels"].append(sequence["rejected_labels"] + [0] * difference) + if use_attn_mask_start_row_indices: + input_dict["attn_mask_start_row_indices"].append( + [sequence["attn_mask_start_row_indices"] + [sequence["attn_mask_start_row_indices"][-1]] * difference] + ) + else: + input_dict["attention_mask"].append( + np.pad( + sequence["attention_mask"], + pad_width=((0, 0), (0, difference), (0, difference)), + mode="constant", + constant_values=False, + ) + ) + + for ri in sequence["response_indexs"]: + input_dict["response_indexs"].append( + [ + i, # bs + ri[0], # chosen_response_start_index + ri[1], # rejeted_response_start_index + ri[2], # rejeted_response_end_index + 1 + ] + ) + for key in input_dict: + if key == "attention_mask": + input_dict[key] = np.array(input_dict[key], dtype=bool) + elif key == "attn_mask_start_row_indices": + input_dict[key] = np.array(input_dict[key], dtype=np.int32) + else: + input_dict[key] = np.array(input_dict[key]) + return input_dict diff --git a/paddlenlp/trl/trl_utils.py b/paddlenlp/trl/trl_utils.py new file mode 100644 index 000000000000..541238807b69 --- /dev/null +++ b/paddlenlp/trl/trl_utils.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def calculate_effective_tokens(training_args, train_dataset, max_seq_len): + """ + Caculate the effective tokens during training. + """ + total_effective_tokens = 0 + + try: + data_parallel_degree = training_args.data_parallel_degree + except: + data_parallel_degree = 1 + if training_args.sharding_parallel_degree > 1: + sharding_parallel_degree = training_args.sharding_parallel_degree + else: + sharding_parallel_degree = 1 + if training_args.max_steps > 0: + total_batch = ( + training_args.max_steps + * training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * sharding_parallel_degree + * data_parallel_degree + ) + for i, data in enumerate(train_dataset): + if i == total_batch: + break + total_effective_tokens += len(data["input_ids"]) + total_tokens = total_batch * max_seq_len + else: + for i, data in enumerate(train_dataset): + total_effective_tokens += len(data["input_ids"]) + total_tokens = (i + 1) * max_seq_len + total_effective_tokens *= training_args.num_train_epochs + total_tokens *= training_args.num_train_epochs + return total_effective_tokens, total_tokens