AptaBLE / utils.py
AtomBio's picture
Create utils.py
2616ade verified
raw
history blame
8.11 kB
import numpy as np
import random
import math
from sklearn.metrics import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import pickle
def word2idx(word, words):
if word in words.keys():
return int(words[word])
return 0
def pad_seq(dataset, max_len):
output = []
for seq in dataset:
pad = np.zeros(max_len)
pad[:len(seq)] = seq
output.append(pad)
return np.array(output)
def str2bool(seq):
out = []
for s in seq:
if s == "positive":
out.append(1)
elif s == "negative":
out.append(0)
return np.array(out)
class API_Dataset(Dataset):
def __init__(self, apta, esm_prot, y, apta_attn_mask, prot_attn_mask):
super(Dataset, self).__init__()
self.apta = np.array(apta, dtype=np.int64)
self.esm_prot = np.array(esm_prot, dtype=np.int64)
self.y = np.array(y, dtype=np.int64)
self.apta_attn_mask = np.array(apta_attn_mask)
self.prot_attn_mask = np.array(prot_attn_mask)
self.len = len(self.apta)
def __len__(self):
return self.len
def __getitem__(self, index):
return torch.tensor(self.apta[index], dtype=torch.int64), torch.tensor(self.esm_prot[index], dtype=torch.int64), torch.tensor(self.y[index], dtype=torch.int64), torch.tensor(self.apta_attn_mask[index], dtype=torch.int64), torch.tensor(self.prot_attn_mask[index], dtype=torch.int64)
def find_opt_threshold(target, pred):
result = 0
best = 0
for i in range(0, 1000):
pred_threshold = np.where(pred > i/1000, 1, 0)
now = f1_score(target, pred_threshold)
if now > best:
result = i/1000
best = now
return result
def argument_seqset(seqset):
arg_seqset = []
for s, ss in seqset:
arg_seqset.append([s, ss])
arg_seqset.append([s[::-1], ss[::-1]])
return arg_seqset
def augment_apis(apta, prot, ys):
aug_apta = []
aug_prot = []
aug_y = []
for a, p, y in zip(apta, prot, ys):
aug_apta.append(a)
aug_prot.append(p)
aug_y.append(y)
aug_apta.append(a[::-1])
aug_prot.append(p)
aug_y.append(y)
return np.array(aug_apta), np.array(aug_prot), np.array(aug_y)
def load_data_source(filepath):
with open(filepath,"rb") as fr:
dataset = pickle.load(fr)
dataset_train = np.array(dataset[dataset["dataset"]=="training dataset"])
dataset_test = np.array(dataset[dataset["dataset"]=="test dataset"])
dataset_bench = np.array(dataset[dataset['dataset']=='benchmark dataset'])
return dataset_train, dataset_test, dataset_bench
def get_dataset(filepath, prot_max_len, n_prot_vocabs, prot_words):
dataset_train, dataset_test, dataset_bench = load_data_source(filepath)
arg_apta, arg_prot, arg_y = augment_apis(dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2])
datasets_train = [rna2vec(arg_apta), tokenize_sequences(arg_prot, prot_max_len, n_prot_vocabs, prot_words), str2bool(arg_y)]
datasets_test = [rna2vec(dataset_test[:, 0]), tokenize_sequences(dataset_test[:, 1], prot_max_len, n_prot_vocabs, prot_words), str2bool(dataset_test[:, 2])]
datasets_bench = [rna2vec(dataset_bench[:, 0]), tokenize_sequences(dataset_bench[:, 1], prot_max_len, n_prot_vocabs, prot_words), str2bool(dataset_bench[:, 2])]
return datasets_train, datasets_test, datasets_bench
def get_esm_dataset(filepath, batch_converter, alphabet):
dataset_train, dataset_test, dataset_bench = load_data_source(filepath)
# arg_apta, arg_prot, arg_y = augment_apis(dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2])
# arg_prot is a np.array of strings (4640,) -> convert this to np.array of size (2x4640) where first row is a label
arg_apta, arg_prot, arg_y = dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2]
arg_apta, arg_prot, arg_y = augment_apis(arg_apta, arg_prot, arg_y)
train_inputs = [(i, j) for i, j in zip(arg_y, arg_prot)]
_, _, prot_tokens = batch_converter(train_inputs)
datasets_train = [rna2vec(arg_apta), prot_tokens, str2bool(arg_y)]
test_inputs = [(i, j) for i, j in enumerate(dataset_test[:, 1])]
_, _, test_prot_tokens = batch_converter(test_inputs)
datasets_test = [rna2vec(dataset_test[:, 0]), test_prot_tokens, str2bool(dataset_test[:, 2])]
bench_inputs = [(i, j) for i, j in enumerate(dataset_bench[:, 1])]
_, _, bench_prot_tokens = batch_converter(bench_inputs)
# truncating
bench_prot_tokenized = bench_prot_tokens[:, :1678]
# padding
prot_ex = torch.ones((bench_prot_tokenized.shape[0], 1678), dtype=torch.int64)*alphabet.padding_idx
prot_ex[:, :bench_prot_tokenized.shape[1]] = bench_prot_tokenized
datasets_bench = [rna2vec(dataset_bench[:, 0]), prot_ex, str2bool(dataset_bench[:, 2])]
return datasets_train, datasets_test, datasets_bench
def get_nt_esm_dataset(filepath, nt_tokenizer, batch_converter, alphabet):
dataset_train, dataset_test, dataset_bench = load_data_source(filepath)
arg_apta, arg_prot, arg_y = augment_apis(dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2])
# arg_prot is a np.array of strings (4640,) -> convert this to np.array of size (2x4640) where first row is a label
max_length = 275#nt_tokenizer.model_max_length
train_inputs = [(i, j) for i, j in zip(arg_y, arg_prot)]
_, _, prot_tokens = batch_converter(train_inputs)
apta_toks = nt_tokenizer.batch_encode_plus(arg_apta, return_tensors='pt', padding='max_length', max_length=max_length)['input_ids']
apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id
prot_attention_mask = prot_tokens != alphabet.padding_idx
# datasets_train = [apta_toks, prot_tokens, str2bool(arg_y)]
datasets_train = [apta_toks, prot_tokens, str2bool(arg_y), apta_attention_mask, prot_attention_mask]
test_inputs = [(i, j) for i, j in enumerate(dataset_test[:, 1])]
_, _, test_prot_tokens = batch_converter(test_inputs)
prot_ex = torch.ones((test_prot_tokens.shape[0], 1680), dtype=torch.int64)*alphabet.padding_idx
prot_ex[:, :test_prot_tokens.shape[1]] = test_prot_tokens
apta_toks = nt_tokenizer.batch_encode_plus(dataset_test[:, 0], return_tensors='pt', padding='max_length', max_length=max_length)['input_ids']
apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id
prot_attention_mask = prot_ex != alphabet.padding_idx
datasets_test = [apta_toks, prot_ex, str2bool(dataset_test[:, 2]), apta_attention_mask, prot_attention_mask]
bench_inputs = [(i, j) for i, j in enumerate(dataset_bench[:, 1])]
_, _, bench_prot_tokens = batch_converter(bench_inputs)
# padding
prot_ex = torch.ones((bench_prot_tokens.shape[0], 1680), dtype=torch.int64)*alphabet.padding_idx
prot_ex[:, :bench_prot_tokens.shape[1]] = bench_prot_tokens
apta_toks = nt_tokenizer.batch_encode_plus(dataset_bench[:, 0], return_tensors='pt', padding='max_length', max_length=max_length)['input_ids']
apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id
prot_attention_mask = prot_ex != alphabet.padding_idx
datasets_bench = [apta_toks, prot_ex, str2bool(dataset_bench[:, 2]), apta_attention_mask, prot_attention_mask]
return datasets_train, datasets_test, datasets_bench
def get_scores(target, pred):
threshold = find_opt_threshold(target, pred)
pred_threshold = np.where(pred > threshold, 1, 0)
acc = accuracy_score(target, pred_threshold)
roc_auc = roc_auc_score(target, pred)
mcc = matthews_corrcoef(target, pred_threshold)
f1 = f1_score(target, pred_threshold)
pr_auc = average_precision_score(target, pred)
cls_report = classification_report(target, pred_threshold)
scores = {
'threshold': threshold,
'acc': acc,
'roc_auc': roc_auc,
'mcc': mcc,
'f1': f1,
'pr_auc': pr_auc,
'cls_report': cls_report
}
return scores