Shreyas094 commited on
Commit
b4dffd4
1 Parent(s): 5ecdc0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -115
app.py CHANGED
@@ -12,11 +12,19 @@ from langchain_community.document_loaders import PyPDFLoader
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
13
  from llama_parse import LlamaParse
14
  from langchain_core.documents import Document
 
 
15
 
16
  # Environment variables and configurations
17
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
18
  llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
19
 
 
 
 
 
 
 
20
  # Initialize LlamaParse
21
  llama_parser = LlamaParse(
22
  api_key=llama_cloud_api_key,
@@ -26,7 +34,7 @@ llama_parser = LlamaParse(
26
  language="en",
27
  )
28
 
29
- def load_document(file: NamedTemporaryFile, parser: str = "pypdf") -> List[Document]:
30
  """Loads and splits the document into pages."""
31
  if parser == "pypdf":
32
  loader = PyPDFLoader(file.name)
@@ -69,53 +77,55 @@ def update_vectors(files, parser):
69
 
70
  return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}."
71
 
72
- def generate_chunked_response(prompt, max_tokens=1000, max_chunks=5, temperature=0.7, repetition_penalty=1.1):
73
- API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
74
- headers = {"Authorization": f"Bearer {huggingface_token}"}
75
- payload = {
76
- "inputs": prompt,
77
- "parameters": {
78
- "max_new_tokens": max_tokens,
79
- "temperature": temperature,
80
- "top_p": 0.4,
81
- "top_k": 40,
82
- "repetition_penalty": repetition_penalty,
83
- "stop": ["</s>", "[/INST]"]
84
- }
85
- }
86
-
87
  full_response = ""
88
- for _ in range(max_chunks):
89
- response = requests.post(API_URL, headers=headers, json=payload)
90
- if response.status_code == 200:
91
- result = response.json()
92
- if isinstance(result, list) and len(result) > 0:
93
- chunk = result[0].get('generated_text', '')
94
-
95
- # Remove any part of the chunk that's already in full_response
96
- new_content = chunk[len(full_response):].strip()
97
-
98
- if not new_content:
99
- break # No new content, so we're done
100
-
101
- full_response += new_content
102
-
103
- if chunk.endswith((".", "!", "?", "</s>", "[/INST]")):
104
- break
105
-
106
- # Update the prompt for the next iteration
107
- payload["inputs"] = full_response
108
- else:
109
- break
110
- else:
111
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  # Clean up the response
114
  clean_response = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', full_response, flags=re.DOTALL)
115
  clean_response = clean_response.replace("Using the following context:", "").strip()
116
  clean_response = clean_response.replace("Using the following context from the PDF documents:", "").strip()
117
 
118
- return clean_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def duckduckgo_search(query):
121
  with DDGS() as ddgs:
@@ -127,113 +137,171 @@ class CitingSources(BaseModel):
127
  ...,
128
  description="List of sources to cite. Should be an URL of the source."
129
  )
 
 
 
130
 
131
- def get_response_from_pdf(query, temperature=0.7, repetition_penalty=1.1):
132
- embed = get_embeddings()
133
- if os.path.exists("faiss_database"):
134
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
135
- else:
136
- return "No documents available. Please upload PDF documents to answer questions."
137
-
138
- retriever = database.as_retriever()
139
- relevant_docs = retriever.get_relevant_documents(query)
140
- context_str = "\n".join([doc.page_content for doc in relevant_docs])
141
-
142
- prompt = f"""<s>[INST] Using the following context from the PDF documents:
143
- {context_str}
144
- Write a detailed and complete response that answers the following user question: '{query}'
145
- Do not include a list of sources in your response. [/INST]"""
146
 
147
- generated_text = generate_chunked_response(prompt, temperature=temperature, repetition_penalty=repetition_penalty)
 
 
 
 
 
 
 
 
 
 
148
 
149
- # Clean the response
150
- clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
151
- clean_text = clean_text.replace("Using the following context from the PDF documents:", "").strip()
 
 
 
 
 
152
 
153
- return clean_text
 
 
 
 
 
 
154
 
155
- def get_response_with_search(query, temperature=0.7, repetition_penalty=1.1):
156
  search_results = duckduckgo_search(query)
157
  context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
158
  for result in search_results if 'body' in result)
