File size: 1,778 Bytes
a1551fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, connections
from sentence_transformers import SentenceTransformer
import configparser


def retrieve_molecule_index(molecule):
    model = SentenceTransformer(model_name_or_path="bert-base-uncased")
    search_vector = model.encode(molecule).reshape(1,-1)
    cfp = configparser.RawConfigParser()
    cfp.read('config.ini')
    milvus_uri = cfp.get('example', 'uri')
    token = cfp.get('example', 'token')
    connections.connect("default",
                        uri=milvus_uri,
                        token=token)
    print(f"Connecting to DB: {milvus_uri}")
    collection_name = "molecule_embeddings"
    dim = 768  # Adjust based on the dimensionality of your embeddings
    
    # Define collection schema
    molecule_cid = FieldSchema(name="molecule_cid", dtype=DataType.INT64, description="cid", is_primary = True)
    molecule_name = FieldSchema(name="molecule_name", dtype=DataType.VARCHAR, max_length=256, description="name")
    molecule_embeddings = FieldSchema(name="molecule_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim)
    schema = CollectionSchema(fields=[molecule_cid, molecule_name, molecule_embeddings],
                          auto_id=False,
                          description="my first collection!")
    
    print(f"Creating example collection: {collection_name}")
    collection = Collection(name=collection_name, schema=schema)
    
    search_params = {"metric_type": "IP"}
    topk = 1
    results = collection.search(search_vector, anns_field='molecule_embedding', param=search_params, limit=topk)
    print(results)
    # Disconnect from Milvus server
    connections.disconnect("default")
    print("Disconnected from Milvus server.")
    return results[0].ids