GenSeq / app.py
kimou605's picture
Update app.py
7b8118c verified
raw
history blame contribute delete
No virus
12.3 kB
import re
import torch
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
EsmForProteinFolding,
)
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from proteins_viz import *
import spaces
import os
# Get the current working directory
current_working_directory = os.getcwd()
print("Current working directory:", current_working_directory)
# Utility function to extract nucleotide sequences from text
def extract_nucleotide_sequences(text):
print("Extracting nucleotide sequences...")
sequences = re.findall(r'\b[ATCG]{3,}\b', text, re.IGNORECASE)
joined_sequence = ''.join(sequences)
print(f"Extracted sequence: {joined_sequence}")
return joined_sequence
# Genetic code dictionary for translating nucleotide sequences to protein sequences
genetic_code = {
'ATA': 'I', 'ATC': 'I', 'ATT': 'I', 'ATG': 'M',
'ACA': 'T', 'ACC': 'T', 'ACG': 'T', 'ACT': 'T',
'AAC': 'N', 'AAT': 'N', 'AAA': 'K', 'AAG': 'K',
'AGC': 'S', 'AGT': 'S', 'AGA': 'R', 'AGG': 'R',
'CTA': 'L', 'CTC': 'L', 'CTG': 'L', 'CTT': 'L',
'CCA': 'P', 'CCC': 'P', 'CCG': 'P', 'CCT': 'P',
'CAC': 'H', 'CAT': 'H', 'CAA': 'Q', 'CAG': 'Q',
'CGA': 'R', 'CGC': 'R', 'CGG': 'R', 'CGT': 'R',
'GTA': 'V', 'GTC': 'V', 'GTG': 'V', 'GTT': 'V',
'GCA': 'A', 'GCC': 'A', 'GCG': 'A', 'GCT': 'A',
'GAC': 'D', 'GAT': 'D', 'GAA': 'E', 'GAG': 'E',
'GGA': 'G', 'GGC': 'G', 'GGG': 'G', 'GGT': 'G',
'TCA': 'S', 'TCC': 'S', 'TCG': 'S', 'TCT': 'S',
'TTC': 'F', 'TTT': 'F', 'TTA': 'L', 'TTG': 'L',
'TAC': 'Y', 'TAT': 'Y', 'TGC': 'C', 'TGT': 'C', 'TGG': 'W'
}
# Function to translate nucleotide sequences to protein sequences
def translate_nucleotide_sequence(nucleotide_seq):
print("Translating nucleotide sequence to protein sequence...")
truncated_seq = nucleotide_seq[:len(nucleotide_seq) // 3 * 3]
protein_seq = ''.join([genetic_code.get(truncated_seq[i:i+3], 'X') for i in range(0, len(truncated_seq), 3)])
print(f"Translated protein sequence: {protein_seq}")
return protein_seq
# Utility function to read PDB file
def read_mol(molpath):
print(f"Reading molecule from {molpath}...")
with open(molpath, "r") as fp:
mol = fp.read()
print(f"Read molecule: {mol[:100]}...") # Print first 100 characters for brevity
return mol
# Function to create an HTML iframe for molecule visualization
def molecule(input_pdb):
mol = read_mol(input_pdb)
x = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 230px;
position: relative;
}
.mol-container select{
background-image:None;
}
</style>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
<script>
let pdb = `"""
+ mol
+ """`
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "white" };
let viewer = $3Dmol.createViewer(element, config);
viewer.addModel(pdb, "pdb");
viewer.getModel(0).setStyle({}, { cartoon: { colorscheme:"chain" } });
viewer.zoomTo();
viewer.render();
viewer.zoom(0.3, 1000);
})
</script>
</body></html>"""
)
return f"""<h3 style="text-align: center;">Protein Visualization</h3> <iframe style="width: 100%; height: 250px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}' ></iframe>"""
# Function to convert outputs to PDB format
def convert_outputs_to_pdb(outputs):
final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
final_atom_positions = final_atom_positions.cpu().numpy()
final_atom_mask = outputs["atom37_atom_exists"]
pdbs = []
for i in range(outputs["aatype"].shape[0]):
aa = outputs["aatype"][i]
pred_pos = final_atom_positions[i]
mask = final_atom_mask[i]
resid = outputs["residue_index"][i] + 1
pred = OFProtein(
aatype=aa,
atom_positions=pred_pos,
atom_mask=mask,
residue_index=resid,
b_factors=outputs["plddt"][i],
chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
)
pdbs.append(to_pdb(pred))
print("Conversion to PDB format completed.")
return pdbs
# Initialize the tokenizer and model
print("Initializing tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
model = model.cuda()
model.esm = model.esm.half()
torch.backends.cuda.matmul.allow_tf32 = True
model.trunk.set_chunk_size(64)
print("Tokenizer and model initialized.")
# Initialize the text generation pipeline
print("Initializing text generation pipeline...")
pipe = pipeline("text-generation", model="kimou605/shadow-clown-BioMistral-7B-DARE", torch_dtype=torch.bfloat16, device_map="auto")
print("Text generation pipeline initialized.")
css = """
body, html {
height: 100%;
margin: 0;
background-color: #f0f0f0;
font-family: Arial, sans-serif;
}
.gradio-container {
display: flex;
flex-direction: column;
height: 100%;
}
.pdf-bubble {
display: inline-block;
background-color: #4CAF50; /* Green */
color: white;
padding: 5px 10px;
border-radius: 12px;
margin: 2px;
font-size: 12px;
line-height: 1.4;
text-decoration: none;
}
.pdf-bubble:hover {
background-color: #45a049; /* Darker green */
}
footer {
display: none !important;
}
.gr-button {
background-color: #4CAF50; /* Green */
color: white;
border: none;
padding: 10px 20px;
text-align: center;
text-decoration: none;
display: inline-block;
font-size: 16px;
margin: 4px 2px;
cursor: pointer;
border-radius: 16px;
}
.gr-button:hover {
background-color: #45a049; /* Darker green */
}
.gr-textbox, .gr-slider {
border: 2px solid #4CAF50; /* Green border */
border-radius: 4px;
padding: 10px;
margin-bottom: 10px;
width: 100%;
}
.gr-textbox input, .gr-slider input {
border: none;
outline: none;
width: 100%;
padding: 8px;
}
.gr-slider input[type=range] {
appearance: none;
width: 100%;
height: 8px;
background: #4CAF50; /* Green */
outline: none;
opacity: 0.7;
transition: opacity .2s;
}
.gr-slider input[type=range]:hover {
opacity: 1;
}
.gr-slider input[type=range]::-webkit-slider-thumb {
appearance: none;
width: 25px;
height: 25px;
background: #f44336; /* Red */
cursor: pointer;
border-radius: 50%;
}
.gr-slider input[type=range]::-moz-range-thumb {
width: 25px;
height: 25px;
background: #f44336; /* Red */
cursor: pointer;
border-radius: 50%;
}
"""
with gr.Blocks(css=css,theme=gr.themes.Soft()) as demo:
@spaces.GPU(duration=120)
def respond(message, chat_history, system_message, max_tokens, temperature, top_p, hf_token):
print("Responding to user input...")
messages = [{"role": "system", "content": system_message}]
for user_msg, assistant_msg in chat_history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
print(f"Constructed messages: {messages}")
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(f"Generated prompt: {prompt}")
outputs = pipe(prompt, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_k=50, top_p=top_p)
response_text = outputs[0]["generated_text"].split('[/INST]')[-1]
print(f"Generated response text: {response_text}")
chat_history.append((message, response_text))
sequences = extract_nucleotide_sequences(response_text)
if not sequences:
sequences=extract_nucleotide_sequences(message)
if sequences:
protein_seq = translate_nucleotide_sequence(sequences)
tokenized_input = tokenizer([protein_seq], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
with torch.no_grad():
output = model(tokenized_input)
pdb = convert_outputs_to_pdb(output)
output_pdb_path = "output_structure.pdb"
with open(output_pdb_path, "w") as f:
f.write("".join(pdb))
html = molecule(output_pdb_path)
response = response_text
return "", chat_history, html
else:
return "", chat_history,""
chatbot = gr.Chatbot(height=500)
html_output = gr.HTML()
system_message = gr.Textbox(value='''You are a highly knowledgeable medical assistant with expertise in biochemistry and molecular biology. When a user provides a medical query, you should respond with comprehensive information, including definitions, functions, applications, and relevant molecular details. If the query involves a specific molecule, provide its nucleotide or protein structure where applicable. Ensure your responses are clear, precise, and written in English.
#Example User Input and Model Response:
#User Input:
"Tell me about insulin"
#Model Response:
Insulin is a peptide hormone produced by the pancreas that regulates glucose levels in the blood. It is essential for maintaining homeostasis and plays a crucial role in metabolism. Insulin facilitates the uptake of glucose into cells, especially in the liver, muscle, and fat tissue, allowing the body to use glucose for energy or store it as glycogen.
The structure of insulin consists of two peptide chains, A and B, linked by disulfide bonds. The nucleotide sequence encoding human insulin is found on the INS gene, which includes several exons and introns.
Here is a detailed nucleotide sequence of the insulin gene (INS):
ATGGCCCCTGTGGATGCGCCTGACCTGCCCAGGCTGGGCCCTGAGTGA
This sequence represents the coding region of the insulin gene, which translates into the insulin protein necessary for glucose metabolism.
''', label="System message",visible=False)
msg = gr.Textbox(label="User Input")
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens",visible=False)
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature",visible=False)
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)",visible=False)
hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Enter your Hugging Face token here",visible=False)
clear = gr.ClearButton([msg, chatbot, html_output])
examples = gr.Examples(
examples = [
["Tell me about insulin"],
],
inputs=[msg],
)
msg.submit(respond, [msg, chatbot, system_message, max_tokens, temperature, top_p, hf_token], [msg, chatbot, html_output])
if __name__ == "__main__":
print("Launching demo...")
demo.queue().launch()
print("Demo launched.")