File size: 3,175 Bytes
a726fe4
 
 
b461c84
 
 
a726fe4
 
b461c84
 
b6828c2
b461c84
 
e3b5f05
b461c84
a726fe4
034aca0
 
 
a726fe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7acd96f
 
a726fe4
 
 
 
 
 
 
 
7acd96f
 
 
 
 
a726fe4
7acd96f
 
 
a726fe4
 
 
 
 
034aca0
 
 
 
 
 
 
 
 
 
 
 
a726fe4
034aca0
 
a726fe4
 
 
7acd96f
a726fe4
 
7acd96f
a726fe4
7acd96f
 
a726fe4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from transformers import AutoConfig

model_name = "farhan2206/dnabert2fourth"

# Load the tokenizer and model

# Load the configuration associated with the model
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

# Load the model using the correct configuration
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config, trust_remote_code=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Streamlit UI
def main():
    st.title("Epigenetic Marks Prediction")
    st.write("An application of DNA BERT2")

    # Sidebar with information
    st.sidebar.header("About")
    st.sidebar.write("This app uses DNA BERT2 to predict the presence of epigenetic marks in a given DNA sequence.")

    # User input
    user_input = st.text_area("Enter a DNA sequence:", height=150)

    # Predict when the user provides input
    if st.button("Classify Sequence"):
        if user_input:
            # Call the pred function for prediction
            # predicted_class, confidence = pred(user_input)
            predicted_class = pred(user_input)

            # Display the result
            st.subheader("Prediction Result")
            if predicted_class == 1:
                st.success("Epigenetic Mark detected!")
            else:
                st.info("No epigenetic mark found.")

            # # Display progress bars with percentages
            # st.subheader("Class Distribution")
            # st.write("1 - Epigenetic mark found")
            # st.progress(confidence)
            # st.text(f"{confidence * 100:.2f}%")
            
            # st.write("0 - Epigenetic mark not found")
            # st.progress(1 - confidence)
            # st.text(f"{(1 - confidence) * 100:.2f}%")

        else:
            st.warning("Please enter a DNA sequence for classification.")

# Function for prediction
# def pred(sequence):
#     encoded_input = tokenizer(sequence, return_tensors='pt')
    
#     # Pass the encoded input through the model
#     with torch.no_grad():
#         outputs = model(input_ids=encoded_input['input_ids'], attention_mask=encoded_input['attention_mask'])
#         logits = outputs[0]
#         predicted_class = logits.argmax(-1).item()
#         confidence = logits.softmax(dim=-1)[0, 1].item()

#     return predicted_class, confidence

def pred(sequence):
    # Move the input tensors to the GPU
    encoded_input = tokenizer(sequence, return_tensors='pt').to(device)
    
    # Pass the encoded input through the model
    with torch.no_grad():
        outputs = model(input_ids=encoded_input['input_ids'], attention_mask=encoded_input['attention_mask']).to(device)
        logits = outputs[0]
        predicted_class = logits.argmax(-1).item()
        #confidence = logits.softmax(dim=-1)[0, 1].item()

    return predicted_class
    #, confidence

if __name__ == "__main__":
    main()


# streamlit run app.py --server.port 9000