# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2021 deepset GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from json import JSONDecodeError
import streamlit as st
from markdown import markdown
from pipelines.document_stores import FAISSDocumentStore, MilvusDocumentStore
from pipelines.nodes import DensePassageRetriever, ErnieRanker
from pipelines.utils import (
convert_files_to_dicts,
fetch_archive_from_http,
print_documents,
)
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu",
help="Select which device to run dense_qa system, defaults to gpu.")
parser.add_argument("--index_name", default='dureader_index', type=str, help="The ann index name of ANN.")
parser.add_argument("--search_engine", choices=['faiss', 'milvus'], default="faiss",
help="The type of ANN search engine.")
parser.add_argument("--max_seq_len_query", default=64, type=int,
help="The maximum total length of query after tokenization.")
parser.add_argument("--max_seq_len_passage", default=256, type=int,
help="The maximum total length of passage after tokenization.")
parser.add_argument("--retriever_batch_size", default=16, type=int,
help="The batch size of retriever to extract passage embedding for building ANN index.")
parser.add_argument("--query_embedding_model", default="rocketqa-zh-nano-query-encoder", type=str,
help="The query_embedding_model path")
parser.add_argument("--passage_embedding_model", default="rocketqa-zh-nano-query-encoder", type=str,
help="The passage_embedding_model path")
parser.add_argument("--params_path", default="checkpoints/model_40/model_state.pdparams", type=str,
help="The checkpoint path")
parser.add_argument("--embedding_dim", default=312, type=int, help="The embedding_dim of index")
parser.add_argument('--host', type=str, default="localhost", help='host ip of ANN search engine')
parser.add_argument('--port', type=str, default="8530", help='port of ANN search engine')
parser.add_argument('--embed_title', default=False, type=bool, help="The title to be embedded into embedding")
parser.add_argument('--model_type', choices=['ernie_search', 'ernie', 'bert', 'neural_search'], default="ernie",
help="the ernie model types")
args = parser.parse_args()
# yapf: enable
def get_faiss_retriever(use_gpu):
faiss_document_store = "faiss_document_store.db"
if os.path.exists(args.index_name) and os.path.exists(faiss_document_store):
# connect to existed FAISS Index
print("connect to existed FAISS Index!")
document_store = FAISSDocumentStore.load(args.index_name)
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
passage_embedding_model=args.passage_embedding_model,
params_path=args.params_path,
output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=args.embed_title,
)
else:
doc_dir = "resume_data"
dicts = convert_files_to_dicts(dir_path=doc_dir, split_paragraphs=True, encoding="utf-8")
if os.path.exists(args.index_name):
os.remove(args.index_name)
if os.path.exists(faiss_document_store):
os.remove(faiss_document_store)
document_store = FAISSDocumentStore(embedding_dim=args.embedding_dim, faiss_index_factory_str="Flat")
document_store.write_documents(dicts)
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
passage_embedding_model=args.passage_embedding_model,
params_path=args.params_path,
output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=args.embed_title,
)
# update Embedding
document_store.update_embeddings(retriever)
# save index
document_store.save(args.index_name)
return retriever
def get_milvus_retriever(use_gpu):
milvus_document_store = "milvus_document_store.db"
if os.path.exists(milvus_document_store):
document_store = MilvusDocumentStore(
embedding_dim=args.embedding_dim,
host=args.host,
index=args.index_name,
port=args.port,
index_param={"M": 16, "efConstruction": 50},
index_type="HNSW",
)
# connect to existed Milvus Index
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
passage_embedding_model=args.passage_embedding_model,
params_path=args.params_path,
output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=args.embed_title,
)
else:
doc_dir = "data/dureader_dev"
dureader_data = "https://paddlenlp.bj.bcebos.com/applications/dureader_dev.zip"
fetch_archive_from_http(url=dureader_data, output_dir=doc_dir)
dicts = convert_files_to_dicts(dir_path=doc_dir, split_paragraphs=True, encoding="utf-8")
document_store = MilvusDocumentStore(
embedding_dim=args.embedding_dim,
host=args.host,
index=args.index_name,
port=args.port,
index_param={"M": 16, "efConstruction": 50},
index_type="HNSW",
)
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model=args.query_embedding_model,
passage_embedding_model=args.passage_embedding_model,
params_path=args.params_path,
output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None,
max_seq_len_query=args.max_seq_len_query,
max_seq_len_passage=args.max_seq_len_passage,
batch_size=args.retriever_batch_size,
use_gpu=use_gpu,
embed_title=args.embed_title,
)
document_store.write_documents(dicts)
# update Embedding
document_store.update_embeddings(retriever)
return retriever
use_gpu = True if args.device == "gpu" else False
if args.search_engine == "milvus":
retriever = get_milvus_retriever(use_gpu)
else:
retriever = get_faiss_retriever(use_gpu)
# Ranker
ranker = ErnieRanker(model_name_or_path="rocketqa-zh-dureader-cross-encoder", use_gpu=use_gpu)
# Pipeline
from pipelines import SemanticSearchPipeline
pipe = SemanticSearchPipeline(retriever, ranker)
def semantic_search_tutorial(query):
prediction = pipe.run(query=query, params={"Retriever": {"top_k": 20}, "Ranker": {"top_k": 10}})
print(prediction)
docs = prediction["documents"]
return docs
# Adjust to a question that you would like users to see in the search bar when they load the UI:
DEFAULT_QUESTION_AT_STARTUP = os.getenv("DEFAULT_QUESTION_AT_STARTUP", "具有金融科技背景的人才?")
DEFAULT_ANSWER_AT_STARTUP = os.getenv("DEFAULT_ANSWER_AT_STARTUP", "")
# Sliders
DEFAULT_DOCS_FROM_RETRIEVER = int(os.getenv("DEFAULT_DOCS_FROM_RETRIEVER", "50"))
DEFAULT_NUMBER_OF_ANSWERS = int(os.getenv("DEFAULT_NUMBER_OF_ANSWERS", "10"))
def set_state_if_absent(key, value):
if key not in st.session_state:
st.session_state[key] = value
def on_change_text():
st.session_state.question = st.session_state.quest
st.session_state.answer = None
st.session_state.results = None
st.session_state.raw_json = None
def main():
st.set_page_config(
page_title="人才简历语义检索演示系统",
page_icon="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/imgs/logo.png",
)
# Persistent state
set_state_if_absent("question", DEFAULT_QUESTION_AT_STARTUP)
set_state_if_absent("results", None)
set_state_if_absent("raw_json", None)
set_state_if_absent("random_question_requested", False)
# Small callback to reset the interface in case the text of the question changes
def reset_results(*args):
st.session_state.answer = None
st.session_state.results = None
st.session_state.raw_json = None
# Title
st.write("# PaddleNLP 人才简历语义检索演示系统")
# Sidebar
st.sidebar.header("选项")
top_k_reader = st.sidebar.slider(
"最大的答案的数量",
min_value=1,
max_value=30,
value=DEFAULT_NUMBER_OF_ANSWERS,
step=1,
on_change=reset_results,
)
top_k_retriever = st.sidebar.slider(
"最大检索数量",
min_value=1,
max_value=100,
value=DEFAULT_DOCS_FROM_RETRIEVER,
step=1,
on_change=reset_results,
)
# Search bar
question = st.text_input(
"",
value=st.session_state.question,
key="quest",
on_change=on_change_text,
max_chars=100,
placeholder="请输入您的问题",
)
col1, col2 = st.columns(2)
col1.markdown("", unsafe_allow_html=True)
col2.markdown("", unsafe_allow_html=True)
# Run button
run_pressed = col1.button("运行")
st.session_state.random_question_requested = False
run_query = ( run_pressed or question != st.session_state.question) and not st.session_state.random_question_requested
# Get results for query
if (run_query or st.session_state.results is None) and question:
reset_results()
st.session_state.question = question
with st.spinner(
"🧠 Performing neural search on documents... \n "
"Do you want to optimize speed or accuracy? \n"
):
try:
st.session_state.results = semantic_search_tutorial(question)
except JSONDecodeError:
st.error("👓 An error occurred reading the results. Is the document store working?")
return
if st.session_state.results:
st.write("## 返回结果:")
for count, result in enumerate(st.session_state.results):
context = result.content
st.write(
markdown(context),
unsafe_allow_html=True,
)
# st.write("**Relevance:** ", result["relevance"])
st.write("___")
if __name__ == "__main__":
main()