159
 
160
- prompt = f"""<s>[INST] Using the following context:
161
  {context}
162
  Write a detailed and complete research document that fulfills the following user request: '{query}'
163
- After writing the document, please provide a list of sources used in your response. [/INST]"""
164
 
165
- generated_text = generate_chunked_response(prompt, temperature=temperature, repetition_penalty=repetition_penalty)
166
 
167
- # Clean the response
168
- clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
169
- clean_text = clean_text.replace("Using the following context:", "").strip()
170
-
171
- # Split the content and sources
172
- parts = clean_text.split("Sources:", 1)
173
- main_content = parts[0].strip()
174
- sources = parts[1].strip() if len(parts) > 1 else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- return main_content, sources
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- def chatbot_interface(message, history, use_web_search, temperature, repetition_penalty):
179
- if use_web_search:
180
- main_content, sources = get_response_with_search(message, temperature, repetition_penalty)
181
- formatted_response = f"{main_content}\n\nSources:\n{sources}"
182
  else:
183
- response = get_response_from_pdf(message, temperature, repetition_penalty)
184
- formatted_response = response
 
 
 
185
 
186
- history.append((message, formatted_response))
187
- return history
188
 
189
- # Gradio interface
190
- with gr.Blocks() as demo:
191
- gr.Markdown("# AI-powered Web Search and PDF Chat Assistant")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  with gr.Row():
194
  file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
195
- parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="pypdf")
196
  update_button = gr.Button("Upload Document")
197
 
198
  update_output = gr.Textbox(label="Update Status")
199
  update_button.click(update_vectors, inputs=[file_input, parser_dropdown], outputs=update_output)
200
-
201
- chatbot = gr.Chatbot(label="Conversation")
202
- msg = gr.Textbox(label="Ask a question")
203
- use_web_search = gr.Checkbox(label="Use Web Search", value=False)
204
-
205
- with gr.Row():
206
- temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
207
- repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
208
-
209
- submit = gr.Button("Submit")
210
-
211
- gr.Examples(
212
- examples=[
213
- ["What are the latest developments in AI?"],
214
- ["Tell me about recent updates on GitHub"],
215
- ["What are the best hotels in Galapagos, Ecuador?"],
216
- ["Summarize recent advancements in Python programming"],
217
- ],
218
- inputs=msg,
219
- )
220
-
221
- submit.click(chatbot_interface,
222
- inputs=[msg, chatbot, use_web_search, temperature_slider, repetition_penalty_slider],
223
- outputs=[chatbot])
224
- msg.submit(chatbot_interface,
225
- inputs=[msg, chatbot, use_web_search, temperature_slider, repetition_penalty_slider],
226
- outputs=[chatbot])
227
 
 
228
  gr.Markdown(
229
  """
230
  ## How to use
231
  1. Upload PDF documents using the file input at the top.
232
  2. Select the PDF parser (pypdf or llamaparse) and click "Upload Document" to update the vector store.
233
- 3. Ask questions in the textbox.
234
  4. Toggle "Use Web Search" to switch between PDF chat and web search.
235
- 5. Adjust Temperature and Repetition Penalty sliders to fine-tune the response generation.
236
- 6. Click "Submit" or press Enter to get a response.
237
  """
238
  )
239
 
 
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
13
  from llama_parse import LlamaParse
14
  from langchain_core.documents import Document
15
+ from huggingface_hub import InferenceClient
16
+ import inspect
17
 
18
  # Environment variables and configurations
19
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
20
  llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
21
 
22
+ MODELS = [
23
+ "mistralai/Mistral-7B-Instruct-v0.3",
24
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
25
+ "microsoft/Phi-3-mini-4k-instruct"
26
+ ]
27
+
28
  # Initialize LlamaParse
29
  llama_parser = LlamaParse(
30
  api_key=llama_cloud_api_key,
 
34
  language="en",
35
  )
36
 
37
+ def load_document(file: NamedTemporaryFile, parser: str = "llamaparse") -> List[Document]:
38
  """Loads and splits the document into pages."""
39
  if parser == "pypdf":
40
  loader = PyPDFLoader(file.name)
 
77
 
78
  return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}."
79
 
