@@ -35,6 +35,10 @@ def __init__(self, index_name):
35
35
# Load the cl100k_base tokenizer which is designed to work with the ada-002 model
36
36
self .tokenizer = tiktoken .get_encoding ("cl100k_base" )
37
37
38
+ self .answer_generation_prompt = "Based on the context below\" \n \n Context: {}\n \n ---\n \n Please provide concise answer for this questions: {}"
39
+ self .question_suggestion_prompt = "Based on the context below\" \n \n Context: {}\n \n ---\n \n Please recommend 3 more questions to be curious about {}"
40
+ self .just_question_prompt = "{}{}"
41
+
38
42
def index (self , doc_id , doc , text ):
39
43
doc ["embeddings_dict_list" ] = self ._create_emb_dict_list (text )
40
44
self .es .index (index = self .index_name ,
@@ -146,10 +150,17 @@ def _create_context(self, question, df):
146
150
# Return the context and the length of the context
147
151
return "\n \n ###\n \n " .join (returns ), cur_len
148
152
149
- def _gpt_api_call (self , query , input_token_len , context ):
153
+ def _gpt_api_call (self , query , input_token_len , context , call_type ):
154
+ if call_type == "answer" :
155
+ prompt = self .answer_generation_prompt
156
+ elif call_type == "question" :
157
+ prompt = self .just_question_prompt
158
+ else :
159
+ prompt = self .question_suggestion_prompt
160
+
150
161
body = {
151
162
"model" : self .model_engine ,
152
- "prompt" : f"Based on the context below \" \n \n Context: { context } \n \n --- \n \n Please provide concise answer for this questions: { query } " ,
163
+ "prompt" : prompt . format ( context , query ) ,
153
164
"max_tokens" : self .model_max_tokens - input_token_len ,
154
165
"n" : 1 ,
155
166
"temperature" : 0.5 ,
@@ -165,6 +176,7 @@ def _gpt_api_call(self, query, input_token_len, context):
165
176
stream = True )
166
177
return resp
167
178
179
+
168
180
def gpt_answer (self , query , es_results = None , text_results = None ):
169
181
# Generate summaries for each search result
170
182
if text_results :
@@ -204,7 +216,32 @@ def gpt_answer(self, query, es_results=None, text_results=None):
204
216
else :
205
217
assert False , "Must provide either es_results or text_results"
206
218
207
- return self ._gpt_api_call (query , input_token_len , context )
219
+ return self ._gpt_api_call (query , input_token_len , context , call_type = "answer" )
220
+
221
+ def gpt_question_generator (self , text_results = None ):
222
+ if text_results :
223
+ input_token_len = len (self .tokenizer .encode (text_results ))
224
+ if input_token_len < self .max_tokens :
225
+ context = text_results
226
+ else :
227
+ context = text_results [:self .max_tokens ]
228
+ input_token_len = self .max_tokens
229
+ else :
230
+ assert False , "Text results are not found"
231
+
232
+ return self ._gpt_api_call ("" , input_token_len , context , call_type = "suggestion" )
233
+
234
+ def gpt_direct_answer (self , q ):
235
+ input_token_len = len (self .tokenizer .encode (q ))
236
+ if input_token_len < self .max_tokens :
237
+ query = q
238
+ else :
239
+ query = q [:self .max_tokens ]
240
+ input_token_len = self .max_tokens
241
+ return self ._gpt_api_call (q , input_token_len , "" , call_type = "question" )
242
+
243
+
244
+
208
245
209
246
210
247
# Example usage
0 commit comments