anthony-chen commited on
Commit
a1551fc
1 Parent(s): 2c72ca6
This view is limited to 50 files because it contains too many changes.   See raw diff
Autograder.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from generator2 import response
3
+ from retrieve import retrieve_molecule_index
4
+ from PIL import Image
5
+
6
+ st.title("Chem 210 Autograder")
7
+
8
+ if "messages" not in st.session_state:
9
+ st.session_state.messages = []
10
+
11
+ for message in st.session_state.messages:
12
+ with st.chat_message(message["role"]):
13
+ st.markdown(message["content"])
14
+
15
+ # Use text_area for text input with a smaller height
16
+ text_input = st.text_input("Indicate the molecule you want graded")
17
+
18
+ # Use file_uploader for image input
19
+ image_input = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
20
+
21
+ if st.button("Submit"):
22
+ if text_input and image_input:
23
+ if text_input:
24
+ with st.chat_message("user"):
25
+ st.markdown(f"Does the following image indicate correct chemical structure of {text_input}?")
26
+ st.session_state.messages.append({"role": "user", "content": text_input})
27
+ index = int(retrieve_molecule_index(text_input)[0])
28
+ image_path = f"test/CID_{index}.png"
29
+ image2 = Image.open(image_path).convert('RGB')
30
+
31
+ if image_input:
32
+ with st.chat_message("user"):
33
+ st.image(image_input, caption="User image", use_column_width=True)
34
+ st.session_state.messages.append({"role": "user", "content": "User uploaded an image."})
35
+ image = Image.open(image_input).convert('RGB')
36
+
37
+ answer = response(image2, image)
38
+ with st.chat_message("AI"):
39
+ st.markdown(answer)
40
+
41
+ st.session_state.messages.append({"role": "AI", "content": answer})
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: green
5
  colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.36.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
5
  colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.36.0
8
+ app_file: Autograder.py
9
  pinned: false
10
  ---
11
 
