AllOneAI111M / app.py
OOFMAN29803's picture
Update app.py
6bb6151 verified
raw
history blame contribute delete
No virus
6.29 kB
import torch
from torch.utils.data import Dataset
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torch.nn as nn
import math
import gradio as gr
# Step 1: Define the Dataset Class (for loading vocabularies)
class TranslationDataset(Dataset):
def __init__(self, file_path, src_vocab=None, tgt_vocab=None):
self.data = []
self.src_tokenizer = get_tokenizer('basic_english')
self.tgt_tokenizer = get_tokenizer('basic_english')
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
try:
src, tgt = line.split('", "')
src = src[2:]
tgt = tgt[:-3]
self.data.append((src, tgt))
except ValueError:
print(f"Skipping malformed line: {line}")
if src_vocab is None or tgt_vocab is None:
self.src_vocab = build_vocab_from_iterator(self._yield_tokens(self.data, 0), specials=["<unk>", "<pad>", "<bos>", "<eos>"])
self.tgt_vocab = build_vocab_from_iterator(self._yield_tokens(self.data, 1), specials=["<unk>", "<pad>", "<bos>", "<eos>"])
else:
self.src_vocab = src_vocab
self.tgt_vocab = tgt_vocab
self.src_vocab.set_default_index(self.src_vocab["<unk>"])
self.tgt_vocab.set_default_index(self.tgt_vocab["<unk>"])
def _yield_tokens(self, data, index):
for src, tgt in data:
yield self.src_tokenizer(src) if index == 0 else self.tgt_tokenizer(tgt)
# Step 2: Define the Transformer Model
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class TransformerModel(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
super(TransformerModel, self).__init__()
self.model_type = 'Transformer'
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout)
self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
self.fc_out = nn.Linear(d_model, tgt_vocab_size)
self.d_model = d_model
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
src = self.src_embedding(src) * math.sqrt(self.d_model)
tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
tgt = self.pos_encoder(tgt)
memory = self.transformer(src, tgt, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
output = self.fc_out(memory)
return output
# Function to load the trained model
def load_model(model_path, src_vocab_size, tgt_vocab_size):
model = TransformerModel(src_vocab_size, tgt_vocab_size)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
return model
# Function to translate a sentence
def translate(model, src_sentence, dataset, device, max_len=50):
model.eval()
src_tokens = [dataset.src_vocab["<bos>"]] + [dataset.src_vocab[token] for token in dataset.src_tokenizer(src_sentence)] + [dataset.src_vocab["<eos>"]]
src_tensor = torch.LongTensor(src_tokens).unsqueeze(1).to(device)
src_mask = torch.zeros((src_tensor.size(0), src_tensor.size(0)), device=device).type(torch.bool)
memory = model.transformer.encoder(model.pos_encoder(model.src_embedding(src_tensor) * math.sqrt(model.d_model)), src_mask)
ys = torch.ones(1, 1).fill_(dataset.tgt_vocab["<bos>"]).type(torch.long).to(device)
for i in range(max_len-1):
tgt_mask = nn.Transformer.generate_square_subsequent_mask(ys.size(0)).to(device)
out = model.transformer.decoder(model.pos_encoder(model.tgt_embedding(ys) * math.sqrt(model.d_model)),
memory, tgt_mask)
out = model.fc_out(out)
prob = out[-1]
_, next_word = torch.max(prob, dim=1)
next_word = next_word.item()
ys = torch.cat([ys, torch.ones(1, 1).type_as(src_tensor.data).fill_(next_word)], dim=0)
if next_word == dataset.tgt_vocab["<eos>"]:
break
ys = ys.flatten()
translated_tokens = [dataset.tgt_vocab.get_itos()[token] for token in ys if token not in [dataset.tgt_vocab["<bos>"], dataset.tgt_vocab["<eos>"], dataset.tgt_vocab["<pad>"]]]
return " ".join(translated_tokens)
# Gradio Interface
def translate_with_model(src_sentence):
translation = translate(model, src_sentence, dataset, device)
return translation
# Load dataset (which contains vocabularies)
file_path = 'data.txt' # Make sure this points to your data file
dataset = TranslationDataset(file_path)
# Load the trained model
model_path = 'SolconLM12B.pth'
model = load_model(model_path, len(dataset.src_vocab), len(dataset.tgt_vocab))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Gradio interface
iface = gr.Interface(fn=translate_with_model, inputs="text", outputs="text", title="Solocon12B Model", description="Enter Question")
if __name__ == "__main__":
iface.launch()