Skip to content

Commit

Permalink
Upgrade applications/text_classifications/multi_class to use Trainer …
Browse files Browse the repository at this point in the history
…API (#3679)

* add trainer to multi_class finetuning

* styles

* fix styles

* log eval metrics

* address comments

* address comments

* precommit

* fix README style

Co-authored-by: lugimzzz <63761690+lugimzzz@users.noreply.github.com>
  • Loading branch information
sijunhe and lugimzzz authored Nov 17, 2022
1 parent bfec1e7 commit 60a51f0
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 239 deletions.
116 changes: 71 additions & 45 deletions applications/text_classification/multi_class/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,67 +188,93 @@ data.txt(待预测数据文件),需要预测标签的文本数据。
#### 2.4.1 预训练模型微调


使用CPU/GPU训练,默认为GPU训练,使用CPU训练只需将设备参数配置改为`--device "cpu"`
使用CPU/GPU训练,默认为GPU训练,使用CPU训练只需将设备参数配置改为`--device cpu`
```shell
python train.py \
--dataset_dir "data" \
--device "gpu" \
--model_name_or_path ernie-3.0-medium-zh \
--data_dir ./data/ \
--output_dir checkpoint \
--device gpu \
--learning_rate 3e-5 \
--num_train_epochs 100 \
--early_stopping_patience 4 \
--max_seq_length 128 \
--model_name "ernie-3.0-medium-zh" \
--batch_size 32 \
--early_stop \
--epochs 100
--per_device_eval_batch_size 32 \
--per_device_train_batch_size 32 \
--num_train_epochs 100 \
--do_train \
--do_eval \
--metric_for_best_model accuracy \
--load_best_model_at_end \
--evaluation_strategy epoch \
--save_strategy epoch \
--save_total_limit 1
```

如果在CPU环境下训练,可以指定`nproc_per_node`参数进行多核训练:
```shell
python -m paddle.distributed.launch --nproc_per_node 8 --backend "gloo" train.py \
--dataset_dir "data" \
--device "cpu" \
python -m paddle.distributed.launch --nproc_per_node 8 --backend gloo train.py \
--model_name_or_path ernie-3.0-medium-zh \
--data_dir ./data/ \
--output_dir checkpoint \
--device cpu \
--learning_rate 3e-5 \
--num_train_epochs 100 \
--max_seq_length 128 \
--model_name "ernie-3.0-medium-zh" \
--batch_size 32 \
--early_stop \
--epochs 100
--per_device_eval_batch_size 32 \
--per_device_train_batch_size 32 \
--num_train_epochs 100 \
--early_stopping_patience 4 \
--do_train \
--do_eval \
--metric_for_best_model accuracy \
--load_best_model_at_end \
--evaluation_strategy epoch \
--save_strategy epoch \
--save_total_limit 1
```

如果在GPU环境中使用,可以指定`gpus`参数进行单卡/多卡训练。使用多卡训练可以指定多个GPU卡号,例如 --gpus "0,1"。如果设备只有一个GPU卡号默认为0,可使用`nvidia-smi`命令查看GPU使用情况:
如果在GPU环境中使用,可以指定`gpus`参数进行单卡/多卡训练。使用多卡训练可以指定多个GPU卡号,例如 --gpus 0,1。如果设备只有一个GPU卡号默认为0,可使用`nvidia-smi`命令查看GPU使用情况:

```shell
unset CUDA_VISIBLE_DEVICES
python -m paddle.distributed.launch --gpus "0" train.py \
--dataset_dir "data" \
--device "gpu" \
python -m paddle.distributed.launch --gpus 0,1 train.py \
--data_dir ./data/ \
--output_dir checkpoint \
--device cpu \
--learning_rate 3e-5 \
--num_train_epochs 100 \
--max_seq_length 128 \
--model_name "ernie-3.0-medium-zh" \
--batch_size 32 \
--early_stop \
--epochs 100
--per_device_eval_batch_size 32 \
--per_device_train_batch_size 32 \
--num_train_epochs 100 \
--early_stopping_patience 4 \
--do_train \
--do_eval \
--metric_for_best_model accuracy \
--load_best_model_at_end \
--evaluation_strategy epoch \
--save_strategy epoch \
--save_total_limit 1
```

可支持配置的参数:
主要的配置的参数为:
- `model_name_or_path`: 内置模型名,或者模型参数配置目录路径。默认为`ernie-3.0-base-zh`
- `data_dir`: 训练数据集路径,数据格式要求详见[数据标注](#数据标注)
- `output_dir`: 模型参数、训练日志和静态图导出的保存目录。
- `max_seq_length`: 最大句子长度,超过该长度的文本将被截断,不足的以Pad补全。提示文本不会被截断。
- `num_train_epochs`: 训练轮次,使用早停法时可以选择100
- `early_stopping_patience`: 在设定的早停训练轮次内,模型在开发集上表现不再上升,训练终止;默认为4。
- `learning_rate`: 预训练语言模型参数基础学习率大小,将与learning rate scheduler产生的值相乘作为当前学习率。
- `do_train`: 是否进行训练。
- `do_eval`: 是否进行评估。
- `device`: 使用的设备,默认为`gpu`
- `per_device_train_batch_size`: 每次训练每张卡上的样本数量。可根据实际GPU显存适当调小/调大此配置。
- `per_device_eval_batch_size`: 每次评估每张卡上的样本数量。可根据实际GPU显存适当调小/调大此配置。

* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu。如使用gpu训练,可使用参数--gpus指定GPU卡号;默认为"gpu"。
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含train.txt,dev.txt和label.txt文件;默认为None。
* `save_dir`:保存训练模型的目录;默认保存在当前目录checkpoint文件夹下。
* `max_seq_length`:分词器tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数;默认为128。
* `model_name`:选择预训练模型,可选"ernie-1.0-large-zh-cw","ernie-3.0-xbase-zh", "ernie-3.0-base-zh", "ernie-3.0-medium-zh", "ernie-3.0-micro-zh", "ernie-3.0-mini-zh", "ernie-3.0-nano-zh", "ernie-2.0-base-en", "ernie-2.0-large-en","ernie-m-base","ernie-m-large";默认为"ernie-3.0-medium-zh"。
* `batch_size`:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
* `learning_rate`:训练最大学习率;默认为3e-5。
* `epochs`: 训练轮次,使用早停法时可以选择100;默认为10。
* `early_stop`:选择是否使用早停法(EarlyStopping),模型在开发集经过一定epoch后精度表现不再上升,训练终止;默认为False。
* `early_stop_nums`:在设定的早停训练轮次内,模型在开发集上表现不再上升,训练终止;默认为4。
* `logging_steps`: 训练过程中日志打印的间隔steps数,默认5。
* `weight_decay`:控制正则项力度的参数,用于防止过拟合,默认为0.0。
* `warmup`:是否使用学习率warmup策略,使用时应设置适当的训练轮次(epochs);默认为False。
* `warmup_steps`:学习率warmup策略的比例数,如果设为1000,则学习率会在1000steps数从0慢慢增长到learning_rate, 而后再缓慢衰减;默认为0。
* `init_from_ckpt`: 模型初始checkpoint参数地址,默认None。
* `seed`:随机种子,默认为3。
* `train_file`:本地数据集中训练集文件名;默认为"train.txt"。
* `dev_file`:本地数据集中开发集文件名;默认为"dev.txt"。
* `label_file`:本地数据集中标签集文件名;默认为"label.txt"。
训练脚本支持所有`TraingArguments`的参数,更多参数介绍可参考[TrainingArguments 参数介绍](https://paddlenlp.readthedocs.io/zh/latest/trainer.html#trainingarguments)

程序运行时将会自动进行训练,评估。同时训练过程中会自动保存开发集上最佳模型在指定的 `save_dir` 中,保存模型文件结构如下所示:
程序运行时将会自动进行训练,评估。同时训练过程中会自动保存开发集上最佳模型在指定的 `output_dir` 中,保存模型文件结构如下所示:

```text
checkpoint/
Expand All @@ -260,8 +286,8 @@ checkpoint/

**NOTE:**

* 如需恢复模型训练,则可以设置 `init_from_ckpt` , 如 `init_from_ckpt=checkpoint/model_state.pdparams`
* 如需训练英文文本分类任务,只需更换预训练模型参数 `model_name` 。英文训练任务推荐使用"ernie-2.0-base-en"、"ernie-2.0-large-en"。
* 如需恢复模型训练,则可以设置 `resume_from_checkpoint` , 如 `resume_from_checkpoint=./checkpoints/checkpoint-217`
* 如需训练英文文本分类任务,只需更换预训练模型参数 `model_name_or_path` 。英文训练任务推荐使用"ernie-2.0-base-en"、"ernie-2.0-large-en"。
* 英文和中文以外语言的文本分类任务,推荐使用基于96种语言(涵盖法语、日语、韩语、德语、西班牙语等几乎所有常见语言)进行预训练的多语言预训练模型"ernie-m-base"、"ernie-m-large",详情请参见[ERNIE-M论文](https://arxiv.org/pdf/2012.15674.pdf)

#### 2.4.2 训练评估与模型优化
Expand Down
Loading

0 comments on commit 60a51f0

Please sign in to comment.