Skip to content

Commit bde1e0f

Browse files
authored
Merge pull request #757 from yusefes/fix-tokenizer-loading
Fix tokenizer loading for GPT2
2 parents 488821a + d291819 commit bde1e0f

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

scrapegraphai/utils/tokenizer.py

+8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from langchain_ollama import ChatOllama
77
from langchain_mistralai import ChatMistralAI
88
from langchain_core.language_models.chat_models import BaseChatModel
9+
from transformers import GPT2TokenizerFast
910

1011
def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int:
1112
"""
@@ -23,6 +24,13 @@ def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int:
2324
from .tokenizers.tokenizer_ollama import num_tokens_ollama
2425
num_tokens_fn = num_tokens_ollama
2526

27+
elif isinstance(llm_model, GPT2TokenizerFast):
28+
def num_tokens_gpt2(text: str, model: BaseChatModel) -> int:
29+
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
30+
tokens = tokenizer.encode(text)
31+
return len(tokens)
32+
num_tokens_fn = num_tokens_gpt2
33+
2634
else:
2735
from .tokenizers.tokenizer_openai import num_tokens_openai
2836
num_tokens_fn = num_tokens_openai

scrapegraphai/utils/tokenizers/tokenizer_ollama.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
from langchain_core.language_models.chat_models import BaseChatModel
55
from ..logging import get_logger
6+
from transformers import GPT2TokenizerFast
67

78
def num_tokens_ollama(text: str, llm_model:BaseChatModel) -> int:
89
"""
@@ -21,8 +22,12 @@ def num_tokens_ollama(text: str, llm_model:BaseChatModel) -> int:
2122

2223
logger.debug(f"Counting tokens for text of {len(text)} characters")
2324

25+
if isinstance(llm_model, GPT2TokenizerFast):
26+
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
27+
tokens = tokenizer.encode(text)
28+
return len(tokens)
29+
2430
# Use langchain token count implementation
2531
# NB: https://github.com/ollama/ollama/issues/1716#issuecomment-2074265507
2632
tokens = llm_model.get_num_tokens(text)
2733
return tokens
28-

tests/graphs/smart_scraper_ollama_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import pytest
55
from scrapegraphai.graphs import SmartScraperGraph
6+
from transformers import GPT2TokenizerFast
67

78

89
@pytest.fixture
@@ -50,3 +51,11 @@ def test_get_execution_info(graph_config: dict):
5051
graph_exec_info = smart_scraper_graph.get_execution_info()
5152

5253
assert graph_exec_info is not None
54+
55+
56+
def test_gpt2_tokenizer_loading():
57+
"""
58+
Test loading of GPT2TokenizerFast
59+
"""
60+
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
61+
assert tokenizer is not None

0 commit comments

Comments
 (0)