Skip to content

Commit 27142a4

Browse files
committed
Added sentence transformer for emb
1 parent dede55e commit 27142a4

File tree

4 files changed

+46
-10
lines changed

4 files changed

+46
-10
lines changed

Makefile

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ $(VENV)/bin/activate: requirements.txt
1212
python3 -m venv $(VENV)
1313
$(PIP) install -r requirements.txt
1414

15+
emb: $(VENV)/bin/activate
16+
$(PYTHON) emb.py
17+
1518
crawl: $(VENV)/bin/activate
1619
$(PYTHON) crawl_index.py
1720

@@ -20,6 +23,7 @@ esgpt: $(VENV)/bin/activate
2023

2124
test: $(VENV)/bin/activate
2225
$(PYTEST) --verbose es_gpt_test.py -s -vv
26+
2327
app: $(VENV)/bin/activate
2428
$(UVICORN) app:app --reload --port 7002
2529

emb.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import os
2+
import openai
3+
from sentence_transformers import SentenceTransformer
4+
5+
6+
EMB_USE_OPENAI = os.getenv('EMB_USE_OPENAI', '0')
7+
8+
9+
def _get_openai_embedding(input):
10+
openai.api_key = os.environ["OPENAI_API_KEY"]
11+
return openai.Embedding.create(
12+
input=input, engine='text-embedding-ada-002')['data'][0]['embedding']
13+
14+
15+
def _get_transformer_embedding(input):
16+
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
17+
18+
# Sentences are encoded by calling model.encode()
19+
embedding = model.encode(input)
20+
return embedding
21+
22+
23+
def get_embedding(input):
24+
if EMB_USE_OPENAI == '1':
25+
return _get_openai_embedding(input)
26+
else:
27+
return _get_transformer_embedding(input)
28+
29+
30+
if __name__ == "__main__":
31+
print("Transformer: ", _get_transformer_embedding('hello world')[0])
32+
print("OpenAI: ", _get_openai_embedding('hello world'))

es_gpt.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import tiktoken
1010
import openai
1111
from openai.embeddings_utils import distances_from_embeddings
12+
from emb import get_embedding
1213

1314

1415
ES_URL = os.environ["ES_URL"]
@@ -29,6 +30,7 @@ def __init__(self, index_name):
2930
self.api_key = os.environ["OPENAI_API_KEY"]
3031
openai.api_key = self.api_key
3132
self.max_tokens = 1000
33+
self.split_max_tokens = 500
3234

3335
# Load the cl100k_base tokenizer which is designed to work with the ada-002 model
3436
self.tokenizer = tiktoken.get_encoding("cl100k_base")
@@ -60,9 +62,9 @@ def _paper_results_to_text(self, results):
6062
# Function to split the text into chunks of a maximum number of tokens
6163
def _split_into_many(self, text):
6264
sentences = []
63-
for sentence in text.split('.'):
65+
for sentence in re.split(r'[{}]'.format(string.punctuation), text):
6466
sentence = sentence.strip()
65-
if sentence and (any(char.isalpha() for char in sentence) or any(char.isdigit() for char in sentence)) and (not all(char in string.punctuation for char in sentence)):
67+
if sentence and (any(char.isalpha() for char in sentence) or any(char.isdigit() for char in sentence)):
6668
sentences.append(sentence)
6769

6870
n_tokens = [len(self.tokenizer.encode(" " + sentence))
@@ -77,14 +79,14 @@ def _split_into_many(self, text):
7779
# If the number of tokens so far plus the number of tokens in the current sentence is greater
7880
# than the max number of tokens, then add the chunk to the list of chunks and reset
7981
# the chunk and tokens so far
80-
if tokens_so_far + token > self.max_tokens:
82+
if tokens_so_far + token > self.split_max_tokens and chunk:
8183
chunks.append(". ".join(chunk) + ".")
8284
chunk = []
8385
tokens_so_far = 0
8486

8587
# If the number of tokens in the current sentence is greater than the max number of
8688
# tokens, go to the next sentence
87-
if token > self.max_tokens:
89+
if token > self.split_max_tokens:
8890
continue
8991

9092
# Otherwise, add the sentence to the chunk and add the number of tokens to the total
@@ -97,9 +99,6 @@ def _split_into_many(self, text):
9799

98100
return chunks
99101

100-
def _get_embedding(self, input):
101-
return openai.Embedding.create(
102-
input=input, engine='text-embedding-ada-002')['data'][0]['embedding']
103102

104103
def _create_emb_dict_list(self, long_text):
105104
shortened = self._split_into_many(long_text)
@@ -108,7 +107,7 @@ def _create_emb_dict_list(self, long_text):
108107

109108
for text in shortened:
110109
n_tokens = len(self.tokenizer.encode(text))
111-
embeddings = self._get_embedding(input=text)
110+
embeddings = get_embedding(input=text)
112111
embeddings_dict = {}
113112
embeddings_dict["text"] = text
114113
embeddings_dict["n_tokens"] = n_tokens
@@ -123,7 +122,7 @@ def _create_context(self, question, df):
123122
"""
124123

125124
# Get the embeddings for the question
126-
q_embeddings = self._get_embedding(input=question)
125+
q_embeddings = get_embedding(input=question)
127126

128127
# Get the distances from the embeddings
129128
df['distances'] = distances_from_embeddings(

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ plotly
1111
pandas
1212
scipy
1313
scikit-learn
14-
pytest
14+
pytest
15+
sentence-transformers

0 commit comments

Comments
 (0)