52100303-TranPhuocSang commited on
Commit
24988f9
1 Parent(s): 3fdfe52

Update model, RAG with CTransformer

Browse files
Files changed (10) hide show
  1. .gitattributes +1 -0
  2. .gitignore +164 -0
  3. README.md +1 -1
  4. app.py +116 -39
  5. chain.py +83 -0
  6. db/index.faiss +0 -0
  7. db/index.pkl +3 -0
  8. requirements.txt +22 -1
  9. test.py +10 -0
  10. vector_db.py +51 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/vinallama-7b-chat_q5_0.gguf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ models/
11
+ documents/
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
112
+ .pdm.toml
113
+ .pdm-python
114
+ .pdm-build/
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: RAG LLM Chatbot CTransformer
3
  emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
 
1
  ---
2
+ title: Chatbot Llms Rag
3
  emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
app.py CHANGED
@@ -1,59 +1,136 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  demo = gr.ChatInterface(
46
  respond,
 
47
  additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
  ),
58
  ],
59
  )
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from langchain_community.llms import CTransformers
6
+ from langchain_community.llms import HuggingFaceHub
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain.chains import RetrievalQA
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain_huggingface import HuggingFaceEmbeddings
11
 
12
+ model_name = "vilm/vinallama-2.7b-chat-GGUF"
13
+ model_file_path = './models/vinallama-7b-chat_q5_0.gguf'
14
+ model_embedding_name = 'bkai-foundation-models/vietnamese-bi-encoder'
15
+ vectorDB_path = './db'
16
 
17
+ load_dotenv()
18
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
19
 
20
+ # def load_model(model_file_path,
21
+ # model_type,
22
+ # temperature=0.01,
23
+ # context_length=1024,
24
+ # max_new_tokens=1024
25
+ # ):
26
+ # llm = CTransformers(
27
+ # model = model_file_path,
28
+ # model_type = model_type,
29
+ # max_new_tokens = max_new_tokens,
30
+ # temperature = temperature,
31
+ # config = {
32
+ # 'context_length': context_length,
33
+ # },
34
+ # )
35
+ # return llm
36
 
37
+ def load_model(model_name,
38
+ api_token,
39
+ temperature=0.01,
40
+ context_length=1024,
41
+ max_new_tokens=1024):
42
+ client = InferenceClient(model=model_name, token=api_token)
43
+ llm = HuggingFaceHub(
44
+ client = client,
45
+ max_new_tokens = max_new_tokens,
46
+ temperature = temperature,
47
+ context_length = context_length,
48
+ )
49
+ return llm
50
 
51
+ def load_db():
52
+ model_kwargs = {'device': 'cuda'}
53
+ encode_kwargs = {'normalize_embeddings': False}
54
+ embeddings = HuggingFaceEmbeddings(
55
+ model_name=model_embedding_name,
56
+ model_kwargs=model_kwargs,
57
+ encode_kwargs=encode_kwargs
58
+ )
59
+ db = FAISS.load_local(vectorDB_path, embeddings, allow_dangerous_deserialization=True)
60
+ return db
61
 
62
+ def create_prompt(template):
63
+ prompt = PromptTemplate(
64
+ template=template,
65
+ input_variables=['context', 'question'],
66
+ )
67
 
68
+ return prompt
 
 
 
 
 
 
 
69
 
70
+ def create_chain(llm,
71
+ prompt,
72
+ db,
73
+ top_k_documents=3,
74
+ return_source_documents=True):
75
+
76
+ chain = RetrievalQA.from_chain_type(
77
+ llm = llm,
78
+ chain_type = 'stuff',
79
+ retriever = db.as_retriever(
80
+ search_kwargs={
81
+ "k": top_k_documents
82
+ }
83
+ ),
84
+ return_source_documents = return_source_documents,
85
+ chain_type_kwargs = {
86
+ 'prompt': prompt,
87
+ },
88
+ )
89
 
90
+ return chain
91
+
92
+ db = load_db()
93
+ llm = load_model(
94
+ model_file_path=model_file_path,
95
+ model_type='llama',
96
+ context_length=2048
97
+ )
98
+
99
+
100
+ template = """<|im_start|>system
101
+ Sử dụng thông tin sau đây để trả lời câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời \n
102
+ {context}<|im_end|>\n
103
+ <|im_start|>user\n
104
+ {question}!<|im_end|>\n
105
+ <|im_start|>assistant
106
  """
107
+
108
+ prompt = create_prompt(template=template)
109
+ llm_chain = create_chain(llm, prompt, db)
110
+
111
+ def respond(message,
112
+ history: list[tuple[str, str]],
113
+ system_message,
114
+ max_tokens,
115
+ temperature,
116
+ top_k_documents,
117
+ ):
118
+ response = llm_chain.invoke({"query": message})
119
+
120
+ history.append((message, response['result']))
121
+
122
+ yield response['result']
123
+
124
+
125
+
126
  demo = gr.ChatInterface(
127
  respond,
128
+ title="Chatbot",
129
  additional_inputs=[
130
+ # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
131
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
132
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
133
+ gr.Slider(minimum=1, maximum=8, value=3, step=1, label="Top k documents to search for answers in",
 
 
 
 
 
134
  ),
135
  ],
136
  )
