-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_prompts.py
94 lines (81 loc) · 2.38 KB
/
create_prompts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
import os
import random
import click
from datasets import load_dataset
@click.command()
@click.option(
"--dataset_name",
required=True,
help="Name of the Hugging Face dataset, e.g., tatsu-lab/alpaca",
)
@click.option(
"--split",
default="train",
help="Split of the dataset to use (default: train)",
)
@click.option(
"--instruction_field",
default="instruction",
help="Field name for the instruction in the dataset (default: instruction)",
)
@click.option(
"--input_field",
default="input",
help="Field name for the input in the dataset (default: input)",
)
@click.option(
"--prompts_file",
required=True,
type=click.Path(),
help="Path to the JSON output file, e.g., prompts/nlp-alpaca.jsonl",
)
@click.option(
"--shuffle",
is_flag=True,
help="Shuffle the dataset before processing (default: False)",
)
@click.option(
"--limit",
type=int,
default=None,
help="Limit the number of samples to process (default: all)",
)
def main(
dataset_name,
split,
instruction_field,
input_field,
prompts_file,
shuffle,
limit,
):
"""Creates prompts from an existing Hugging Face dataset and saves them as a JSONL file."""
if os.path.exists(prompts_file):
if not click.confirm(
f"{prompts_file} already exists. Overwrite?", default=False
):
click.echo("Operation cancelled.")
return
# Load the dataset from Hugging Face
ds = load_dataset(dataset_name, split=split)
# Shuffle the dataset if the flag is set
if shuffle:
ds = ds.shuffle(seed=42) # Ensuring reproducibility
# Limit the number of samples if specified
if limit:
ds = ds.select(range(min(limit, len(ds)))) # Avoids out-of-range errors
# Open the output file for writing
with open(prompts_file, "w", encoding="utf-8") as f:
for row in ds:
instruction = row.get(instruction_field, "").strip()
inp = row.get(input_field, "").strip()
prompt = f"{instruction}\n{inp}" if inp else instruction
record = {
"dataset_name": dataset_name, # Add dataset handle
"prompt": prompt,
}
f.write(json.dumps(record, ensure_ascii=False) + "\n")
click.echo(f"Prompts successfully saved to {prompts_file}")
if __name__ == "__main__":
main()