DialogSumLlama2 / app.py
shenoy's picture
Add application file and dependencies
8fd1ef9
raw
history blame contribute delete
No virus
1.31 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
from peft import PeftModel
# Quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="float16",
)
model_name = "TinyPixel/Llama-2-7B-bf16-sharded"
# loading the model with quantization config
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True,
device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True , return_token_type_ids=False)
tokenizer.pad_token = tokenizer.eos_token
model = PeftModel.from_pretrained(model,"shenoy/DialogSumLlama2_qlora", device_map="auto")
#gradio fields
input_text = gr.inputs.Textbox(label="Input Text", type="text")
output_text = gr.outputs.Textbox(label="Output Text", type="text")
def predict(text):
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_new_tokens=100 ,repetition_penalty=1.2)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
#gradio interface
interface = gr.Interface(fn=predict, inputs=input_text, outputs=output_text)
interface.launch()