-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsummsc_gradio.py
80 lines (73 loc) · 2.37 KB
/
summsc_gradio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from configs.opt_summsc import OPT
from configs.summsc_config import DEFAULT_CONFIG
import os
import deepl
import gradio as gr
from scripts.summarizer import summarize_text
from scripts.safety import filter_unsafe
os.system("cd ParlAI")
from parlai.core.agents import create_agent_from_model_file
def predict(query_fr,session=[]):
global context, chat_agent
""" Prediction function to speak to the agent """
query_en = deepl.translate(
text=query_fr,
target_language='EN',
source_language='FR'
)
if not session:
turn = {
'text': context+query_en,
'episode_done': False}
else:
turn = {
'text': query_en,
'episode_done': False
}
chat_agent.observe(turn)
response_en = chat_agent.act()
#print(response_en['beam_texts'])
response_en = filter_unsafe(
response_en['text'],
response_en['beam_texts'])
response_fr = deepl.translate(
text=response_en,
target_language='FR',
source_language='EN')
context = context + \
"partner's persona: " + query_en + "\n" \
+ "your persona: " + response_en + "\n"
session.append((query_fr,response_fr))
return session
if __name__ == '__main__':
use_cuda = DEFAULT_CONFIG["USE_CUDA"]
n_turns = DEFAULT_CONFIG["N_TURNS"]
model = DEFAULT_CONFIG["MODEL"]
print_terminal = DEFAULT_CONFIG["PRINT_TERMINAL"]
reset_history = DEFAULT_CONFIG["RESET_HISTORY"]
history_path = DEFAULT_CONFIG["HISTORY_PATH"]
#Setup of the model
opt = OPT
opt['no_cuda'] = not use_cuda
chat_agent = create_agent_from_model_file(model, opt)
#Reset agent
chat_agent.reset()
#Use_context
if reset_history:
with open(history_path,"w") as history:
history.write("")
with open(history_path,"r") as history:
context = ''.join(history.readlines())
context = "partner's persona: " + summarize_text(context) + '\n'
gr.Interface(fn=predict,
inputs="text",
outputs="chatbot",
title="Chatbot",
description="Chatbot",
).launch(share=True,
height=500,
width=500,)
#Write context to file
with open(history_path,"w") as history:
history.write(context)
history.close()