AllOneAINew / app.py
OOFMAN29803's picture
Update app.py
6703501 verified
raw
history blame contribute delete
No virus
8.25 kB
import torch
from torch.utils.data import Dataset
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torch.nn as nn
import math
import gradio as gr
# Suppress torchtext deprecation warnings
import torchtext
torchtext.disable_torchtext_deprecation_warning()
# Define the CSS styles
css_styles = '''
@import url('https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;600;700;800&display=swap');
.gradio-container {
font-family: 'Plus Jakarta Sans', sans-serif;
}
button.primary-button {
width: 300px;
height: 48px;
padding: 2px;
font-weight: 700;
border: 3px solid #5964C2;
border-radius: 5px;
background-color: #7583FF;
color: white;
font-size: 20px;
transition: 0.3s ease;
}
button.primary-button:hover {
background-color: #5C67C9;
border: 3px solid #31376B;
}
input[type="text"], textarea {
width: 100%;
outline: none;
border: 3px solid #B4CFBB !important;
background-color: #DEFFE7 !important;
border-radius: 10px !important;
color: #B4CFBB !important;
padding: 2px !important;
font-weight: 600 !important;
transition: 0.3s ease;
}
input[type="text"]:focus, textarea:focus {
background-color: #88A88D !important;
border: 3px solid #657D69 !important;
color: #657D69 !important;
font-size: 16px !important;
}
'''
# Define the TranslationDataset class (simplified for vocab loading)
class TranslationDataset(Dataset):
def __init__(self, file_path):
self.src_tokenizer = get_tokenizer('basic_english')
self.tgt_tokenizer = get_tokenizer('basic_english')
self.src_vocab = build_vocab_from_iterator(self._yield_tokens(file_path, 0), specials=["<unk>", "<pad>", "<bos>", "<eos>"])
self.tgt_vocab = build_vocab_from_iterator(self._yield_tokens(file_path, 1), specials=["<unk>", "<pad>", "<bos>", "<eos>"])
self.src_vocab.set_default_index(self.src_vocab["<unk>"])
self.tgt_vocab.set_default_index(self.tgt_vocab["<unk>"])
def _yield_tokens(self, file_path, index):
with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
for line in f:
line = line.strip()
if line:
try:
src, tgt = line.split('","')
src = src[2:]
tgt = tgt[:-3]
yield self.src_tokenizer(src) if index == 0 else self.tgt_tokenizer(tgt)
except ValueError:
continue
# Define the PositionalEncoding class
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
if d_model % 2 == 1:
# For odd d_model, handle the last column
pe[:, 1::2] = torch.cos(position * div_term[:-1])
else:
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)
# Define the TransformerModel class
class TransformerModel(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1):
super(TransformerModel, self).__init__()
self.model_type = 'Transformer'
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout)
self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
self.fc_out = nn.Linear(d_model, tgt_vocab_size)
self.d_model = d_model
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
src = self.src_embedding(src) * math.sqrt(self.d_model)
tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
tgt = self.pos_encoder(tgt)
memory = self.transformer(
src, tgt, src_mask, tgt_mask, None,
src_padding_mask, tgt_padding_mask, memory_key_padding_mask
)
output = self.fc_out(memory)
return output
# Translation function
def translate(model, src_sentence, src_vocab, tgt_vocab, device, max_len=50):
model.eval()
src_tokenizer = get_tokenizer('basic_english')
src_tokens = [src_vocab["<bos>"]] + [src_vocab[token] for token in src_tokenizer(src_sentence)] + [src_vocab["<eos>"]]
src_tensor = torch.LongTensor(src_tokens).unsqueeze(1).to(device)
src_mask = torch.zeros((src_tensor.size(0), src_tensor.size(0)), device=device).type(torch.bool)
with torch.no_grad():
memory = model.transformer.encoder(
model.pos_encoder(model.src_embedding(src_tensor) * math.sqrt(model.d_model)),
src_mask
)
ys = torch.ones(1, 1).fill_(tgt_vocab["<bos>"]).type(torch.long).to(device)
for _ in range(max_len-1):
tgt_mask = nn.Transformer.generate_square_subsequent_mask(ys.size(0)).to(device)
with torch.no_grad():
out = model.transformer.decoder(
model.pos_encoder(model.tgt_embedding(ys) * math.sqrt(model.d_model)),
memory,
tgt_mask
)
out = model.fc_out(out)
prob = out[-1].detach()
_, next_word = torch.max(prob, dim=1)
next_word = next_word.item()
ys = torch.cat([ys, torch.ones(1, 1).type_as(src_tensor.data).fill_(next_word)], dim=0)
if next_word == tgt_vocab["<eos>"]:
break
ys = ys.flatten()
translated_tokens = [
tgt_vocab.get_itos()[token]
for token in ys
if token not in [tgt_vocab["<bos>"], tgt_vocab["<eos>"], tgt_vocab["<pad>"]]
]
return " ".join(translated_tokens)
# Load the model and dataset
def load_model_and_data():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load the dataset (for vocabulary)
file_path = 'newcode15M.txt' # Replace with the path to your dataset file
dataset = TranslationDataset(file_path)
# Model hyperparameters (make sure these match your trained model)
SRC_VOCAB_SIZE = len(dataset.src_vocab)
TGT_VOCAB_SIZE = len(dataset.tgt_vocab)
D_MODEL = 256
NHEAD = 8
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
DIM_FEEDFORWARD = 512
DROPOUT = 0.2
# Initialize the model
model = TransformerModel(
SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, D_MODEL, NHEAD,
NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, DIM_FEEDFORWARD, DROPOUT
).to(device)
# Load the trained model
model.load_state_dict(torch.load('AllOneLM.pth', map_location=device))
model.eval()
return model, dataset.src_vocab, dataset.tgt_vocab, device
# Load model and data
model, src_vocab, tgt_vocab, device = load_model_and_data()
# Define the translation function for Gradio
def translate_sentence(src_sentence):
translated_sentence = translate(model, src_sentence, src_vocab, tgt_vocab, device)
return translated_sentence
# Create Gradio interface
iface = gr.Interface(
fn=translate_sentence,
inputs=gr.Textbox(label="Enter a sentence:", lines=2, placeholder="Type here..."),
outputs=gr.Textbox(label="Translated:"),
title="Translation Talking Script",
description="Enter a sentence to translate.",
css=css_styles
)
# Launch the interface
if __name__ == "__main__":
iface.launch()