Spaces:
Running
Running
import torch | |
from torch.utils.data import Dataset | |
from torchtext.data.utils import get_tokenizer | |
from torchtext.vocab import build_vocab_from_iterator | |
import torch.nn as nn | |
import math | |
import gradio as gr | |
# Step 1: Define the Dataset Class (for loading vocabularies) | |
class TranslationDataset(Dataset): | |
def __init__(self, file_path, src_vocab=None, tgt_vocab=None): | |
self.data = [] | |
self.src_tokenizer = get_tokenizer('basic_english') | |
self.tgt_tokenizer = get_tokenizer('basic_english') | |
with open(file_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
line = line.strip() | |
if line: | |
try: | |
src, tgt = line.split('", "') | |
src = src[2:] | |
tgt = tgt[:-3] | |
self.data.append((src, tgt)) | |
except ValueError: | |
print(f"Skipping malformed line: {line}") | |
if src_vocab is None or tgt_vocab is None: | |
self.src_vocab = build_vocab_from_iterator(self._yield_tokens(self.data, 0), specials=["<unk>", "<pad>", "<bos>", "<eos>"]) | |
self.tgt_vocab = build_vocab_from_iterator(self._yield_tokens(self.data, 1), specials=["<unk>", "<pad>", "<bos>", "<eos>"]) | |
else: | |
self.src_vocab = src_vocab | |
self.tgt_vocab = tgt_vocab | |
self.src_vocab.set_default_index(self.src_vocab["<unk>"]) | |
self.tgt_vocab.set_default_index(self.tgt_vocab["<unk>"]) | |
def _yield_tokens(self, data, index): | |
for src, tgt in data: | |
yield self.src_tokenizer(src) if index == 0 else self.tgt_tokenizer(tgt) | |
# Step 2: Define the Transformer Model | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
x = x + self.pe[:x.size(0), :] | |
return self.dropout(x) | |
class TransformerModel(nn.Module): | |
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1): | |
super(TransformerModel, self).__init__() | |
self.model_type = 'Transformer' | |
self.src_embedding = nn.Embedding(src_vocab_size, d_model) | |
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model) | |
self.pos_encoder = PositionalEncoding(d_model, dropout) | |
self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout) | |
self.fc_out = nn.Linear(d_model, tgt_vocab_size) | |
self.d_model = d_model | |
self._reset_parameters() | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask): | |
src = self.src_embedding(src) * math.sqrt(self.d_model) | |
tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model) | |
src = self.pos_encoder(src) | |
tgt = self.pos_encoder(tgt) | |
memory = self.transformer(src, tgt, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask) | |
output = self.fc_out(memory) | |
return output | |
# Function to load the trained model | |
def load_model(model_path, src_vocab_size, tgt_vocab_size): | |
model = TransformerModel(src_vocab_size, tgt_vocab_size) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
return model | |
# Function to translate a sentence | |
def translate(model, src_sentence, dataset, device, max_len=50): | |
model.eval() | |
src_tokens = [dataset.src_vocab["<bos>"]] + [dataset.src_vocab[token] for token in dataset.src_tokenizer(src_sentence)] + [dataset.src_vocab["<eos>"]] | |
src_tensor = torch.LongTensor(src_tokens).unsqueeze(1).to(device) | |
src_mask = torch.zeros((src_tensor.size(0), src_tensor.size(0)), device=device).type(torch.bool) | |
memory = model.transformer.encoder(model.pos_encoder(model.src_embedding(src_tensor) * math.sqrt(model.d_model)), src_mask) | |
ys = torch.ones(1, 1).fill_(dataset.tgt_vocab["<bos>"]).type(torch.long).to(device) | |
for i in range(max_len-1): | |
tgt_mask = nn.Transformer.generate_square_subsequent_mask(ys.size(0)).to(device) | |
out = model.transformer.decoder(model.pos_encoder(model.tgt_embedding(ys) * math.sqrt(model.d_model)), | |
memory, tgt_mask) | |
out = model.fc_out(out) | |
prob = out[-1] | |
_, next_word = torch.max(prob, dim=1) | |
next_word = next_word.item() | |
ys = torch.cat([ys, torch.ones(1, 1).type_as(src_tensor.data).fill_(next_word)], dim=0) | |
if next_word == dataset.tgt_vocab["<eos>"]: | |
break | |
ys = ys.flatten() | |
translated_tokens = [dataset.tgt_vocab.get_itos()[token] for token in ys if token not in [dataset.tgt_vocab["<bos>"], dataset.tgt_vocab["<eos>"], dataset.tgt_vocab["<pad>"]]] | |
return " ".join(translated_tokens) | |
# Gradio Interface | |
def translate_with_model(src_sentence): | |
translation = translate(model, src_sentence, dataset, device) | |
return translation | |
# Load dataset (which contains vocabularies) | |
file_path = 'data.txt' # Make sure this points to your data file | |
dataset = TranslationDataset(file_path) | |
# Load the trained model | |
model_path = 'SolconLM12B.pth' | |
model = load_model(model_path, len(dataset.src_vocab), len(dataset.tgt_vocab)) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
# Gradio interface | |
iface = gr.Interface(fn=translate_with_model, inputs="text", outputs="text", title="Solocon12B Model", description="Enter Question") | |
if __name__ == "__main__": | |
iface.launch() |