80
+ def generate_chunked_response(prompt, model, max_tokens=1000, num_calls=3, temperature=0.2, should_stop=False):
81
+ print(f"Starting generate_chunked_response with {num_calls} calls")
82
+ client = InferenceClient(model, token=huggingface_token)
 
 
 
 
 
 
 
 
 
 
 
 
83
  full_response = ""
84
+ messages = [{"role": "user", "content": prompt}]
85
+
86
+ for i in range(num_calls):
87
+ print(f"Starting API call {i+1}")
88
+ if should_stop:
89
+ print("Stop clicked, breaking loop")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  break
91
+ try:
92
+ for message in client.chat_completion(
93
+ messages=messages,
94
+ max_tokens=max_tokens,
95
+ temperature=temperature,
96
+ stream=True,
97
+ ):
98
+ if should_stop:
99
+ print("Stop clicked during streaming, breaking")
100
+ break
101
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
102
+ chunk = message.choices[0].delta.content
103
+ full_response += chunk
104
+ print(f"API call {i+1} completed")
105
+ except Exception as e:
106
+ print(f"Error in generating response: {str(e)}")
107
 
108
  # Clean up the response
109
  clean_response = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', full_response, flags=re.DOTALL)
110
  clean_response = clean_response.replace("Using the following context:", "").strip()
111
  clean_response = clean_response.replace("Using the following context from the PDF documents:", "").strip()
112
 
113
+ # Remove duplicate paragraphs and sentences
114
+ paragraphs = clean_response.split('\n\n')
115
+ unique_paragraphs = []
116
+ for paragraph in paragraphs:
117
+ if paragraph not in unique_paragraphs:
118
+ sentences = paragraph.split('. ')
119
+ unique_sentences = []
120
+ for sentence in sentences:
121
+ if sentence not in unique_sentences:
122
+ unique_sentences.append(sentence)
123
+ unique_paragraphs.append('. '.join(unique_sentences))
124
+
125
+ final_response = '\n\n'.join(unique_paragraphs)
126
+
127
+ print(f"Final clean response: {final_response[:100]}...")
128
+ return final_response
129
 
130
  def duckduckgo_search(query):
131
  with DDGS() as ddgs:
 
137
  ...,
138
  description="List of sources to cite. Should be an URL of the source."
139
  )
140
+ def chatbot_interface(message, history, use_web_search, model, temperature, num_calls):
141
+ if not message.strip():
142
+ return "", history
143
 
144
+ history = history + [(message, "")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ try:
147
+ if use_web_search:
148
+ for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
149
+ history[-1] = (message, f"{main_content}\n\n{sources}")
150
+ yield history
151
+ else:
152
+ for partial_response in get_response_from_pdf(message, model, num_calls=num_calls, temperature=temperature):
153
+ history[-1] = (message, partial_response)
154
+ yield history
155
+ except gr.CancelledError:
156
+ yield history
157
 
158
+ def retry_last_response(history, use_web_search, model, temperature, num_calls):
159
+ if not history:
160
+ return history
161
+
162
+ last_user_msg = history[-1][0]
163
+ history = history[:-1] # Remove the last response
164
+
165
+ return chatbot_interface(last_user_msg, history, use_web_search, model, temperature, num_calls)
166
 
167
+ def respond(message, history, model, temperature, num_calls, use_web_search):
168
+ if use_web_search:
169
+ for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
170
+ yield f"{main_content}\n\n{sources}"
171
+ else:
172
+ for partial_response in get_response_from_pdf(message, model, num_calls=num_calls, temperature=temperature):
173
+ yield partial_response
174
 
175
+ def get_response_with_search(query, model, num_calls=3, temperature=0.2):
176
  search_results = duckduckgo_search(query)
177
  context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
178
  for result in search_results if 'body' in result)
179
 
180
+ prompt = f"""Using the following context:
181
  {context}
182
  Write a detailed and complete research document that fulfills the following user request: '{query}'
183
+ After writing the document, please provide a list of sources used in your response."""
184
 
185
+ client = InferenceClient(model, token=huggingface_token)
186
 
