Spaces:
Running
Running
import torch | |
from torch.utils.data import Dataset | |
from torchtext.data.utils import get_tokenizer | |
from torchtext.vocab import build_vocab_from_iterator | |
import torch.nn as nn | |
import math | |
import gradio as gr | |
# Suppress torchtext deprecation warnings | |
import torchtext | |
torchtext.disable_torchtext_deprecation_warning() | |
# Define the CSS styles | |
css_styles = ''' | |
@import url('https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;600;700;800&display=swap'); | |
.gradio-container { | |
font-family: 'Plus Jakarta Sans', sans-serif; | |
} | |
button.primary-button { | |
width: 300px; | |
height: 48px; | |
padding: 2px; | |
font-weight: 700; | |
border: 3px solid #5964C2; | |
border-radius: 5px; | |
background-color: #7583FF; | |
color: white; | |
font-size: 20px; | |
transition: 0.3s ease; | |
} | |
button.primary-button:hover { | |
background-color: #5C67C9; | |
border: 3px solid #31376B; | |
} | |
input[type="text"], textarea { | |
width: 100%; | |
outline: none; | |
border: 3px solid #B4CFBB !important; | |
background-color: #DEFFE7 !important; | |
border-radius: 10px !important; | |
color: #B4CFBB !important; | |
padding: 2px !important; | |
font-weight: 600 !important; | |
transition: 0.3s ease; | |
} | |
input[type="text"]:focus, textarea:focus { | |
background-color: #88A88D !important; | |
border: 3px solid #657D69 !important; | |
color: #657D69 !important; | |
font-size: 16px !important; | |
} | |
''' | |
# Define the TranslationDataset class (simplified for vocab loading) | |
class TranslationDataset(Dataset): | |
def __init__(self, file_path): | |
self.src_tokenizer = get_tokenizer('basic_english') | |
self.tgt_tokenizer = get_tokenizer('basic_english') | |
self.src_vocab = build_vocab_from_iterator(self._yield_tokens(file_path, 0), specials=["<unk>", "<pad>", "<bos>", "<eos>"]) | |
self.tgt_vocab = build_vocab_from_iterator(self._yield_tokens(file_path, 1), specials=["<unk>", "<pad>", "<bos>", "<eos>"]) | |
self.src_vocab.set_default_index(self.src_vocab["<unk>"]) | |
self.tgt_vocab.set_default_index(self.tgt_vocab["<unk>"]) | |
def _yield_tokens(self, file_path, index): | |
with open(file_path, 'r', encoding='utf-8', errors='replace') as f: | |
for line in f: | |
line = line.strip() | |
if line: | |
try: | |
src, tgt = line.split('","') | |
src = src[2:] | |
tgt = tgt[:-3] | |
yield self.src_tokenizer(src) if index == 0 else self.tgt_tokenizer(tgt) | |
except ValueError: | |
continue | |
# Define the PositionalEncoding class | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
if d_model % 2 == 1: | |
# For odd d_model, handle the last column | |
pe[:, 1::2] = torch.cos(position * div_term[:-1]) | |
else: | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
x = x + self.pe[:x.size(0)] | |
return self.dropout(x) | |
# Define the TransformerModel class | |
class TransformerModel(nn.Module): | |
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1): | |
super(TransformerModel, self).__init__() | |
self.model_type = 'Transformer' | |
self.src_embedding = nn.Embedding(src_vocab_size, d_model) | |
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model) | |
self.pos_encoder = PositionalEncoding(d_model, dropout) | |
self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout) | |
self.fc_out = nn.Linear(d_model, tgt_vocab_size) | |
self.d_model = d_model | |
self._reset_parameters() | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask): | |
src = self.src_embedding(src) * math.sqrt(self.d_model) | |
tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model) | |
src = self.pos_encoder(src) | |
tgt = self.pos_encoder(tgt) | |
memory = self.transformer( | |
src, tgt, src_mask, tgt_mask, None, | |
src_padding_mask, tgt_padding_mask, memory_key_padding_mask | |
) | |
output = self.fc_out(memory) | |
return output | |
# Translation function | |
def translate(model, src_sentence, src_vocab, tgt_vocab, device, max_len=50): | |
model.eval() | |
src_tokenizer = get_tokenizer('basic_english') | |
src_tokens = [src_vocab["<bos>"]] + [src_vocab[token] for token in src_tokenizer(src_sentence)] + [src_vocab["<eos>"]] | |
src_tensor = torch.LongTensor(src_tokens).unsqueeze(1).to(device) | |
src_mask = torch.zeros((src_tensor.size(0), src_tensor.size(0)), device=device).type(torch.bool) | |
with torch.no_grad(): | |
memory = model.transformer.encoder( | |
model.pos_encoder(model.src_embedding(src_tensor) * math.sqrt(model.d_model)), | |
src_mask | |
) | |
ys = torch.ones(1, 1).fill_(tgt_vocab["<bos>"]).type(torch.long).to(device) | |
for _ in range(max_len-1): | |
tgt_mask = nn.Transformer.generate_square_subsequent_mask(ys.size(0)).to(device) | |
with torch.no_grad(): | |
out = model.transformer.decoder( | |
model.pos_encoder(model.tgt_embedding(ys) * math.sqrt(model.d_model)), | |
memory, | |
tgt_mask | |
) | |
out = model.fc_out(out) | |
prob = out[-1].detach() | |
_, next_word = torch.max(prob, dim=1) | |
next_word = next_word.item() | |
ys = torch.cat([ys, torch.ones(1, 1).type_as(src_tensor.data).fill_(next_word)], dim=0) | |
if next_word == tgt_vocab["<eos>"]: | |
break | |
ys = ys.flatten() | |
translated_tokens = [ | |
tgt_vocab.get_itos()[token] | |
for token in ys | |
if token not in [tgt_vocab["<bos>"], tgt_vocab["<eos>"], tgt_vocab["<pad>"]] | |
] | |
return " ".join(translated_tokens) | |
# Load the model and dataset | |
def load_model_and_data(): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
# Load the dataset (for vocabulary) | |
file_path = 'newcode15M.txt' # Replace with the path to your dataset file | |
dataset = TranslationDataset(file_path) | |
# Model hyperparameters (make sure these match your trained model) | |
SRC_VOCAB_SIZE = len(dataset.src_vocab) | |
TGT_VOCAB_SIZE = len(dataset.tgt_vocab) | |
D_MODEL = 256 | |
NHEAD = 8 | |
NUM_ENCODER_LAYERS = 6 | |
NUM_DECODER_LAYERS = 6 | |
DIM_FEEDFORWARD = 512 | |
DROPOUT = 0.2 | |
# Initialize the model | |
model = TransformerModel( | |
SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, D_MODEL, NHEAD, | |
NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, DIM_FEEDFORWARD, DROPOUT | |
).to(device) | |
# Load the trained model | |
model.load_state_dict(torch.load('AllOneLM.pth', map_location=device)) | |
model.eval() | |
return model, dataset.src_vocab, dataset.tgt_vocab, device | |
# Load model and data | |
model, src_vocab, tgt_vocab, device = load_model_and_data() | |
# Define the translation function for Gradio | |
def translate_sentence(src_sentence): | |
translated_sentence = translate(model, src_sentence, src_vocab, tgt_vocab, device) | |
return translated_sentence | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=translate_sentence, | |
inputs=gr.Textbox(label="Enter a sentence:", lines=2, placeholder="Type here..."), | |
outputs=gr.Textbox(label="Translated:"), | |
title="Translation Talking Script", | |
description="Enter a sentence to translate.", | |
css=css_styles | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch() | |