# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright 2021 deepset GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os from json import JSONDecodeError import streamlit as st from markdown import markdown from pipelines.document_stores import FAISSDocumentStore, MilvusDocumentStore from pipelines.nodes import DensePassageRetriever, ErnieRanker from pipelines.utils import ( convert_files_to_dicts, fetch_archive_from_http, print_documents, ) # yapf: disable parser = argparse.ArgumentParser() parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to run dense_qa system, defaults to gpu.") parser.add_argument("--index_name", default='dureader_index', type=str, help="The ann index name of ANN.") parser.add_argument("--search_engine", choices=['faiss', 'milvus'], default="faiss", help="The type of ANN search engine.") parser.add_argument("--max_seq_len_query", default=64, type=int, help="The maximum total length of query after tokenization.") parser.add_argument("--max_seq_len_passage", default=256, type=int, help="The maximum total length of passage after tokenization.") parser.add_argument("--retriever_batch_size", default=16, type=int, help="The batch size of retriever to extract passage embedding for building ANN index.") parser.add_argument("--query_embedding_model", default="rocketqa-zh-nano-query-encoder", type=str, help="The query_embedding_model path") parser.add_argument("--passage_embedding_model", default="rocketqa-zh-nano-query-encoder", type=str, help="The passage_embedding_model path") parser.add_argument("--params_path", default="checkpoints/model_40/model_state.pdparams", type=str, help="The checkpoint path") parser.add_argument("--embedding_dim", default=312, type=int, help="The embedding_dim of index") parser.add_argument('--host', type=str, default="localhost", help='host ip of ANN search engine') parser.add_argument('--port', type=str, default="8530", help='port of ANN search engine') parser.add_argument('--embed_title', default=False, type=bool, help="The title to be embedded into embedding") parser.add_argument('--model_type', choices=['ernie_search', 'ernie', 'bert', 'neural_search'], default="ernie", help="the ernie model types") args = parser.parse_args() # yapf: enable def get_faiss_retriever(use_gpu): faiss_document_store = "faiss_document_store.db" if os.path.exists(args.index_name) and os.path.exists(faiss_document_store): # connect to existed FAISS Index print("connect to existed FAISS Index!") document_store = FAISSDocumentStore.load(args.index_name) retriever = DensePassageRetriever( document_store=document_store, query_embedding_model=args.query_embedding_model, passage_embedding_model=args.passage_embedding_model, params_path=args.params_path, output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None, max_seq_len_query=args.max_seq_len_query, max_seq_len_passage=args.max_seq_len_passage, batch_size=args.retriever_batch_size, use_gpu=use_gpu, embed_title=args.embed_title, ) else: doc_dir = "resume_data" dicts = convert_files_to_dicts(dir_path=doc_dir, split_paragraphs=True, encoding="utf-8") if os.path.exists(args.index_name): os.remove(args.index_name) if os.path.exists(faiss_document_store): os.remove(faiss_document_store) document_store = FAISSDocumentStore(embedding_dim=args.embedding_dim, faiss_index_factory_str="Flat") document_store.write_documents(dicts) retriever = DensePassageRetriever( document_store=document_store, query_embedding_model=args.query_embedding_model, passage_embedding_model=args.passage_embedding_model, params_path=args.params_path, output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None, max_seq_len_query=args.max_seq_len_query, max_seq_len_passage=args.max_seq_len_passage, batch_size=args.retriever_batch_size, use_gpu=use_gpu, embed_title=args.embed_title, ) # update Embedding document_store.update_embeddings(retriever) # save index document_store.save(args.index_name) return retriever def get_milvus_retriever(use_gpu): milvus_document_store = "milvus_document_store.db" if os.path.exists(milvus_document_store): document_store = MilvusDocumentStore( embedding_dim=args.embedding_dim, host=args.host, index=args.index_name, port=args.port, index_param={"M": 16, "efConstruction": 50}, index_type="HNSW", ) # connect to existed Milvus Index retriever = DensePassageRetriever( document_store=document_store, query_embedding_model=args.query_embedding_model, passage_embedding_model=args.passage_embedding_model, params_path=args.params_path, output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None, max_seq_len_query=args.max_seq_len_query, max_seq_len_passage=args.max_seq_len_passage, batch_size=args.retriever_batch_size, use_gpu=use_gpu, embed_title=args.embed_title, ) else: doc_dir = "data/dureader_dev" dureader_data = "https://paddlenlp.bj.bcebos.com/applications/dureader_dev.zip" fetch_archive_from_http(url=dureader_data, output_dir=doc_dir) dicts = convert_files_to_dicts(dir_path=doc_dir, split_paragraphs=True, encoding="utf-8") document_store = MilvusDocumentStore( embedding_dim=args.embedding_dim, host=args.host, index=args.index_name, port=args.port, index_param={"M": 16, "efConstruction": 50}, index_type="HNSW", ) retriever = DensePassageRetriever( document_store=document_store, query_embedding_model=args.query_embedding_model, passage_embedding_model=args.passage_embedding_model, params_path=args.params_path, output_emb_size=args.embedding_dim if args.model_type in ["ernie_search", "neural_search"] else None, max_seq_len_query=args.max_seq_len_query, max_seq_len_passage=args.max_seq_len_passage, batch_size=args.retriever_batch_size, use_gpu=use_gpu, embed_title=args.embed_title, ) document_store.write_documents(dicts) # update Embedding document_store.update_embeddings(retriever) return retriever use_gpu = True if args.device == "gpu" else False if args.search_engine == "milvus": retriever = get_milvus_retriever(use_gpu) else: retriever = get_faiss_retriever(use_gpu) # Ranker ranker = ErnieRanker(model_name_or_path="rocketqa-zh-dureader-cross-encoder", use_gpu=use_gpu) # Pipeline from pipelines import SemanticSearchPipeline pipe = SemanticSearchPipeline(retriever, ranker) def semantic_search_tutorial(query): prediction = pipe.run(query=query, params={"Retriever": {"top_k": 20}, "Ranker": {"top_k": 10}}) print(prediction) docs = prediction["documents"] return docs # Adjust to a question that you would like users to see in the search bar when they load the UI: DEFAULT_QUESTION_AT_STARTUP = os.getenv("DEFAULT_QUESTION_AT_STARTUP", "具有金融科技背景的人才?") DEFAULT_ANSWER_AT_STARTUP = os.getenv("DEFAULT_ANSWER_AT_STARTUP", "") # Sliders DEFAULT_DOCS_FROM_RETRIEVER = int(os.getenv("DEFAULT_DOCS_FROM_RETRIEVER", "50")) DEFAULT_NUMBER_OF_ANSWERS = int(os.getenv("DEFAULT_NUMBER_OF_ANSWERS", "10")) def set_state_if_absent(key, value): if key not in st.session_state: st.session_state[key] = value def on_change_text(): st.session_state.question = st.session_state.quest st.session_state.answer = None st.session_state.results = None st.session_state.raw_json = None def main(): st.set_page_config( page_title="人才简历语义检索演示系统", page_icon="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/imgs/logo.png", ) # Persistent state set_state_if_absent("question", DEFAULT_QUESTION_AT_STARTUP) set_state_if_absent("results", None) set_state_if_absent("raw_json", None) set_state_if_absent("random_question_requested", False) # Small callback to reset the interface in case the text of the question changes def reset_results(*args): st.session_state.answer = None st.session_state.results = None st.session_state.raw_json = None # Title st.write("# PaddleNLP 人才简历语义检索演示系统") # Sidebar st.sidebar.header("选项") top_k_reader = st.sidebar.slider( "最大的答案的数量", min_value=1, max_value=30, value=DEFAULT_NUMBER_OF_ANSWERS, step=1, on_change=reset_results, ) top_k_retriever = st.sidebar.slider( "最大检索数量", min_value=1, max_value=100, value=DEFAULT_DOCS_FROM_RETRIEVER, step=1, on_change=reset_results, ) # Search bar question = st.text_input( "", value=st.session_state.question, key="quest", on_change=on_change_text, max_chars=100, placeholder="请输入您的问题", ) col1, col2 = st.columns(2) col1.markdown("", unsafe_allow_html=True) col2.markdown("", unsafe_allow_html=True) # Run button run_pressed = col1.button("运行") st.session_state.random_question_requested = False run_query = ( run_pressed or question != st.session_state.question) and not st.session_state.random_question_requested # Get results for query if (run_query or st.session_state.results is None) and question: reset_results() st.session_state.question = question with st.spinner( "🧠    Performing neural search on documents... \n " "Do you want to optimize speed or accuracy? \n" ): try: st.session_state.results = semantic_search_tutorial(question) except JSONDecodeError: st.error("👓    An error occurred reading the results. Is the document store working?") return if st.session_state.results: st.write("## 返回结果:") for count, result in enumerate(st.session_state.results): context = result.content st.write( markdown(context), unsafe_allow_html=True, ) # st.write("**Relevance:** ", result["relevance"]) st.write("___") if __name__ == "__main__": main()