Skip to content

📦 [SFT] Deprecate batched formatting_func #3147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 8, 2025
11 changes: 7 additions & 4 deletions docs/source/sft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,19 @@ If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False`
If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:

```python
def formatting_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['question'])):
text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
output_texts.append(text)
return output_texts

training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
formatting_func=formatting_func
formatting_func=formatting_prompts_func
)

trainer.train()
Expand Down
6 changes: 2 additions & 4 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class SFTTrainer(Trainer):
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
formatting_func (`Optional[Callable]`):
Formatting function applied to the dataset before tokenization.
A batched formatting function applied to the dataset before tokenization.
"""

_tag_names = ["trl", "sft"]
Expand Down Expand Up @@ -473,12 +473,10 @@ def _prepare_dataset(
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"

batched = isinstance(formatting_func(next(iter(dataset))), list)

def _func(example):
return {"text": formatting_func(example)}

dataset = dataset.map(_func, batched=batched, **map_kwargs)
dataset = dataset.map(_func, batched=True, **map_kwargs)

# If the dataset is prompt-completion, convert it to language modeling type
first_example = next(iter(dataset))
Expand Down