@@ -325,26 +325,85 @@ def decode(
325
325
text = "" .join (sub_texts )
326
326
return text
327
327
328
+ def _tokenize_header (self , message : Message ) -> List [int ]:
329
+ return (
330
+ [self .im_start_id ]
331
+ + self .encode (f"{ message .role } \n " , add_bos = False , add_eos = False )
332
+ )
333
+
334
+ def _tokenize_body (self , message : Message ) -> List [int ]:
335
+ tokenized_body = []
336
+ for item in message .content :
337
+ if item ["type" ] == "text" :
338
+ tokenized_body += self .encode (
339
+ item ["content" ], add_bos = False , add_eos = False ,
340
+ )
341
+ else :
342
+ raise RuntimeError (f"Unsupported message content type: { item ['type' ]} " )
343
+ return tokenized_body
344
+
345
+ def _tokenize_end (self , message : Message ) -> List [int ]:
346
+ return (
347
+ [self .im_end_id ]
348
+ + self .encode ("\n " , add_bos = False , add_eos = False )
349
+ )
350
+
351
+
352
+ def tokenize_message (
353
+ self ,
354
+ message : Message ,
355
+ * ,
356
+ add_start_tokens : bool = True ,
357
+ add_end_tokens : bool = True
358
+ ) -> List [int ]:
359
+ """
360
+ Tokenize a message into a list of token ids.
361
+
362
+ Args:
363
+ message (Message): The message to tokenize.
364
+ add_start_tokens (bool): Whether to prepend a tokenized header to the message. Default is True.
365
+ add_end_tokens (bool): Whether to append eot or eom id at the end of the message. Default is True.
366
+
367
+ Returns:
368
+ List[int]: The list of token ids.
369
+ """
370
+ tokenized_header = self ._tokenize_header (message ) if add_start_tokens else []
371
+ tokenized_body = self ._tokenize_body (message )
372
+ tokenized_end = self ._tokenize_end (message ) if add_end_tokens else []
373
+
374
+ tokenized_message = tokenized_header + tokenized_body + tokenized_end
375
+ return tokenized_message
376
+
328
377
def tokenize_messages (
329
378
self ,
330
379
messages : List [Message ],
331
380
* ,
332
- add_eos : bool = True ,
381
+ add_end_tokens : bool = True ,
333
382
) -> Tuple [List [int ], List [bool ]]:
334
383
"""
335
- Given a list of messages, return a list of tokens for the concatenated
336
- and formatted messages.
384
+ Tokenize a list of messages into a list of token ids and masks.
337
385
338
386
Args:
339
- messages (List[Message]): The message list to tokenize.
340
- add_eos (bool): Wether to add the tokenizer's eos_id at the end of the
341
- sequence of messages. Default is True.
387
+ messages (List[Message]): The list of messages to tokenize.
388
+ add_end_tokens (bool): Whether to append end tokens ids (end-of-seq, end-of-turn, end-of-message) at the end of the
389
+ last assistant message. This value should be set to False for generation. Default is True.
390
+
391
+ Examples:
392
+ >>> # Tokenize a list of messages with default settings
393
+ >>> messages = [
394
+ ... Message(role="user", content="Hello world!", masked=True),
395
+ ... Message(role="assistant", content="How are you?", masked=False),
396
+ ... ]
397
+ >>> tokenizer = Qwen2Tokenizer("/path/to/tt_model")
398
+ >>> tokenizer.tokenize_messages(messages)
399
+ ([1, 31587, 29644, 102, 1, 31587, 29644, 102, 2], [True, True, True, True, True, False, False, False, True])
400
+
401
+ >>> # Tokenize a list of messages with add_end_tokens set to False
402
+ >>> tokenizer.tokenize_messages(messages, add_end_tokens=False)
403
+ ([1, 31587, 29644, 102, 1, 31587, 29644], [True, True, True, True, True, False, False])
342
404
343
405
Returns:
344
406
Tuple[List[int], List[bool]]: The list of token ids and the list of masks.
345
-
346
- Raises:
347
- RuntimeError: If a message contains non-text content
348
407
"""
349
408
assert not isinstance (self .prompt_template , ChatMLTemplate ), (
350
409
"Using ChatMLTemplate with tokenize_messages will result in multiple <|im_*|> tokens wrapping each message."
@@ -355,69 +414,48 @@ def tokenize_messages(
355
414
if self .prompt_template is not None
356
415
else messages
357
416
)
417
+ tokens = [self .bos_id ]
418
+ # bos and eos are always masked
419
+ mask = [True ]
420
+
421
+ num_messages = len (templated_messages )
422
+ for i , message in enumerate (templated_messages ):
423
+ # Add end tokens to the last assistant message if add_end_tokens is True
424
+ # Otherwise, end tokens should always be added
425
+ add_end_tokens_to_message = (
426
+ add_end_tokens if i == num_messages - 1 else True
427
+ )
428
+ tokenized_message = self .tokenize_message (
429
+ message , add_end_tokens = add_end_tokens_to_message
430
+ )
358
431
359
- tokenized_messages = []
360
- mask = []
361
- for index , message in enumerate (templated_messages ):
362
- tokens = []
363
-
364
- # message header
365
- if message .role != "ipython" :
366
- tokens .append (self .im_start_id )
367
- tokens .extend (
368
- self .encode (f"{ message .role } \n " , add_bos = False , add_eos = False )
369
- )
370
-
371
- # message content
372
- for item in message .content :
373
- if item ["type" ] == "text" :
374
- tokens .extend (
375
- self .encode (
376
- item ["content" ],
377
- add_bos = False ,
378
- add_eos = False ,
379
- )
380
- )
381
- else :
382
- raise RuntimeError (
383
- f"Unsupported message content type: { item ['type' ]} "
384
- )
385
-
386
- # message footer
387
- if message .role != "ipython" and (
388
- message .role != "assistant" or index != len (messages ) - 1
389
- ):
390
- tokens .append (self .im_end_id )
391
- tokens .extend (self .encode ("\n " , add_bos = False , add_eos = False ))
392
-
393
- tokenized_messages .extend (tokens )
394
- mask .extend ([message .masked ] * len (tokens ))
432
+ tokens = tokens + tokenized_message
433
+ mask = mask + ([message .masked ] * len (tokenized_message ))
395
434
396
435
# Break out early if we reach max_seq_len
397
- if self .max_seq_len and len (tokenized_messages ) >= self .max_seq_len :
436
+ if self .max_seq_len and len (tokens ) >= self .max_seq_len :
398
437
break
399
438
400
- # Add the End-Of-Sequence token
401
- if add_eos :
402
- tokenized_messages .append (self .eos_id )
403
- mask .append (mask [- 1 ])
439
+ if add_end_tokens :
440
+ tokens = tokens + [self .eos_id ]
441
+ mask = mask + [True ]
404
442
405
443
# Finally, truncate if necessary
406
444
if self .max_seq_len :
407
- tokenized_messages = truncate (
408
- tokens = tokenized_messages ,
445
+ tokens = truncate (
446
+ tokens = tokens ,
409
447
max_seq_len = self .max_seq_len ,
410
- eos_id = self .eos_id if add_eos else None ,
448
+ eos_id = self .eos_id if add_end_tokens else None ,
411
449
truncation_type = self .truncation_type ,
412
450
)
413
451
mask = truncate (
414
452
tokens = mask ,
415
453
max_seq_len = self .max_seq_len ,
416
- eos_id = True if add_eos else None ,
454
+ eos_id = True if add_end_tokens else None ,
417
455
truncation_type = self .truncation_type ,
418
456
)
419
457
420
- return tokenized_messages , mask
458
+ return tokens , mask
421
459
422
460
def __call__ (
423
461
self , sample : Mapping [str , Any ], inference : bool = False
@@ -436,7 +474,7 @@ def __call__(
436
474
inference (bool): Whether the template is being used for inference or not.
437
475
"""
438
476
messages = sample .pop ("messages" )
439
- tokens , mask = self .tokenize_messages (messages )
477
+ tokens , mask = self .tokenize_messages (messages , add_end_tokens = not inference )
440
478
sample ["tokens" ] = tokens
441
479
sample ["mask" ] = mask
442
480
return sample
0 commit comments