Skip to content

Commit ca473b3

Browse files
committed
adding add_end_token to Qwen
1 parent 1241231 commit ca473b3

File tree

1 file changed

+94
-56
lines changed

1 file changed

+94
-56
lines changed

torchtune/models/qwen2/_tokenizer.py

+94-56
Original file line numberDiff line numberDiff line change
@@ -325,26 +325,85 @@ def decode(
325325
text = "".join(sub_texts)
326326
return text
327327

328+
def _tokenize_header(self, message: Message) -> List[int]:
329+
return (
330+
[self.im_start_id]
331+
+ self.encode(f"{message.role}\n", add_bos=False, add_eos=False)
332+
)
333+
334+
def _tokenize_body(self, message: Message) -> List[int]:
335+
tokenized_body = []
336+
for item in message.content:
337+
if item["type"] == "text":
338+
tokenized_body += self.encode(
339+
item["content"], add_bos=False, add_eos=False,
340+
)
341+
else:
342+
raise RuntimeError(f"Unsupported message content type: {item['type']}")
343+
return tokenized_body
344+
345+
def _tokenize_end(self, message: Message) -> List[int]:
346+
return (
347+
[self.im_end_id]
348+
+self.encode("\n", add_bos=False, add_eos=False)
349+
)
350+
351+
352+
def tokenize_message(
353+
self,
354+
message: Message,
355+
*,
356+
add_start_tokens: bool = True,
357+
add_end_tokens: bool = True
358+
) -> List[int]:
359+
"""
360+
Tokenize a message into a list of token ids.
361+
362+
Args:
363+
message (Message): The message to tokenize.
364+
add_start_tokens (bool): Whether to prepend a tokenized header to the message. Default is True.
365+
add_end_tokens (bool): Whether to append eot or eom id at the end of the message. Default is True.
366+
367+
Returns:
368+
List[int]: The list of token ids.
369+
"""
370+
tokenized_header = self._tokenize_header(message) if add_start_tokens else []
371+
tokenized_body = self._tokenize_body(message)
372+
tokenized_end = self._tokenize_end(message) if add_end_tokens else []
373+
374+
tokenized_message = tokenized_header + tokenized_body + tokenized_end
375+
return tokenized_message
376+
328377
def tokenize_messages(
329378
self,
330379
messages: List[Message],
331380
*,
332-
add_eos: bool = True,
381+
add_end_tokens: bool = True,
333382
) -> Tuple[List[int], List[bool]]:
334383
"""
335-
Given a list of messages, return a list of tokens for the concatenated
336-
and formatted messages.
384+
Tokenize a list of messages into a list of token ids and masks.
337385
338386
Args:
339-
messages (List[Message]): The message list to tokenize.
340-
add_eos (bool): Wether to add the tokenizer's eos_id at the end of the
341-
sequence of messages. Default is True.
387+
messages (List[Message]): The list of messages to tokenize.
388+
add_end_tokens (bool): Whether to append end tokens ids (end-of-seq, end-of-turn, end-of-message) at the end of the
389+
last assistant message. This value should be set to False for generation. Default is True.
390+
391+
Examples:
392+
>>> # Tokenize a list of messages with default settings
393+
>>> messages = [
394+
... Message(role="user", content="Hello world!", masked=True),
395+
... Message(role="assistant", content="How are you?", masked=False),
396+
... ]
397+
>>> tokenizer = Qwen2Tokenizer("/path/to/tt_model")
398+
>>> tokenizer.tokenize_messages(messages)
399+
([1, 31587, 29644, 102, 1, 31587, 29644, 102, 2], [True, True, True, True, True, False, False, False, True])
400+
401+
>>> # Tokenize a list of messages with add_end_tokens set to False
402+
>>> tokenizer.tokenize_messages(messages, add_end_tokens=False)
403+
([1, 31587, 29644, 102, 1, 31587, 29644], [True, True, True, True, True, False, False])
342404
343405
Returns:
344406
Tuple[List[int], List[bool]]: The list of token ids and the list of masks.
345-
346-
Raises:
347-
RuntimeError: If a message contains non-text content
348407
"""
349408
assert not isinstance(self.prompt_template, ChatMLTemplate), (
350409
"Using ChatMLTemplate with tokenize_messages will result in multiple <|im_*|> tokens wrapping each message."
@@ -355,69 +414,48 @@ def tokenize_messages(
355414
if self.prompt_template is not None
356415
else messages
357416
)
417+
tokens = [self.bos_id]
418+
# bos and eos are always masked
419+
mask = [True]
420+
421+
num_messages = len(templated_messages)
422+
for i, message in enumerate(templated_messages):
423+
# Add end tokens to the last assistant message if add_end_tokens is True
424+
# Otherwise, end tokens should always be added
425+
add_end_tokens_to_message = (
426+
add_end_tokens if i == num_messages - 1 else True
427+
)
428+
tokenized_message = self.tokenize_message(
429+
message, add_end_tokens=add_end_tokens_to_message
430+
)
358431

359-
tokenized_messages = []
360-
mask = []
361-
for index, message in enumerate(templated_messages):
362-
tokens = []
363-
364-
# message header
365-
if message.role != "ipython":
366-
tokens.append(self.im_start_id)
367-
tokens.extend(
368-
self.encode(f"{message.role}\n", add_bos=False, add_eos=False)
369-
)
370-
371-
# message content
372-
for item in message.content:
373-
if item["type"] == "text":
374-
tokens.extend(
375-
self.encode(
376-
item["content"],
377-
add_bos=False,
378-
add_eos=False,
379-
)
380-
)
381-
else:
382-
raise RuntimeError(
383-
f"Unsupported message content type: {item['type']}"
384-
)
385-
386-
# message footer
387-
if message.role != "ipython" and (
388-
message.role != "assistant" or index != len(messages) - 1
389-
):
390-
tokens.append(self.im_end_id)
391-
tokens.extend(self.encode("\n", add_bos=False, add_eos=False))
392-
393-
tokenized_messages.extend(tokens)
394-
mask.extend([message.masked] * len(tokens))
432+
tokens = tokens + tokenized_message
433+
mask = mask + ([message.masked] * len(tokenized_message))
395434

396435
# Break out early if we reach max_seq_len
397-
if self.max_seq_len and len(tokenized_messages) >= self.max_seq_len:
436+
if self.max_seq_len and len(tokens) >= self.max_seq_len:
398437
break
399438

400-
# Add the End-Of-Sequence token
401-
if add_eos:
402-
tokenized_messages.append(self.eos_id)
403-
mask.append(mask[-1])
439+
if add_end_tokens:
440+
tokens = tokens + [self.eos_id]
441+
mask = mask + [True]
404442

405443
# Finally, truncate if necessary
406444
if self.max_seq_len:
407-
tokenized_messages = truncate(
408-
tokens=tokenized_messages,
445+
tokens = truncate(
446+
tokens=tokens,
409447
max_seq_len=self.max_seq_len,
410-
eos_id=self.eos_id if add_eos else None,
448+
eos_id=self.eos_id if add_end_tokens else None,
411449
truncation_type=self.truncation_type,
412450
)
413451
mask = truncate(
414452
tokens=mask,
415453
max_seq_len=self.max_seq_len,
416-
eos_id=True if add_eos else None,
454+
eos_id=True if add_end_tokens else None,
417455
truncation_type=self.truncation_type,
418456
)
419457

420-
return tokenized_messages, mask
458+
return tokens, mask
421459

422460
def __call__(
423461
self, sample: Mapping[str, Any], inference: bool = False
@@ -436,7 +474,7 @@ def __call__(
436474
inference (bool): Whether the template is being used for inference or not.
437475
"""
438476
messages = sample.pop("messages")
439-
tokens, mask = self.tokenize_messages(messages)
477+
tokens, mask = self.tokenize_messages(messages, add_end_tokens=not inference)
440478
sample["tokens"] = tokens
441479
sample["mask"] = mask
442480
return sample

0 commit comments

Comments
 (0)