Skip to content

Commit dede55e

Browse files
committed
Minor fixes in splits
1 parent 47ba458 commit dede55e

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

Makefile

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ VENV = .venv
22
PYTHON = $(VENV)/bin/python3
33
PIP = $(VENV)/bin/pip3
44
UVICORN = $(VENV)/bin/uvicorn
5+
PYTEST = $(VENV)/bin/pytest
56

67
include .env
78
export
@@ -17,6 +18,8 @@ crawl: $(VENV)/bin/activate
1718
esgpt: $(VENV)/bin/activate
1819
$(PYTHON) es_gpt.py
1920

21+
test: $(VENV)/bin/activate
22+
$(PYTEST) --verbose es_gpt_test.py -s -vv
2023
app: $(VENV)/bin/activate
2124
$(UVICORN) app:app --reload --port 7002
2225

es_gpt.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import requests
44
import re
55
import pandas as pd
6-
6+
import string
77
from elasticsearch import Elasticsearch
88

99
import tiktoken
@@ -19,7 +19,7 @@
1919

2020
class ESGPT:
2121
def __init__(self, index_name):
22-
self.es = Elasticsearch(ES_URL, http_auth=(ES_USER, ES_PASS),
22+
self.es = Elasticsearch(ES_URL, basic_auth=(ES_USER, ES_PASS),
2323
ca_certs=ES_CA_CERT, verify_certs=True)
2424
self.index_name = index_name
2525

@@ -59,11 +59,12 @@ def _paper_results_to_text(self, results):
5959
# Code from https://github.com/openai/openai-cookbook/blob/main/apps/web-crawl-q-and-a/web-qa.py
6060
# Function to split the text into chunks of a maximum number of tokens
6161
def _split_into_many(self, text):
62+
sentences = []
63+
for sentence in text.split('.'):
64+
sentence = sentence.strip()
65+
if sentence and (any(char.isalpha() for char in sentence) or any(char.isdigit() for char in sentence)) and (not all(char in string.punctuation for char in sentence)):
66+
sentences.append(sentence)
6267

63-
# Split the text into sentences
64-
sentences = text.split('. ')
65-
66-
# Get the number of tokens for each sentence
6768
n_tokens = [len(self.tokenizer.encode(" " + sentence))
6869
for sentence in sentences]
6970

@@ -96,16 +97,18 @@ def _split_into_many(self, text):
9697

9798
return chunks
9899

100+
def _get_embedding(self, input):
101+
return openai.Embedding.create(
102+
input=input, engine='text-embedding-ada-002')['data'][0]['embedding']
103+
99104
def _create_emb_dict_list(self, long_text):
100105
shortened = self._split_into_many(long_text)
101106

102107
embeddings_dict_list = []
103108

104109
for text in shortened:
105110
n_tokens = len(self.tokenizer.encode(text))
106-
embeddings = openai.Embedding.create(
107-
input=text,
108-
engine='text-embedding-ada-002')['data'][0]['embedding']
111+
embeddings = self._get_embedding(input=text)
109112
embeddings_dict = {}
110113
embeddings_dict["text"] = text
111114
embeddings_dict["n_tokens"] = n_tokens
@@ -120,8 +123,7 @@ def _create_context(self, question, df):
120123
"""
121124

122125
# Get the embeddings for the question
123-
q_embeddings = openai.Embedding.create(
124-
input=question, engine='text-embedding-ada-002')['data'][0]['embedding']
126+
q_embeddings = self._get_embedding(input=question)
125127

126128
# Get the distances from the embeddings
127129
df['distances'] = distances_from_embeddings(
@@ -132,7 +134,6 @@ def _create_context(self, question, df):
132134

133135
# Sort by distance and add the text to the context until the context is too long
134136
for i, row in df.sort_values('distances', ascending=True).iterrows():
135-
136137
# Add the length of the text to the current length
137138
cur_len += row['n_tokens'] + 4
138139

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ matplotlib
1010
plotly
1111
pandas
1212
scipy
13-
scikit-learn
13+
scikit-learn
14+
pytest

0 commit comments

Comments
 (0)