diff --git a/tests/models/gpt2/test_tokenization_gpt2.py b/tests/models/gpt2/test_tokenization_gpt2.py index 653a1d21fc07..3273fbfce773 100644 --- a/tests/models/gpt2/test_tokenization_gpt2.py +++ b/tests/models/gpt2/test_tokenization_gpt2.py @@ -299,3 +299,33 @@ def test_serialize_deserialize_fast_opt(self): text, ) self.assertEqual(tokens_ids, [2, 250, 1345, 9, 10, 4758]) + + def test_fast_slow_equivalence(self): + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m", use_slow=True) + text = "A photo of a cat" + + tokens_ids = tokenizer.encode( + text, + ) + # Same as above + self.assertEqual(tokens_ids, [2, 250, 1345, 9, 10, 4758]) + + def test_users_can_modify_bos(self): + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m", from_slow=True) + + tokenizer.bos_token = "bos" + tokenizer.bos_token_id = tokenizer.get_vocab()["bos"] + + text = "A photo of a cat" + tokens_ids = tokenizer.encode( + text, + ) + # We changed the bos token + self.assertEqual(tokens_ids, [31957, 250, 1345, 9, 10, 4758]) + tokenizer.save_pretrained("./tok") + tokenizer = AutoTokenizer.from_pretrained("./tok") + self.assertTrue(tokenizer.is_fast) + tokens_ids = tokenizer.encode( + text, + ) + self.assertEqual(tokens_ids, [31957, 250, 1345, 9, 10, 4758])