187
+ main_content = ""
188
+ for i in range(num_calls):
189
+ for message in client.chat_completion(
190
+ messages=[{"role": "user", "content": prompt}],
191
+ max_tokens=1000,
192
+ temperature=temperature,
193
+ stream=True,
194
+ ):
195
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
196
+ chunk = message.choices[0].delta.content
197
+ main_content += chunk
198
+ yield main_content, "" # Yield partial main content without sources
199
+
200
+ def get_response_from_pdf(query, model, num_calls=3, temperature=0.2):
201
+ embed = get_embeddings()
202
+ if os.path.exists("faiss_database"):
203
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
204
+ else:
205
+ yield "No documents available. Please upload PDF documents to answer questions."
206
+ return
207
+
208
+ retriever = database.as_retriever()
209
+ relevant_docs = retriever.get_relevant_documents(query)
210
+ context_str = "\n".join([doc.page_content for doc in relevant_docs])
211
+
212
+ prompt = f"""Using the following context from the PDF documents:
213
+ {context_str}
214
+ Write a detailed and complete response that answers the following user question: '{query}'"""
215
+
216
+ client = InferenceClient(model, token=huggingface_token)
217
 
218
+ response = ""
219
+ for i in range(num_calls):
220
+ for message in client.chat_completion(
221
+ messages=[{"role": "user", "content": prompt}],
222
+ max_tokens=1000,
223
+ temperature=temperature,
224
+ stream=True,
225
+ ):
226
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
227
+ chunk = message.choices[0].delta.content
228
+ response += chunk
229
+ yield response # Yield partial response
230
 
231
+ def vote(data: gr.LikeData):
232
+ if data.liked:
233
+ print(f"You upvoted this response: {data.value}")
 
234
  else:
235
+ print(f"You downvoted this response: {data.value}")
236
+
237
+ css = """
238
+ /* Add your custom CSS here */
239
+ """
240
 
241
+ # Define the checkbox outside the demo block
242
+ use_web_search = gr.Checkbox(label="Use Web Search", value=False)
243
 
244
+ demo = gr.ChatInterface(
245
+ respond,
246
+ additional_inputs=[
247
+ gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[0]),
248
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
249
+ gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of API Calls"),
250
+ use_web_search # Add this line to include the checkbox
251
+ ],
252
+ title="AI-powered Web Search and PDF Chat Assistant",
253
+ description="Chat with your PDFs or use web search to answer questions.",
254
+ theme=gr.themes.Soft(
255
+ primary_hue="orange",
256
+ secondary_hue="amber",
257
+ neutral_hue="gray",
258
+ font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]
259
+ ).set(
260
+ body_background_fill_dark="#0c0505",
261
+ block_background_fill_dark="#0c0505",
262
+ block_border_width="1px",
263
+ block_title_background_fill_dark="#1b0f0f",
264
+ input_background_fill_dark="#140b0b",
265
+ button_secondary_background_fill_dark="#140b0b",
266
+ border_color_accent_dark="#1b0f0f",
267
+ border_color_primary_dark="#1b0f0f",
268
+ background_fill_secondary_dark="#0c0505",
269
+ color_accent_soft_dark="transparent",
270
+ code_background_fill_dark="#140b0b"
271
+ ),
272
 
273
+ css=css,
274
+ examples=[
275
+ ["Tell me about the contents of the uploaded PDFs."],
276
+ ["What are the main topics discussed in the documents?"],
277
+ ["Can you summarize the key points from the PDFs?"]
278
+ ],
279
+ cache_examples=False,
280
+ analytics_enabled=False,
281
+ )
282
+
283
+ # Add file upload functionality
284
+ with demo:
285
+ gr.Markdown("## Upload PDF Documents")
286
+
287
  with gr.Row():
288
  file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
289
+ parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="llamaparse")
290
  update_button = gr.Button("Upload Document")
291
 
292
  update_output = gr.Textbox(label="Update Status")
293
  update_button.click(update_vectors, inputs=[file_input, parser_dropdown], outputs=update_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
+
296
  gr.Markdown(
297
  """
298
  ## How to use
299
  1. Upload PDF documents using the file input at the top.
300
  2. Select the PDF parser (pypdf or llamaparse) and click "Upload Document" to update the vector store.
301
+ 3. Ask questions in the chat interface.
302
  4. Toggle "Use Web Search" to switch between PDF chat and web search.
303
+ 5. Adjust Temperature and Number of API Calls to fine-tune the response generation.
304
+ 6. Use the provided examples or ask your own questions.
305
  """
306
  )
307