chain.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.llms import CTransformers
2
+ from langchain.chains import LLMChain
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain.chains import RetrievalQA
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+
8
+ model_name = "vilm/vinallama-2.7b-chat-GGUF"
9
+ model_file_path = './models/vinallama-7b-chat_q5_0.gguf'
10
+ model_embedding_name = 'bkai-foundation-models/vietnamese-bi-encoder'
11
+
12
+ def load_model():
13
+ llm = CTransformers(
14
+ model = model_file_path,
15
+ model_type = 'llama',
16
+ max_new_tokens = 1024,
17
+ temperature = 0.01,
18
+ config = {
19
+ 'context_length': 1024,
20
+ },
21
+ )
22
+ return llm
23
+
24
+ def create_prompt(template):
25
+ prompt = PromptTemplate(
26
+ template=template,
27
+ input_variables=['context', 'question'],
28
+ )
29
+
30
+ return prompt
31
+
32
+ def create_chain(llm, prompt, db):
33
+ chain = RetrievalQA.from_chain_type(
34
+ llm = llm,
35
+ chain_type = 'stuff',
36
+ retriever = db.as_retriever( search_kwargs={"k": 3}),
37
+ return_source_documents = True,
38
+ chain_type_kwargs = {
39
+ 'prompt': prompt,
40
+ },
41
+ )
42
+
43
+ return chain
44
+
45
+ vectorDB_path = './db'
46
+ def load_db():
47
+ model_kwargs = {'device': 'cuda'}
48
+ encode_kwargs = {'normalize_embeddings': False}
49
+ embeddings = HuggingFaceEmbeddings(
50
+ model_name=model_embedding_name,
51
+ model_kwargs=model_kwargs,
52
+ encode_kwargs=encode_kwargs
53
+ )
54
+ db = FAISS.load_local(vectorDB_path, embeddings, allow_dangerous_deserialization=True)
55
+ return db
56
+
57
+
58
+ db = load_db()
59
+ llm = load_model()
60
+
61
+ template = """<|im_start|>system
62
+ Sử dụng thông tin sau đây để trả lời câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời \n
63
+ {context}<|im_end|>\n
64
+ <|im_start|>user\n
65
+ {question}!<|im_end|>\n
66
+ <|im_start|>assistant
67
+ """
68
+
69
+ prompt = create_prompt(template=template)
70
+ llm_chain = create_chain(llm, prompt, db)
71
+
72
+ # Test the chain
73
+ # question = "2/9 ở Việt Nam là ngày gì ?"
74
+ question = "Diễn biến Chiến dịch biên giới thu đông 1950"
75
+ response = llm_chain.invoke({"query": question})
76
+ print(response)
77
+ print()
78
+
79
+ print(response['query'])
80
+ print(response['result'])
81
+ print()
82
+
83
+ print(response['source_documents'])
db/index.faiss ADDED
Binary file (756 kB). View file
 
db/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ab25a54daf668704a3bf9da7faa92c6a1eb97ecc11c2dd07a80d8af752c9b31
3
+ size 193552
requirements.txt CHANGED
@@ -1 +1,22 @@
1
- huggingface_hub==0.22.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub==0.22.2
2
+ langchain_huggingface
3
+ langchain_ai21
4
+ python-dotenv
5
+ gradio
6
+ minijinja
7
+ transformers
8
+ ctransformers
9
+ langchain
10
+ langchain-community
11
+ torch
12
+ pypdf
13
+ sentence-transformers
14
+ gpt4all
15
+ faiss-cpu
16
+ openai
17
+ bitsandbytes
18
+ accelerate
19
+ xformers
20
+ einops
21
+ re
22
+ llama-cpp-python
test.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def echo(message, history):
4
+ print(message)
5
+ print('---')
6
+ print(history)
7
+ return message
8
+
9
+ demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot")
10
+ demo.launch()
vector_db.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
2
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_ai21 import AI21SemanticTextSplitter
6
+ from dotenv import load_dotenv
7
+ import re
8
+ import os
9
+
10
+ load_dotenv()
11
+
12
+
13
+ pdf_data_path = './documents'
14
+ vector_db_path = './db'
15
+ model_name = 'bkai-foundation-models/vietnamese-bi-encoder'
16
+ AI21_TOKEN = os.getenv('AI21_TOKEN')
17
+ os.environ["AI21_API_KEY"] = AI21_TOKEN
18
+
19
+
20
+ def clean_text(text):
21
+ text = re.sub(r'[^\w\s,.-]', '', text)
22
+ text = re.sub(r'\s+', ' ', text).strip()
23
+ text = text.replace(" \n", "\n").replace("\n ", "\n").replace("\n", "\n\n")
24
+
25
+ return text
26
+
27
+ def create_db_from_files():
28
+ loader = DirectoryLoader(pdf_data_path, glob="*.pdf", loader_cls = PyPDFLoader)
29
+ documents = loader.load()
30
+
31
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
32
+ text_splitter = AI21SemanticTextSplitter(chunk_size=1024, chunk_overlap=128)
33
+
34
+ chunks = text_splitter.split_documents(documents)
35
+
36
+ for chunk in chunks:
37
+ chunk.page_content = clean_text(chunk.page_content)
38
+
39
+ model_kwargs = {'device': 'cuda'}
40
+ encode_kwargs = {'normalize_embeddings': False}
41
+ embeddings = HuggingFaceEmbeddings(
42
+ model_name=model_name,
43
+ model_kwargs=model_kwargs,
44
+ encode_kwargs=encode_kwargs
45
+ )
46
+
47
+ db = FAISS.from_documents(chunks, embeddings)
48
+ db.save_local(vector_db_path)
49
+ return db
50
+
51
+ create_db_from_files()