Skip to content

Commit 7e084d9

Browse files
authored
[6/7] SFTDataset: revamp instruct/chat (pytorch#1286)
1 parent 3e29e6b commit 7e084d9

36 files changed

+596
-652
lines changed

docs/source/api_ref_data.rst

-4
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,7 @@ and models.
2525
PromptTemplate
2626
PromptTemplateInterface
2727
ChatMLTemplate
28-
2928
ChatFormat
30-
ChatMLFormat
31-
Llama2ChatFormat
32-
MistralChatFormat
3329

3430
Types
3531
-----

tests/assets/chat_tiny.json

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
[
2+
{
3+
"conversations": [
4+
{
5+
"from": "system",
6+
"value": "You are an AI assistant."
7+
},
8+
{
9+
"from": "human",
10+
"value": "What is the meaning of life?"
11+
},
12+
{
13+
"from": "gpt",
14+
"value": "The meaning of life is 42."
15+
},
16+
{
17+
"from": "human",
18+
"value": "That's ridiculous."
19+
},
20+
{
21+
"from": "gpt",
22+
"value": "I agree."
23+
}
24+
]
25+
}
26+
]

tests/assets/instruct_tiny.json

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[
2+
{
3+
"instruction": "What time is it in London?",
4+
"response": "It is 10:00 AM in London"
5+
},
6+
{
7+
"instruction": "Is is Istanbul or Constantinople?",
8+
"response": "Istanbul was Constantinople. Now it's Istanbul, not Constantinople."
9+
}
10+
]

tests/common.py

+3
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,8 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from pathlib import Path
67

78
TUNE_PATH = "torchtune/_cli/tune.py"
9+
10+
ASSETS = Path(__file__).parent / "assets"

tests/test_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import torch
2121
from torch import nn
22-
from torchtune.data import ChatFormat, Message, PromptTemplate, truncate
22+
from torchtune.data import Message, PromptTemplate, truncate
2323
from torchtune.modules.tokenizers import ModelTokenizer
2424
from torchtune.modules.transforms import Transform
2525

@@ -164,7 +164,7 @@ def image_id(self):
164164
return -2
165165

166166

167-
class DummyChatFormat(ChatFormat):
167+
class DummyChatFormat:
168168

169169
B_SYS, E_SYS = "System:\n", "\n"
170170
B_INST, E_INST = "User:\n", "\nAssistant:\n"

tests/torchtune/_cli/test_validate.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,8 @@
77
import runpy
88
import sys
99

10-
from pathlib import Path
11-
1210
import pytest
13-
from tests.common import TUNE_PATH
14-
15-
ASSETS = Path(__file__).parent.parent.parent / "assets"
11+
from tests.common import ASSETS, TUNE_PATH
1612

1713

1814
class TestTuneValidateCommand:

tests/torchtune/data/test_chat_formats.py

-99
This file was deleted.

tests/torchtune/datasets/test_chat_dataset.py

+60-50
Original file line numberDiff line numberDiff line change
@@ -4,63 +4,20 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from unittest import mock
8-
97
import pytest
8+
from tests.common import ASSETS
109
from tests.test_utils import DummyChatFormat, DummyTokenizer
11-
from torchtune.data import Message
10+
from torchtune.data import get_sharegpt_messages
1211
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
13-
from torchtune.datasets import ChatDataset
12+
from torchtune.datasets import chat_dataset, ChatDataset
1413

1514

1615
class TestChatDataset:
1716
@pytest.fixture
1817
def chat_format(self):
1918
return DummyChatFormat
2019

21-
@pytest.fixture
22-
def dialogue(self):
23-
return [
24-
{
25-
"dialogue": [
26-
Message.from_dict(
27-
{
28-
"role": "system",
29-
"content": "You are an AI assistant.",
30-
"masked": True,
31-
}
32-
),
33-
Message.from_dict(
34-
{
35-
"role": "user",
36-
"content": "What is the meaning of life?",
37-
"masked": True,
38-
}
39-
),
40-
Message.from_dict(
41-
{
42-
"role": "assistant",
43-
"content": "The meaning of life is 42.",
44-
"masked": False,
45-
}
46-
),
47-
Message.from_dict(
48-
{
49-
"role": "user",
50-
"content": "That's ridiculous.",
51-
"masked": True,
52-
}
53-
),
54-
Message.from_dict(
55-
{"role": "assistant", "content": "I agree.", "masked": False}
56-
),
57-
],
58-
},
59-
]
60-
61-
@mock.patch("torchtune.datasets._chat.load_dataset")
62-
def test_get_item(self, mock_load_dataset, chat_format, dialogue):
63-
mock_load_dataset.return_value = dialogue
20+
def test_get_item(self, chat_format):
6421
expected_tokenized_prompts = [
6522
[
6623
0,
@@ -104,15 +61,68 @@ def test_get_item(self, mock_load_dataset, chat_format, dialogue):
10461
]
10562
ds = ChatDataset(
10663
tokenizer=DummyTokenizer(),
107-
source="iam/agoofy/goober",
108-
convert_to_messages=lambda x, y: x["dialogue"],
64+
source="json",
65+
convert_to_messages=get_sharegpt_messages,
10966
chat_format=chat_format,
11067
max_seq_len=100,
11168
train_on_input=False,
69+
data_files=str(ASSETS / "chat_tiny.json"),
70+
split="train",
11271
)
11372
assert len(ds) == 1
114-
mock_load_dataset.assert_called_once()
73+
prompt, label = ds[0]["tokens"], ds[0]["labels"]
74+
assert prompt == expected_tokenized_prompts[0]
75+
assert label == expected_labels[0]
76+
77+
expected_tokenized_prompts = [
78+
[
79+
0,
80+
3,
81+
3,
82+
2,
83+
2,
84+
10,
85+
4,
86+
2,
87+
3,
88+
7,
89+
2,
90+
5,
91+
3,
92+
7,
93+
2,
94+
4,
95+
2,
96+
3,
97+
-1,
98+
0,
99+
6,
100+
11,
101+
1,
102+
6,
103+
-1,
104+
]
105+
]
106+
prompt_lengths = (12, 3)
107+
expected_labels = [
108+
[CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[0]
109+
+ [3, 7, 2, 4, 2, 3, -1]
110+
+ [CROSS_ENTROPY_IGNORE_IDX] * prompt_lengths[1]
111+
+ [1, 6, -1]
112+
]
113+
114+
ds = chat_dataset(
115+
tokenizer=DummyTokenizer(),
116+
source="json",
117+
data_files=str(ASSETS / "chat_tiny.json"),
118+
conversation_column="conversations",
119+
conversation_style="sharegpt",
120+
train_on_input=False,
121+
packed=False,
122+
split="train",
123+
)
115124

125+
assert len(ds) == 1
116126
prompt, label = ds[0]["tokens"], ds[0]["labels"]
117127
assert prompt == expected_tokenized_prompts[0]
118128
assert label == expected_labels[0]

tests/torchtune/datasets/test_grammar_dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_label_no_masking(self, load_dataset, tokenizer):
3636
]
3737
)
3838

39-
grammar_ds = grammar_dataset(model_transform=tokenizer, train_on_input=True)
39+
grammar_ds = grammar_dataset(tokenizer=tokenizer, train_on_input=True)
4040
input, labels = grammar_ds[0]["tokens"], grammar_ds[0]["labels"]
4141

4242
assert input == [0, 7, 2, 3, 6, 4, 8, 5, 8, 5, 7, 4, 3, 6, 4, 8, 9, 2, 9, -1]
@@ -58,7 +58,7 @@ def test_label_masking(self, load_dataset, tokenizer):
5858
]
5959
)
6060

61-
grammar_ds = grammar_dataset(model_transform=tokenizer)
61+
grammar_ds = grammar_dataset(tokenizer=tokenizer)
6262

6363
# Generate the input and labels
6464
input, labels = grammar_ds[0]["tokens"], grammar_ds[0]["labels"]

0 commit comments

Comments
 (0)