config.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [example]
2
+ uri = https://in03-594505a48b55a63.api.gcp-us-west1.zillizcloud.com
3
+ token = 4e989aded0471e56339d6d1ea894daac5435006f1b6bedd1fa5f1e38ac23aeade4e8b01797d397c2b2a07a292608fe7487b4c671
embeddings.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import time
3
+ import random
4
+ from sentence_transformers import SentenceTransformer
5
+ from pymilvus import connections, DataType, FieldSchema, CollectionSchema, Collection, utility
6
+ import configparser
7
+ from tqdm import tqdm
8
+
9
+ # Initialize SentenceTransformer model for embeddings
10
+ embedding_model = SentenceTransformer(model_name_or_path="bert-base-uncased")
11
+
12
+ # Read molecule names from CSV
13
+ csv_path = 'molecules-small.csv'
14
+ df = pd.read_csv(csv_path)
15
+ max_name_length = 256
16
+ molecules = df['cmpdname'].tolist()
17
+
18
+ for i, molecule in enumerate(molecules):
19
+ if len(molecule) > max_name_length:
20
+ molecules[i] = molecule[:max_name_length]
21
+
22
+ cids = df['cid'].tolist()
23
+
24
+ # Encode embeddings for each molecule
25
+ embeddings_list = []
26
+ for molecule in tqdm(molecules, desc="Generating Embeddings"):
27
+ embeddings = embedding_model.encode(molecule)
28
+ embeddings_list.append(embeddings)
29
+
30
+ cfp = configparser.RawConfigParser()
31
+ cfp.read('config.ini')
32
+ milvus_uri = cfp.get('example', 'uri')
33
+ token = cfp.get('example', 'token')
34
+ connections.connect("default",
35
+ uri=milvus_uri,
36
+ token=token)
37
+ print(f"Connecting to DB: {milvus_uri}")
38
+ # Define collection name and dimensionality of embeddings
39
+ collection_name = 'molecule_embeddings'
40
+ check_collection = utility.has_collection(collection_name)
41
+ if check_collection:
42
+ drop_result = utility.drop_collection(collection_name)
43
+ print("Success!")
44
+ dim = 768 # Adjust based on the dimensionality of your embeddings
45
+
46
+ # Define collection schema
47
+ molecule_cid = FieldSchema(name="molecule_cid", dtype=DataType.INT64, description="cid", is_primary = True)
48
+ molecule_name = FieldSchema(name="molecule_name", dtype=DataType.VARCHAR, max_length=256, description="name")
49
+ molecule_embeddings = FieldSchema(name="molecule_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim)
50
+ schema = CollectionSchema(fields=[molecule_cid, molecule_name, molecule_embeddings],
51
+ auto_id=False,
52
+ description="my first collection!")
53
+
54
+ print(f"Creating example collection: {collection_name}")
55
+ collection = Collection(name=collection_name, schema=schema)
56
+ print(f"Schema: {schema}")
57
+ print("Success!")
58
+
59
+ batch_size = 1000
60
+ total_rt = 0
61
+ start = 0
62
+ print(f"Inserting {len(embeddings_list)} entities... ")
63
+ for i in tqdm(range(0, len(embeddings_list), batch_size), desc="Inserting Embeddings"):
64
+ batch_embeddings = embeddings_list[i:i + batch_size]
65
+ batch_molecules = molecules[i:i + batch_size]
66
+ batch_cids = cids[i:i + batch_size]
67
+ entities = [batch_cids, batch_molecules, batch_embeddings]
68
+ start += batch_size
69
+ t0 = time.time()
70
+ ins_resp = collection.insert(entities)
71
+ ins_rt = time.time() - t0
72
+ total_rt += ins_rt
73
+
74
+
75
+ print(f"Succeed in inserting {len(embeddings_list)} entities in {round(total_rt, 4)} seconds!")
76
+
77
+ # Flush collection
78
+ print("Flushing collection...")
79
+ collection.flush()
80
+
81
+ # Build index
82
+ index_params = {"index_type": "AUTOINDEX", "metric_type": "L2", "params": {}}
83
+ print("Building index...")
84
+ collection.create_index(field_name='molecule_embedding', index_params=index_params)
85
+
86
+ collection.load()
87
+
88
+ # Example search
89
+ nq = 1
90
+ search_params = {"metric_type": "L2"}
91
+ topk = 5
92
+ search_vec = [[random.random() for _ in range(dim)] for _ in range(nq)]
93
+ print(f"Searching vector: {search_vec}")
94
+ results = collection.search(search_vec, anns_field='molecule_embedding', param=search_params, limit=topk)
95
+ print(f"Search results: {results}")
96
+
97
+ # Disconnect from Milvus server
98
+ connections.disconnect("default")
99
+ print("Disconnected from Milvus server.")
generator2.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForPreTraining
3
+
4
+ def response(image2, image):
5
+
6
+ processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
7
+ model = AutoModelForPreTraining.from_pretrained("google/paligemma-3b-pt-224")
8
+
9
+ # Instruct the model to create a caption in Spanish
10
+ model_inputs = processor(text="check whether both molecules have the same chemical structure. if yes, output correct and if not, output incorrect", images= [image, image2], return_tensors="pt")
11
+ input_len = model_inputs["input_ids"].shape[-1]
12
+
13
+ with torch.inference_mode():
14
+ generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
15
+ generation = generation[0][input_len:]
16
+ decoded = processor.decode(generation, skip_special_tokens=True)
17
+ return decoded
pages/1_Manual_Image_Upload.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from generator2 import response
3
+ from PIL import Image
4
+
5
+ st.title("Chem 210 Autograder - Manual Upload")
6
+
7
+ if "messages" not in st.session_state:
8
+ st.session_state.messages = []
9
+
10
+ for message in st.session_state.messages:
11
+ with st.chat_message(message["role"]):
12
+ st.markdown(message["content"])
13
+
14
+ # Use file_uploader for image input
15
+ image_input1 = st.file_uploader("Upload the solution molecule", type=["png", "jpg", "jpeg"])
16
+ image_input2 = st.file_uploader("Upload the desired molecule to be graded", type=["png", "jpg", "jpeg"])
17
+
18
+ if st.button("Submit"):
19
+ if image_input1 and image_input2:
20
+ # Open and convert images to RGB format
21
+ image1 = Image.open(image_input1).convert('RGB')
22
+ image2 = Image.open(image_input2).convert('RGB')
23
+
24
+ # Process images using your response function
25
+ answer = response(image1, image2)
26
+
27
+ # Display results in chat style
28
+ st.session_state.messages.append({"role": "user", "content": "User uploaded images."})
29
+ with st.chat_message("AI"):
30
+ st.markdown(answer)
31
+ st.session_state.messages.append({"role": "AI", "content": answer})
32
+ else:
33
+ st.warning("Please upload two images.")
pages/2_Retrieval_Demo.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from generator2 import response
3
+ from retrieve import retrieve_molecule_index
4
+ from PIL import Image
5
+
6
+ st.title("Retrieval System Demo")
7
+
8
+ if "messages" not in st.session_state:
9
+ st.session_state.messages = []
10
+
11
+ for message in st.session_state.messages:
12
+ with st.chat_message(message["role"]):
13
+ st.markdown(message["content"])
14
+
15
+ # Use text_area for text input with a smaller height
16
+ text_input = st.text_input("Indicate a desired molecule from our database")
17
+
18
+ if st.button("Submit"):
19
+ if text_input:
20
+ with st.chat_message("user"):
21
+ st.markdown(f"Sending request to Milvus...")
22
+ index = int(retrieve_molecule_index(text_input)[0])
23
+ image_path = f"test/CID_{index}.png"
24
+ image = Image.open(image_path).convert('RGB')
25
+
26
+ with st.chat_message("AI"):
27
+ st.write("Retrieved from our database:")
28
+ st.image(image, use_column_width=True)
pages/3_About.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ paragraph_1 = """
4
+ The Autograder's underlying software design utilizes a RAG (Retrieval-Augmented Generation) pipeline.
5
+
6
+ The context items consist of 3,717 2-d molecule images sourced from PubChem and saved into a
7
+ local directory(storing on a relational database is something to explore). Pubchem comes with a downloadable
8
+ csv file that can easily be converted to a dataframe, which gives information on the molecule's cid and compound name.
9
+ The embeddings are decoded through BERT and these embeddings are stored the vector database Milvus, where the molecule's
10
+ cid served as the index.
11
+
12
+ The retrieval system returns the cid of the molecule with the highest semantic score. The intuition behind this decision
13
+ choice was that molecules can have multiple names, so a simple keyword search would not be as versatile. Currently, it
14
+ has not gone through much testing but is able to produce accurate results sometimes when fed in a molecule compound
15
+ synonym.
16
+
17
+ The underlying VLM is Google/PaliGemma-3b and is currently not fine-tuned(hence the inaccurate results). The datasets that
18
+ we have access to are too noisy and we simply do not have enough images to fine-tune a "lightweight" VLM such as
19
+ PaliGemma.
20
+
21
+ """
22
+ paragraph_2 = """
23
+ The Custom Image Upload Autograder's underlying software design employs a straightforward prompt engineering approach.
24
+
25
+ The core model used is PaliGemma. This model was developed as a contingency measure to address potential reliability issues
26
+ that may arise with the retrieval system in the future.
27
+ """
28
+ paragraph_3 = """
29
+ The next step would be to fine-tune our underlying VLM. However, before this is feasible, we would need to gather
30
+ a large corpus of images that is clean and organized well.
31
+
32
+ To improve the retrieval system, there would need to be much more molecules uploaded and stored on Milvus. Mapping
33
+ these to a relational database would also be necessary is storing these on a local directory at this large of a scale
34
+ is unfeasible.
35
+
36
+ The Streamlit frontend could also be improved to make it more user-friendly and functional. This could include a
37
+ better layout, more interactive features, and smooth integration with the backend for efficient data retrieval
38
+ and processing.
39
+ """
40
+
41
+ st.title("About")
42
+ st.markdown("""DISCLAIMER!
43
+ The underlying VLM is not finetuned, so the quality of the outputs are unreliable.
44
+ """)
45
+ st.header("About the Autograder")
46
+ st.markdown(paragraph_1)
47
+ st.header("About Manual Upload")
48
+ st.markdown(paragraph_2)
49
+ st.header("Moving Onwards")
50
+ st.markdown(paragraph_3)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ Pillow
3
+ pandas
4
+ sentence-transformers
5
+ configparser
6
+ pymilvus
7
+ tqdm
8
+ transformers
9
+ torch
retrieve.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, connections
2
+ from sentence_transformers import SentenceTransformer
3
+ import configparser
4
+
5
+
6
+ def retrieve_molecule_index(molecule):
7
+ model = SentenceTransformer(model_name_or_path="bert-base-uncased")
8
+ search_vector = model.encode(molecule).reshape(1,-1)
9
+ cfp = configparser.RawConfigParser()
10
+ cfp.read('config.ini')
11
+ milvus_uri = cfp.get('example', 'uri')
12
+ token = cfp.get('example', 'token')
13
+ connections.connect("default",
14
+ uri=milvus_uri,
15
+ token=token)
16
+ print(f"Connecting to DB: {milvus_uri}")
17
+ collection_name = "molecule_embeddings"
18
+ dim = 768 # Adjust based on the dimensionality of your embeddings
19
+
20
+ # Define collection schema
21
+ molecule_cid = FieldSchema(name="molecule_cid", dtype=DataType.INT64, description="cid", is_primary = True)
22
+ molecule_name = FieldSchema(name="molecule_name", dtype=DataType.VARCHAR, max_length=256, description="name")
23
+ molecule_embeddings = FieldSchema(name="molecule_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim)
24
+ schema = CollectionSchema(fields=[molecule_cid, molecule_name, molecule_embeddings],
25
+ auto_id=False,
26
+ description="my first collection!")
27
+
28
+ print(f"Creating example collection: {collection_name}")
29
+ collection = Collection(name=collection_name, schema=schema)
30
+
31
+ search_params = {"metric_type": "IP"}
32
+ topk = 1
33
+ results = collection.search(search_vector, anns_field='molecule_embedding', param=search_params, limit=topk)
34
+ print(results)
35
+ # Disconnect from Milvus server
36
+ connections.disconnect("default")
37
+ print("Disconnected from Milvus server.")
38
+ return results[0].ids
setup.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ index = 0
4
+ for item in os.scandir("test"):
5
+ print(item)
6
+ index += 1
7
+ if index is 5:
8
+ break
9
+
10
+
11
+
test/CID_10005.png ADDED
test/CID_10035350.png ADDED
test/CID_1004.png ADDED
test/CID_10040.png ADDED
test/CID_100633.png ADDED
test/CID_10085783.png ADDED
test/CID_10087.png ADDED
test/CID_100877.png ADDED
test/CID_100975558.png ADDED
test/CID_100975560.png ADDED
test/CID_100975561.png ADDED
test/CID_100975564.png ADDED
test/CID_100975565.png ADDED
test/CID_101.png ADDED
test/CID_101012.png ADDED
test/CID_101067786.png ADDED
test/CID_10107.png ADDED
test/CID_10110536.png ADDED
test/CID_10111.png ADDED
test/CID_10112.png ADDED
test/CID_101126841.png ADDED
test/CID_101126842.png ADDED
test/CID_101126861.png ADDED
test/CID_101134236.png ADDED
test/CID_101134254.png ADDED
test/CID_10115786.png ADDED
test/CID_101182.png ADDED
test/CID_10130527.png ADDED
test/CID_10140464.png ADDED
test/CID_101525.png ADDED
test/CID_10154195.png ADDED
test/CID_10157484.png ADDED
test/CID_101693946.png ADDED
test/CID_1017.png ADDED
test/CID_101913340.png ADDED
test/CID_101920481.png ADDED
test/CID_101920482.png ADDED
test/CID_101920483.png ADDED
test/CID_101926290.png ADDED