ChartGPT / app.py
xiaofeifei's picture
update requirements.txt
22f6e93
raw
history blame contribute delete
No virus
3.72 kB
from prompts import *
from langchain.agents.initialize import initialize_agent
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.llms.openai import OpenAI
import gradio as gr
import re
from tools import bar_tool, line_tool, pie_tool, scatter_tool
import os
class ConversationBot:
def __init__(self):
self.tools = [line_tool, bar_tool, pie_tool, scatter_tool]
print(f"All the Available Tools: {self.tools}")
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
def init_agent(self, lang, openai_api_key):
self.memory.clear() # clear previous history
if lang == 'English':
PREFIX, FORMAT_INSTRUCTIONS, SUFFIX = CHARTGPT_PREFIX, CHARTGPT_FORMAT_INSTRUCTIONS, CHARTGPT_SUFFIX
place = "Enter text and press enter"
label_clear = "Clear"
else:
PREFIX, FORMAT_INSTRUCTIONS, SUFFIX = CHARTGPT_PREFIX_CN, CHARTGPT_FORMAT_INSTRUCTIONS_CN, CHARTGPT_SUFFIX_CN
place = "输入文字并回车"
label_clear = "清除"
os.environ["OPENAI_API_KEY"] = openai_api_key
self.llm = OpenAI()
self.agent = initialize_agent(
self.tools,
self.llm,
agent="conversational-react-description",
verbose=True,
memory=self.memory,
return_intermediate_steps=True,
agent_kwargs={'prefix': PREFIX, 'format_instructions': FORMAT_INSTRUCTIONS,
'suffix': SUFFIX}, )
print("init bot finished!")
return gr.update(visible=True), gr.update(visible=False), gr.update(placeholder=place), gr.update(
value=label_clear)
def run_text(self, text, state):
print(self.agent.memory)
res = self.agent({"input": text.strip()})
print("res:", res)
res['output'] = res['output'].replace("\\", "/")
response = re.sub('(image/[-\w]*.png)', lambda m: f'![](file={m.group(0)})*{m.group(0)}*', res['output'])
state = state + [(text, response)]
print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n")
return state, state
if __name__ == '__main__':
bot = ConversationBot()
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
openai_api_key = gr.Text(label="OPENAI_API_KEY")
lang = gr.Radio(choices=['Chinese', 'English'], value=None, label='Language')
chatbot = gr.Chatbot(elem_id="chatbot", label="ChartGPT")
state = gr.State([])
with gr.Row(visible=False) as input_raws:
with gr.Column(scale=0.7):
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(
container=False)
with gr.Column(scale=0.15, min_width=0):
clear = gr.Button("Clear")
gr.Examples(
['Visualize [2,4,5,6,10,100,120,200,135,50,10] as a line chart',
'将[2,4,5,6,10,20,100]可视化为柱状图',
'将[2,4,5,6,10,100,120,200,135,50,10]可视化为折线图',
'''将下面的数据可视化为饼图['A', 'B', 'C', 'D', 'E'][25, 20, 15, 10, 30]''',
'''将下面的数据可视化为散点图x = [1, 2, 3, 4, 5]y = [5, 4, 3, 2, 1]'''], txt)
lang.change(bot.init_agent, [lang, openai_api_key], [input_raws, lang, txt, clear])
txt.submit(bot.run_text, [txt, state], [chatbot, state])
txt.submit(lambda: "", None, txt)
clear.click(bot.memory.clear)
clear.click(lambda: [], None, chatbot)
clear.click(lambda: [], None, state)
demo.launch()