simple-ui / app.py
dinhdat1110's picture
Update app.py
95b6c1d
raw
history blame contribute delete
No virus
2.06 kB
from threading import Thread
import gradio as gr
import torch
from transformers import (
pipeline,
AutoTokenizer,
TextIteratorStreamer,
)
def chat_history(history) -> str:
messages = []
for dialog in history:
for i, message in enumerate(dialog):
role = "user" if i % 2 == 0 else "assistant"
messages.append({"role": role, "content": message})
messages.pop(-1)
return pipe.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
def model_loading_pipeline():
model_id = "vilm/vinallama-2.7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5)
pipe = pipeline(
"text-generation",
model=model_id,
model_kwargs={
"torch_dtype": torch.bfloat16,
},
streamer=streamer,
)
return pipe, streamer
def launch_app(pipe, streamer):
with gr.Blocks() as demo:
chat = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history):
prompt = chat_history(history)
history[-1][1] = ""
kwargs = {
"text_inputs": prompt,
"max_new_tokens": 64,
"do_sample": True,
"temperature": 0.7,
"top_k": 50,
"top_p": 0.95,
}
thread = Thread(target=pipe, kwargs=kwargs)
thread.start()
for token in streamer:
history[-1][1] += token
yield history
msg.submit(user, [msg, chat], [msg, chat], queue=False).then(bot, chat, chat)
clear.click(lambda: None, None, chat, queue=False)
demo.queue()
demo.launch(share=True, debug=True)
if __name__ == "__main__":
pipe, streamer = model_loading_pipeline()
launch_app(pipe, streamer)