|
import re |
|
import torch |
|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
pipeline, |
|
EsmForProteinFolding, |
|
) |
|
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein |
|
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 |
|
from proteins_viz import * |
|
import spaces |
|
import os |
|
|
|
|
|
current_working_directory = os.getcwd() |
|
print("Current working directory:", current_working_directory) |
|
|
|
|
|
def extract_nucleotide_sequences(text): |
|
print("Extracting nucleotide sequences...") |
|
sequences = re.findall(r'\b[ATCG]{3,}\b', text, re.IGNORECASE) |
|
joined_sequence = ''.join(sequences) |
|
print(f"Extracted sequence: {joined_sequence}") |
|
return joined_sequence |
|
|
|
|
|
genetic_code = { |
|
'ATA': 'I', 'ATC': 'I', 'ATT': 'I', 'ATG': 'M', |
|
'ACA': 'T', 'ACC': 'T', 'ACG': 'T', 'ACT': 'T', |
|
'AAC': 'N', 'AAT': 'N', 'AAA': 'K', 'AAG': 'K', |
|
'AGC': 'S', 'AGT': 'S', 'AGA': 'R', 'AGG': 'R', |
|
'CTA': 'L', 'CTC': 'L', 'CTG': 'L', 'CTT': 'L', |
|
'CCA': 'P', 'CCC': 'P', 'CCG': 'P', 'CCT': 'P', |
|
'CAC': 'H', 'CAT': 'H', 'CAA': 'Q', 'CAG': 'Q', |
|
'CGA': 'R', 'CGC': 'R', 'CGG': 'R', 'CGT': 'R', |
|
'GTA': 'V', 'GTC': 'V', 'GTG': 'V', 'GTT': 'V', |
|
'GCA': 'A', 'GCC': 'A', 'GCG': 'A', 'GCT': 'A', |
|
'GAC': 'D', 'GAT': 'D', 'GAA': 'E', 'GAG': 'E', |
|
'GGA': 'G', 'GGC': 'G', 'GGG': 'G', 'GGT': 'G', |
|
'TCA': 'S', 'TCC': 'S', 'TCG': 'S', 'TCT': 'S', |
|
'TTC': 'F', 'TTT': 'F', 'TTA': 'L', 'TTG': 'L', |
|
'TAC': 'Y', 'TAT': 'Y', 'TGC': 'C', 'TGT': 'C', 'TGG': 'W' |
|
} |
|
|
|
|
|
def translate_nucleotide_sequence(nucleotide_seq): |
|
print("Translating nucleotide sequence to protein sequence...") |
|
truncated_seq = nucleotide_seq[:len(nucleotide_seq) // 3 * 3] |
|
protein_seq = ''.join([genetic_code.get(truncated_seq[i:i+3], 'X') for i in range(0, len(truncated_seq), 3)]) |
|
print(f"Translated protein sequence: {protein_seq}") |
|
return protein_seq |
|
|
|
|
|
def read_mol(molpath): |
|
print(f"Reading molecule from {molpath}...") |
|
with open(molpath, "r") as fp: |
|
mol = fp.read() |
|
print(f"Read molecule: {mol[:100]}...") |
|
return mol |
|
|
|
|
|
def molecule(input_pdb): |
|
mol = read_mol(input_pdb) |
|
x = ( |
|
"""<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta http-equiv="content-type" content="text/html; charset=UTF-8" /> |
|
<style> |
|
body{ |
|
font-family:sans-serif |
|
} |
|
.mol-container { |
|
width: 100%; |
|
height: 230px; |
|
position: relative; |
|
} |
|
.mol-container select{ |
|
background-image:None; |
|
} |
|
</style> |
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script> |
|
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> |
|
</head> |
|
<body> |
|
<div id="container" class="mol-container"></div> |
|
<script> |
|
let pdb = `""" |
|
+ mol |
|
+ """` |
|
$(document).ready(function () { |
|
let element = $("#container"); |
|
let config = { backgroundColor: "white" }; |
|
let viewer = $3Dmol.createViewer(element, config); |
|
viewer.addModel(pdb, "pdb"); |
|
viewer.getModel(0).setStyle({}, { cartoon: { colorscheme:"chain" } }); |
|
viewer.zoomTo(); |
|
viewer.render(); |
|
viewer.zoom(0.3, 1000); |
|
}) |
|
</script> |
|
</body></html>""" |
|
) |
|
|
|
return f"""<h3 style="text-align: center;">Protein Visualization</h3> <iframe style="width: 100%; height: 250px" name="result" allow="midi; geolocation; microphone; camera; |
|
display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
|
allow-scripts allow-same-origin allow-popups |
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
allowpaymentrequest="" frameborder="0" srcdoc='{x}' ></iframe>""" |
|
|
|
|
|
def convert_outputs_to_pdb(outputs): |
|
final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs) |
|
outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()} |
|
final_atom_positions = final_atom_positions.cpu().numpy() |
|
final_atom_mask = outputs["atom37_atom_exists"] |
|
pdbs = [] |
|
for i in range(outputs["aatype"].shape[0]): |
|
aa = outputs["aatype"][i] |
|
pred_pos = final_atom_positions[i] |
|
mask = final_atom_mask[i] |
|
resid = outputs["residue_index"][i] + 1 |
|
pred = OFProtein( |
|
aatype=aa, |
|
atom_positions=pred_pos, |
|
atom_mask=mask, |
|
residue_index=resid, |
|
b_factors=outputs["plddt"][i], |
|
chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None, |
|
) |
|
pdbs.append(to_pdb(pred)) |
|
print("Conversion to PDB format completed.") |
|
return pdbs |
|
|
|
|
|
print("Initializing tokenizer and model...") |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") |
|
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True) |
|
model = model.cuda() |
|
model.esm = model.esm.half() |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
model.trunk.set_chunk_size(64) |
|
print("Tokenizer and model initialized.") |
|
|
|
|
|
print("Initializing text generation pipeline...") |
|
pipe = pipeline("text-generation", model="kimou605/shadow-clown-BioMistral-7B-DARE", torch_dtype=torch.bfloat16, device_map="auto") |
|
print("Text generation pipeline initialized.") |
|
|
|
|
|
css = """ |
|
body, html { |
|
height: 100%; |
|
margin: 0; |
|
background-color: #f0f0f0; |
|
font-family: Arial, sans-serif; |
|
} |
|
.gradio-container { |
|
display: flex; |
|
flex-direction: column; |
|
height: 100%; |
|
} |
|
.pdf-bubble { |
|
display: inline-block; |
|
background-color: #4CAF50; /* Green */ |
|
color: white; |
|
padding: 5px 10px; |
|
border-radius: 12px; |
|
margin: 2px; |
|
font-size: 12px; |
|
line-height: 1.4; |
|
text-decoration: none; |
|
} |
|
.pdf-bubble:hover { |
|
background-color: #45a049; /* Darker green */ |
|
} |
|
footer { |
|
display: none !important; |
|
} |
|
.gr-button { |
|
background-color: #4CAF50; /* Green */ |
|
color: white; |
|
border: none; |
|
padding: 10px 20px; |
|
text-align: center; |
|
text-decoration: none; |
|
display: inline-block; |
|
font-size: 16px; |
|
margin: 4px 2px; |
|
cursor: pointer; |
|
border-radius: 16px; |
|
} |
|
.gr-button:hover { |
|
background-color: #45a049; /* Darker green */ |
|
} |
|
.gr-textbox, .gr-slider { |
|
border: 2px solid #4CAF50; /* Green border */ |
|
border-radius: 4px; |
|
padding: 10px; |
|
margin-bottom: 10px; |
|
width: 100%; |
|
} |
|
.gr-textbox input, .gr-slider input { |
|
border: none; |
|
outline: none; |
|
width: 100%; |
|
padding: 8px; |
|
} |
|
.gr-slider input[type=range] { |
|
appearance: none; |
|
width: 100%; |
|
height: 8px; |
|
background: #4CAF50; /* Green */ |
|
outline: none; |
|
opacity: 0.7; |
|
transition: opacity .2s; |
|
} |
|
.gr-slider input[type=range]:hover { |
|
opacity: 1; |
|
} |
|
.gr-slider input[type=range]::-webkit-slider-thumb { |
|
appearance: none; |
|
width: 25px; |
|
height: 25px; |
|
background: #f44336; /* Red */ |
|
cursor: pointer; |
|
border-radius: 50%; |
|
} |
|
.gr-slider input[type=range]::-moz-range-thumb { |
|
width: 25px; |
|
height: 25px; |
|
background: #f44336; /* Red */ |
|
cursor: pointer; |
|
border-radius: 50%; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css,theme=gr.themes.Soft()) as demo: |
|
@spaces.GPU(duration=120) |
|
def respond(message, chat_history, system_message, max_tokens, temperature, top_p, hf_token): |
|
print("Responding to user input...") |
|
messages = [{"role": "system", "content": system_message}] |
|
for user_msg, assistant_msg in chat_history: |
|
if user_msg: |
|
messages.append({"role": "user", "content": user_msg}) |
|
if assistant_msg: |
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
messages.append({"role": "user", "content": message}) |
|
print(f"Constructed messages: {messages}") |
|
|
|
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
print(f"Generated prompt: {prompt}") |
|
outputs = pipe(prompt, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_k=50, top_p=top_p) |
|
response_text = outputs[0]["generated_text"].split('[/INST]')[-1] |
|
print(f"Generated response text: {response_text}") |
|
chat_history.append((message, response_text)) |
|
|
|
sequences = extract_nucleotide_sequences(response_text) |
|
if not sequences: |
|
sequences=extract_nucleotide_sequences(message) |
|
|
|
if sequences: |
|
protein_seq = translate_nucleotide_sequence(sequences) |
|
tokenized_input = tokenizer([protein_seq], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() |
|
|
|
with torch.no_grad(): |
|
output = model(tokenized_input) |
|
|
|
pdb = convert_outputs_to_pdb(output) |
|
output_pdb_path = "output_structure.pdb" |
|
with open(output_pdb_path, "w") as f: |
|
f.write("".join(pdb)) |
|
html = molecule(output_pdb_path) |
|
|
|
response = response_text |
|
return "", chat_history, html |
|
else: |
|
return "", chat_history,"" |
|
|
|
|
|
chatbot = gr.Chatbot(height=500) |
|
html_output = gr.HTML() |
|
system_message = gr.Textbox(value='''You are a highly knowledgeable medical assistant with expertise in biochemistry and molecular biology. When a user provides a medical query, you should respond with comprehensive information, including definitions, functions, applications, and relevant molecular details. If the query involves a specific molecule, provide its nucleotide or protein structure where applicable. Ensure your responses are clear, precise, and written in English. |
|
|
|
#Example User Input and Model Response: |
|
#User Input: |
|
"Tell me about insulin" |
|
#Model Response: |
|
Insulin is a peptide hormone produced by the pancreas that regulates glucose levels in the blood. It is essential for maintaining homeostasis and plays a crucial role in metabolism. Insulin facilitates the uptake of glucose into cells, especially in the liver, muscle, and fat tissue, allowing the body to use glucose for energy or store it as glycogen. |
|
The structure of insulin consists of two peptide chains, A and B, linked by disulfide bonds. The nucleotide sequence encoding human insulin is found on the INS gene, which includes several exons and introns. |
|
Here is a detailed nucleotide sequence of the insulin gene (INS): |
|
ATGGCCCCTGTGGATGCGCCTGACCTGCCCAGGCTGGGCCCTGAGTGA |
|
This sequence represents the coding region of the insulin gene, which translates into the insulin protein necessary for glucose metabolism. |
|
|
|
|
|
''', label="System message",visible=False) |
|
msg = gr.Textbox(label="User Input") |
|
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens",visible=False) |
|
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature",visible=False) |
|
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)",visible=False) |
|
hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Enter your Hugging Face token here",visible=False) |
|
clear = gr.ClearButton([msg, chatbot, html_output]) |
|
|
|
|
|
examples = gr.Examples( |
|
examples = [ |
|
["Tell me about insulin"], |
|
], |
|
inputs=[msg], |
|
) |
|
|
|
msg.submit(respond, [msg, chatbot, system_message, max_tokens, temperature, top_p, hf_token], [msg, chatbot, html_output]) |
|
|
|
if __name__ == "__main__": |
|
print("Launching demo...") |
|
demo.queue().launch() |
|
print("Demo launched.") |
|
|