|
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, connections |
|
from sentence_transformers import SentenceTransformer |
|
import configparser |
|
|
|
|
|
def retrieve_molecule_index(molecule): |
|
model = SentenceTransformer(model_name_or_path="bert-base-uncased") |
|
search_vector = model.encode(molecule).reshape(1,-1) |
|
cfp = configparser.RawConfigParser() |
|
cfp.read('config.ini') |
|
milvus_uri = cfp.get('example', 'uri') |
|
token = cfp.get('example', 'token') |
|
connections.connect("default", |
|
uri=milvus_uri, |
|
token=token) |
|
print(f"Connecting to DB: {milvus_uri}") |
|
collection_name = "molecule_embeddings" |
|
dim = 768 |
|
|
|
|
|
molecule_cid = FieldSchema(name="molecule_cid", dtype=DataType.INT64, description="cid", is_primary = True) |
|
molecule_name = FieldSchema(name="molecule_name", dtype=DataType.VARCHAR, max_length=256, description="name") |
|
molecule_embeddings = FieldSchema(name="molecule_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim) |
|
schema = CollectionSchema(fields=[molecule_cid, molecule_name, molecule_embeddings], |
|
auto_id=False, |
|
description="my first collection!") |
|
|
|
print(f"Creating example collection: {collection_name}") |
|
collection = Collection(name=collection_name, schema=schema) |
|
|
|
search_params = {"metric_type": "IP"} |
|
topk = 1 |
|
results = collection.search(search_vector, anns_field='molecule_embedding', param=search_params, limit=topk) |
|
print(results) |
|
|
|
connections.disconnect("default") |
|
print("Disconnected from Milvus server.") |
|
return results[0].ids |