File size: 2,312 Bytes
24988f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from langchain_community.llms import CTransformers
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings

model_name = "vilm/vinallama-2.7b-chat-GGUF"
model_file_path = './models/vinallama-7b-chat_q5_0.gguf'
model_embedding_name = 'bkai-foundation-models/vietnamese-bi-encoder'

def load_model():
    llm = CTransformers(
        model = model_file_path,
        model_type = 'llama',
        max_new_tokens = 1024,
        temperature = 0.01,
        config = {
            'context_length': 1024,
        },
    )
    return llm

def create_prompt(template):
    prompt = PromptTemplate(
        template=template,
        input_variables=['context', 'question'],
    )

    return prompt

def create_chain(llm, prompt, db):
    chain = RetrievalQA.from_chain_type(
        llm = llm,
        chain_type = 'stuff',
        retriever = db.as_retriever( search_kwargs={"k": 3}),
        return_source_documents = True,
        chain_type_kwargs = {
            'prompt': prompt,
        },
    )

    return chain

vectorDB_path = './db'
def load_db():
    model_kwargs = {'device': 'cuda'}
    encode_kwargs = {'normalize_embeddings': False}
    embeddings = HuggingFaceEmbeddings(
        model_name=model_embedding_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs
    )
    db = FAISS.load_local(vectorDB_path, embeddings, allow_dangerous_deserialization=True)
    return db


db = load_db()
llm = load_model()

template = """<|im_start|>system
Sử dụng thông tin sau đây để trả lời câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời \n
{context}<|im_end|>\n
<|im_start|>user\n
{question}!<|im_end|>\n
<|im_start|>assistant
"""

prompt = create_prompt(template=template)
llm_chain = create_chain(llm, prompt, db)

# Test the chain
# question = "2/9 ở Việt Nam là ngày gì ?"
question = "Diễn biến Chiến dịch biên giới thu đông 1950"
response = llm_chain.invoke({"query": question})
print(response)
print()

print(response['query'])
print(response['result'])
print()

print(response['source_documents'])