python-no-senpai / store.py
shimizukawa's picture
refactoring: move index annotation
8d5b271
raw
history blame contribute delete
No virus
2.4 kB
import argparse
from itertools import islice
from pathlib import Path
from tqdm import tqdm
import torch
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Qdrant
from loaders import get_loader, LOADER_NAMES
from config import DB_CONFIG
CHUNK_SIZE = 500
def get_text_chunk(docs):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=0
)
texts = text_splitter.split_documents(docs)
return texts
def batched(iterable, *, size=100):
"Batch data into tuples of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if size < 1:
raise ValueError('n must be at least one')
it = iter(iterable)
while batch := tuple(islice(it, size)):
yield batch
def store(texts):
model_name = "intfloat/multilingual-e5-large"
model_kwargs = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"}
encode_kwargs = {"normalize_embeddings": False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
db_url, db_api_key, db_collection_name = DB_CONFIG
for batch in tqdm(batched(texts, size=100)):
_ = Qdrant.from_documents(
batch,
embeddings,
url=db_url,
api_key=db_api_key,
collection_name=db_collection_name,
)
def get_parser():
p = argparse.ArgumentParser()
p.add_argument("index", type=str)
p.add_argument("inputfile", metavar="INPUTFILE", type=str)
p.add_argument("-l", "--loader", type=str, choices=LOADER_NAMES, required=True)
return p
def index_annotated_docs(docs, index):
for doc in docs:
doc.metadata["index"] = index
yield doc
def main():
"""
$ python store.py --loader wikipage "index" "FILE_PATH"
$ python store.py -l wikipage wiki data/wiki.json
$ python store.py -l rtdhtmlpage django ./docs.djangoproject.com/
"""
p = get_parser()
args = p.parse_args()
loader = get_loader(
args.loader,
inputfile=Path(args.inputfile),
)
docs = loader.lazy_load()
texts = get_text_chunk(index_annotated_docs(docs, args.index))
store(texts)
if __name__ == "__main__":
main()