Skip to content

Commit 3cceb86

Browse files
authored
Refactored modules/tokenizers to be a subdir of modules/transforms (pytorch#2231)
1 parent 5764650 commit 3cceb86

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+113
-72
lines changed

docs/source/api_ref_modules.rst

+6-6
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ model specific tokenizers.
4848
:toctree: generated/
4949
:nosignatures:
5050

51-
tokenizers.SentencePieceBaseTokenizer
52-
tokenizers.TikTokenBaseTokenizer
53-
tokenizers.ModelTokenizer
54-
tokenizers.BaseTokenizer
51+
transforms.tokenizers.SentencePieceBaseTokenizer
52+
transforms.tokenizers.TikTokenBaseTokenizer
53+
transforms.tokenizers.ModelTokenizer
54+
transforms.tokenizers.BaseTokenizer
5555

5656
Tokenizer Utilities
5757
-------------------
@@ -61,8 +61,8 @@ These are helper methods that can be used by any tokenizer.
6161
:toctree: generated/
6262
:nosignatures:
6363

64-
tokenizers.tokenize_messages_no_special_tokens
65-
tokenizers.parse_hf_tokenizer_json
64+
transforms.tokenizers.tokenize_messages_no_special_tokens
65+
transforms.tokenizers.parse_hf_tokenizer_json
6666

6767

6868
PEFT Components

docs/source/basics/custom_components.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ our models in torchtune - see :func:`~torchtune.models.llama3_2_vision.llama3_2_
117117
#
118118
from torchtune.datasets import SFTDataset, PackedDataset
119119
from torchtune.data import InputOutputToMessages
120-
from torchtune.modules.tokenizers import ModelTokenizer
120+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
121121
122122
# Example builder function for a custom code instruct dataset not in torchtune, but using
123123
# different dataset building blocks from torchtune

docs/source/basics/model_transforms.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ The following methods are required on the model transform:
101101

102102
.. code-block:: python
103103
104-
from torchtune.modules.tokenizers import ModelTokenizer
104+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
105105
from torchtune.modules.transforms import Transform
106106
107107
class MyMultimodalTransform(ModelTokenizer, Transform):

docs/source/basics/tokenizers.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ For example, here we change the ``"<|begin_of_text|>"`` and ``"<|end_of_text|>"`
168168
Base tokenizers
169169
---------------
170170

171-
:class:`~torchtune.modules.tokenizers.BaseTokenizer` are the underlying byte-pair encoding modules that perform the actual raw string to token ID conversion and back.
171+
:class:`~torchtune.modules.transforms.tokenizers.BaseTokenizer` are the underlying byte-pair encoding modules that perform the actual raw string to token ID conversion and back.
172172
In torchtune, they are required to implement ``encode`` and ``decode`` methods, which are called by the :ref:`model_tokenizers` to convert
173173
between raw text and token IDs.
174174

@@ -202,13 +202,13 @@ between raw text and token IDs.
202202
"""
203203
pass
204204
205-
If you load any :ref:`model_tokenizers`, you can see that it calls its underlying :class:`~torchtune.modules.tokenizers.BaseTokenizer`
205+
If you load any :ref:`model_tokenizers`, you can see that it calls its underlying :class:`~torchtune.modules.transforms.tokenizers.BaseTokenizer`
206206
to do the actual encoding and decoding.
207207

208208
.. code-block:: python
209209
210210
from torchtune.models.mistral import mistral_tokenizer
211-
from torchtune.modules.tokenizers import SentencePieceBaseTokenizer
211+
from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer
212212
213213
m_tokenizer = mistral_tokenizer("/tmp/Mistral-7B-v0.1/tokenizer.model")
214214
# Mistral uses SentencePiece for its underlying BPE
@@ -227,7 +227,7 @@ to do the actual encoding and decoding.
227227
Model tokenizers
228228
----------------
229229

230-
:class:`~torchtune.modules.tokenizers.ModelTokenizer` are specific to a particular model. They are required to implement the ``tokenize_messages`` method,
230+
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` are specific to a particular model. They are required to implement the ``tokenize_messages`` method,
231231
which converts a list of Messages into a list of token IDs.
232232

233233
.. code-block:: python
@@ -259,7 +259,7 @@ is because they add all the necessary special tokens or prompt templates require
259259
.. code-block:: python
260260
261261
from torchtune.models.mistral import mistral_tokenizer
262-
from torchtune.modules.tokenizers import SentencePieceBaseTokenizer
262+
from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer
263263
from torchtune.data import Message
264264
265265
m_tokenizer = mistral_tokenizer("/tmp/Mistral-7B-v0.1/tokenizer.model")

recipes/eleuther_eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
from torchtune.modules import TransformerDecoder
3232
from torchtune.modules.common_utils import local_kv_cache
3333
from torchtune.modules.model_fusion import DeepFusionModel
34-
from torchtune.modules.tokenizers import ModelTokenizer
3534
from torchtune.modules.transforms import Transform
35+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
3636
from torchtune.recipe_interfaces import EvalRecipeInterface
3737
from torchtune.training import FullModelTorchTuneCheckpointer
3838

tests/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import torch
2121
from torch import nn
2222
from torchtune.data import Message, PromptTemplate, truncate
23-
from torchtune.modules.tokenizers import ModelTokenizer
2423
from torchtune.modules.transforms import Transform
24+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
2525

2626
skip_if_cuda_not_available = unittest.skipIf(
2727
not torch.cuda.is_available(), "CUDA is not available"

tests/torchtune/modules/tokenizers/test_sentencepiece.py tests/torchtune/modules/transforms/tokenizers/test_sentencepiece.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
from tests.common import ASSETS
10-
from torchtune.modules.tokenizers import SentencePieceBaseTokenizer
10+
from torchtune.modules.transforms.tokenizers import SentencePieceBaseTokenizer
1111

1212

1313
class TestSentencePieceBaseTokenizer:

tests/torchtune/modules/tokenizers/test_tiktoken.py tests/torchtune/modules/transforms/tokenizers/test_tiktoken.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tests.common import ASSETS
1010
from torchtune.models.llama3._tokenizer import CL100K_PATTERN
11-
from torchtune.modules.tokenizers import TikTokenBaseTokenizer
11+
from torchtune.modules.transforms.tokenizers import TikTokenBaseTokenizer
1212

1313

1414
class TestTikTokenBaseTokenizer:

tests/torchtune/modules/tokenizers/test_utils.py tests/torchtune/modules/transforms/tokenizers/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tests.test_utils import DummyTokenizer
1010
from torchtune.data import Message
1111

12-
from torchtune.modules.tokenizers import tokenize_messages_no_special_tokens
12+
from torchtune.modules.transforms.tokenizers import tokenize_messages_no_special_tokens
1313

1414

1515
class TestTokenizerUtils:

torchtune/data/_messages.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
class Message:
2323
"""
2424
This class represents individual messages in a fine-tuning dataset. It supports
25-
text-only content, text with interleaved images, and tool calls. The :class:`~torchtune.modules.tokenizers.ModelTokenizer`
26-
will tokenize the content of the message using ``tokenize_messages`` and attach
27-
the appropriate special tokens based on the flags set in this class.
25+
text-only content, text with interleaved images, and tool calls. The
26+
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` will tokenize
27+
the content of the message using ``tokenize_messages`` and attach the appropriate
28+
special tokens based on the flags set in this class.
2829
2930
Args:
3031
role (Role): role of the message writer. Can be "system" for system prompts,

torchtune/datasets/_alpaca.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from torchtune.datasets._packed import PackedDataset
1414
from torchtune.datasets._sft import SFTDataset
15-
from torchtune.modules.tokenizers import ModelTokenizer
15+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1616

1717

1818
def alpaca_dataset(

torchtune/datasets/_chat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torchtune.data._messages import OpenAIToMessages, ShareGPTToMessages
1010
from torchtune.datasets._packed import PackedDataset
1111
from torchtune.datasets._sft import SFTDataset
12-
from torchtune.modules.tokenizers import ModelTokenizer
12+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1313

1414

1515
def chat_dataset(

torchtune/datasets/_cnn_dailymail.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from torchtune.datasets._text_completion import TextCompletionDataset
1010

11-
from torchtune.modules.tokenizers import ModelTokenizer
11+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1212

1313

1414
def cnn_dailymail_articles_dataset(

torchtune/datasets/_grammar.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torchtune.data import InputOutputToMessages
1111
from torchtune.datasets._packed import PackedDataset
1212
from torchtune.datasets._sft import SFTDataset
13-
from torchtune.modules.tokenizers import ModelTokenizer
13+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1414

1515

1616
def grammar_dataset(

torchtune/datasets/_hh_rlhf_helpful.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from torchtune.data import ChosenRejectedToMessages
1010
from torchtune.datasets._preference import PreferenceDataset
11-
from torchtune.modules.tokenizers import ModelTokenizer
11+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1212

1313

1414
def hh_rlhf_helpful_dataset(

torchtune/datasets/_instruct.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torchtune.data import InputOutputToMessages
1010
from torchtune.datasets._packed import PackedDataset
1111
from torchtune.datasets._sft import SFTDataset
12-
from torchtune.modules.tokenizers import ModelTokenizer
12+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1313

1414

1515
def instruct_dataset(

torchtune/datasets/_preference.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from torch.utils.data import Dataset
1212

1313
from torchtune.data import ChosenRejectedToMessages, CROSS_ENTROPY_IGNORE_IDX
14-
15-
from torchtune.modules.tokenizers import ModelTokenizer
1614
from torchtune.modules.transforms import Transform
1715

16+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
17+
1818

1919
class PreferenceDataset(Dataset):
2020
"""
@@ -84,7 +84,7 @@ class requires the dataset to have "chosen" and "rejected" model responses. Thes
8484
of messages are stored in the ``"chosen"`` and ``"rejected"`` keys.
8585
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method.
8686
Since PreferenceDataset only supports text data, it requires a
87-
:class:`~torchtune.modules.tokenizers.ModelTokenizer` instead of the ``model_transform`` in
87+
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` instead of the ``model_transform`` in
8888
:class:`~torchtune.datasets.SFTDataset`.
8989
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
9090
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more

torchtune/datasets/_samsum.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torchtune.data import InputOutputToMessages
1111
from torchtune.datasets._packed import PackedDataset
1212
from torchtune.datasets._sft import SFTDataset
13-
from torchtune.modules.tokenizers import ModelTokenizer
13+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1414

1515

1616
def samsum_dataset(

torchtune/datasets/_sft.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,13 @@ class SFTDataset(Dataset):
6969
multimodal datasets requires processing the images in a way specific to the vision
7070
encoder being used by the model and is agnostic to the specific dataset.
7171
72-
Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`
73-
can be treated as a ``model_transform`` since it uses the model-specific tokenizer to
74-
transform the list of messages outputted from the ``message_transform`` into tokens
75-
used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer`
76-
into ``model_transform``. Tokenizers handle prompt templating, if configured.
72+
Tokenization is handled by the ``model_transform``. All
73+
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` can be treated as
74+
a ``model_transform`` since it uses the model-specific tokenizer to transform the
75+
list of messages outputted from the ``message_transform`` into tokens used by the
76+
model for training. Text-only datasets will simply pass the
77+
:class:`~torchtune.modules.transforms.tokenizers.ModelTokenizer` into ``model_transform``.
78+
Tokenizers handle prompt templating, if configured.
7779
7880
Args:
7981
source (str): path to dataset repository on Hugging Face. For local datasets,

torchtune/datasets/_slimorca.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torchtune.datasets._packed import PackedDataset
1111

1212
from torchtune.datasets._sft import SFTDataset
13-
from torchtune.modules.tokenizers import ModelTokenizer
13+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1414

1515

1616
def slimorca_dataset(

torchtune/datasets/_stack_exchange_paired.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from torchtune.data import Message
1010
from torchtune.datasets._preference import PreferenceDataset
11-
from torchtune.modules.tokenizers import ModelTokenizer
1211
from torchtune.modules.transforms import Transform
12+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1313

1414

1515
class StackExchangePairedToMessages(Transform):

torchtune/datasets/_text_completion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.utils.data import Dataset
1111
from torchtune.data._utils import truncate
1212
from torchtune.datasets._packed import PackedDataset
13-
from torchtune.modules.tokenizers import ModelTokenizer
13+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1414

1515

1616
class TextCompletionDataset(Dataset):

torchtune/datasets/_wikitext.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
TextCompletionDataset,
1414
)
1515

16-
from torchtune.modules.tokenizers import ModelTokenizer
16+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1717

1818

1919
def wikitext_dataset(

torchtune/models/clip/_tokenizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import regex as re
99

10-
from torchtune.modules.tokenizers._utils import BaseTokenizer
10+
from torchtune.modules.transforms.tokenizers._utils import BaseTokenizer
1111

1212
WORD_BOUNDARY = "</w>"
1313

torchtune/models/gemma/_tokenizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from typing import Any, List, Mapping, Optional, Tuple
88

99
from torchtune.data import Message, PromptTemplate
10-
from torchtune.modules.tokenizers import (
10+
from torchtune.modules.transforms import Transform
11+
from torchtune.modules.transforms.tokenizers import (
1112
ModelTokenizer,
1213
SentencePieceBaseTokenizer,
1314
tokenize_messages_no_special_tokens,
1415
)
15-
from torchtune.modules.transforms import Transform
1616

1717
WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"]
1818

torchtune/models/llama2/_tokenizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
from torchtune.data import Message, PromptTemplate
1010
from torchtune.models.llama2._prompt_template import Llama2ChatTemplate
11-
from torchtune.modules.tokenizers import (
11+
from torchtune.modules.transforms import Transform
12+
from torchtune.modules.transforms.tokenizers import (
1213
ModelTokenizer,
1314
SentencePieceBaseTokenizer,
1415
tokenize_messages_no_special_tokens,
1516
)
16-
from torchtune.modules.transforms import Transform
1717

1818
WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"]
1919

torchtune/models/llama3/_model_builders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from torchtune.modules import TransformerDecoder
1515
from torchtune.modules.peft import LORA_ATTN_MODULES
16-
from torchtune.modules.tokenizers import parse_hf_tokenizer_json
16+
from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json
1717

1818

1919
"""

torchtune/models/llama3/_tokenizer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
from typing import Any, Dict, List, Mapping, Optional, Tuple
99

1010
from torchtune.data import Message, PromptTemplate, truncate
11-
from torchtune.modules.tokenizers import ModelTokenizer, TikTokenBaseTokenizer
1211
from torchtune.modules.transforms import Transform
12+
from torchtune.modules.transforms.tokenizers import (
13+
ModelTokenizer,
14+
TikTokenBaseTokenizer,
15+
)
1316

1417

1518
CL100K_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa

torchtune/models/llama3_2_vision/_transform.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
from torchtune.models.clip import CLIPImageTransform
1212
from torchtune.models.llama3 import llama3_tokenizer
13-
from torchtune.modules.tokenizers import ModelTokenizer
1413
from torchtune.modules.transforms import Transform, VisionCrossAttentionMask
14+
from torchtune.modules.transforms.tokenizers import ModelTokenizer
1515

1616

1717
class Llama3VisionTransform(ModelTokenizer, Transform):

torchtune/models/mistral/_tokenizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
from torchtune.data import Message, PromptTemplate
1010
from torchtune.models.mistral._prompt_template import MistralChatTemplate
11-
from torchtune.modules.tokenizers import (
11+
from torchtune.modules.transforms import Transform
12+
from torchtune.modules.transforms.tokenizers import (
1213
ModelTokenizer,
1314
SentencePieceBaseTokenizer,
1415
tokenize_messages_no_special_tokens,
1516
)
16-
from torchtune.modules.transforms import Transform
1717

1818
WHITESPACE_CHARS = [" ", "\n", "\t", "\r", "\v"]
1919

torchtune/models/phi3/_model_builders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torchtune.modules import TransformerDecoder
77
from torchtune.modules.peft import LORA_ATTN_MODULES
88
from functools import partial
9-
from torchtune.modules.tokenizers import parse_hf_tokenizer_json
9+
from torchtune.modules.transforms.tokenizers import parse_hf_tokenizer_json
1010
from torchtune.data._prompt_templates import _TemplateType
1111
from torchtune.data._prompt_templates import _get_prompt_template
1212

torchtune/models/phi3/_tokenizer.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
from torchtune.data._messages import Message
1010
from torchtune.data._prompt_templates import PromptTemplate
1111
from torchtune.data._utils import truncate
12-
from torchtune.modules.tokenizers import ModelTokenizer, SentencePieceBaseTokenizer
1312
from torchtune.modules.transforms import Transform
13+
from torchtune.modules.transforms.tokenizers import (
14+
ModelTokenizer,
15+
SentencePieceBaseTokenizer,
16+
)
1417

1518
PHI3_SPECIAL_TOKENS = {
1619
"<|endoftext|>": 32000,

0 commit comments

Comments
 (0)