Chem-210-Autograder / retrieve.py
anthony-chen's picture
kek
a1551fc
raw
history blame contribute delete
No virus
1.78 kB
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, connections
from sentence_transformers import SentenceTransformer
import configparser
def retrieve_molecule_index(molecule):
model = SentenceTransformer(model_name_or_path="bert-base-uncased")
search_vector = model.encode(molecule).reshape(1,-1)
cfp = configparser.RawConfigParser()
cfp.read('config.ini')
milvus_uri = cfp.get('example', 'uri')
token = cfp.get('example', 'token')
connections.connect("default",
uri=milvus_uri,
token=token)
print(f"Connecting to DB: {milvus_uri}")
collection_name = "molecule_embeddings"
dim = 768 # Adjust based on the dimensionality of your embeddings
# Define collection schema
molecule_cid = FieldSchema(name="molecule_cid", dtype=DataType.INT64, description="cid", is_primary = True)
molecule_name = FieldSchema(name="molecule_name", dtype=DataType.VARCHAR, max_length=256, description="name")
molecule_embeddings = FieldSchema(name="molecule_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim)
schema = CollectionSchema(fields=[molecule_cid, molecule_name, molecule_embeddings],
auto_id=False,
description="my first collection!")
print(f"Creating example collection: {collection_name}")
collection = Collection(name=collection_name, schema=schema)
search_params = {"metric_type": "IP"}
topk = 1
results = collection.search(search_vector, anns_field='molecule_embedding', param=search_params, limit=topk)
print(results)
# Disconnect from Milvus server
connections.disconnect("default")
print("Disconnected from Milvus server.")
return results[0].ids