Spaces:
Running
Running
OOFMAN29803
commited on
Commit
•
a1c209b
1
Parent(s):
1a34d05
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,10 @@ import torch.nn as nn
|
|
6 |
import math
|
7 |
import gradio as gr
|
8 |
|
|
|
|
|
|
|
|
|
9 |
# Define the CSS styles
|
10 |
css_styles = '''
|
11 |
@import url('https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;600;700;800&display=swap');
|
@@ -84,10 +88,14 @@ class PositionalEncoding(nn.Module):
|
|
84 |
self.dropout = nn.Dropout(p=dropout)
|
85 |
|
86 |
pe = torch.zeros(max_len, d_model)
|
87 |
-
position = torch.arange(0, max_len, dtype=torch.
|
88 |
-
div_term = torch.exp(torch.arange(0, d_model, 2
|
89 |
pe[:, 0::2] = torch.sin(position * div_term)
|
90 |
-
|
|
|
|
|
|
|
|
|
91 |
pe = pe.unsqueeze(1)
|
92 |
self.register_buffer('pe', pe)
|
93 |
|
@@ -119,7 +127,10 @@ class TransformerModel(nn.Module):
|
|
119 |
tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
|
120 |
src = self.pos_encoder(src)
|
121 |
tgt = self.pos_encoder(tgt)
|
122 |
-
memory = self.transformer(
|
|
|
|
|
|
|
123 |
output = self.fc_out(memory)
|
124 |
return output
|
125 |
|
@@ -133,15 +144,21 @@ def translate(model, src_sentence, src_vocab, tgt_vocab, device, max_len=50):
|
|
133 |
src_mask = torch.zeros((src_tensor.size(0), src_tensor.size(0)), device=device).type(torch.bool)
|
134 |
|
135 |
with torch.no_grad():
|
136 |
-
memory = model.transformer.encoder(
|
|
|
|
|
|
|
137 |
|
138 |
ys = torch.ones(1, 1).fill_(tgt_vocab["<bos>"]).type(torch.long).to(device)
|
139 |
-
for
|
140 |
tgt_mask = nn.Transformer.generate_square_subsequent_mask(ys.size(0)).to(device)
|
141 |
|
142 |
with torch.no_grad():
|
143 |
-
out = model.transformer.decoder(
|
144 |
-
|
|
|
|
|
|
|
145 |
out = model.fc_out(out)
|
146 |
|
147 |
prob = out[-1].detach()
|
@@ -153,7 +170,11 @@ def translate(model, src_sentence, src_vocab, tgt_vocab, device, max_len=50):
|
|
153 |
break
|
154 |
|
155 |
ys = ys.flatten()
|
156 |
-
translated_tokens = [
|
|
|
|
|
|
|
|
|
157 |
return " ".join(translated_tokens)
|
158 |
|
159 |
# Load the model and dataset
|
@@ -162,7 +183,7 @@ def load_model_and_data():
|
|
162 |
print(f"Using device: {device}")
|
163 |
|
164 |
# Load the dataset (for vocabulary)
|
165 |
-
file_path = '
|
166 |
dataset = TranslationDataset(file_path)
|
167 |
|
168 |
# Model hyperparameters (make sure these match your trained model)
|
@@ -176,10 +197,13 @@ def load_model_and_data():
|
|
176 |
DROPOUT = 0.2
|
177 |
|
178 |
# Initialize the model
|
179 |
-
model = TransformerModel(
|
|
|
|
|
|
|
180 |
|
181 |
# Load the trained model
|
182 |
-
model.load_state_dict(torch.load('
|
183 |
model.eval()
|
184 |
|
185 |
return model, dataset.src_vocab, dataset.tgt_vocab, device
|
@@ -195,8 +219,8 @@ def translate_sentence(src_sentence):
|
|
195 |
# Create Gradio interface
|
196 |
iface = gr.Interface(
|
197 |
fn=translate_sentence,
|
198 |
-
inputs=gr.
|
199 |
-
outputs=gr.
|
200 |
title="Translation Talking Script",
|
201 |
description="Enter a sentence to translate.",
|
202 |
css=css_styles
|
|
|
6 |
import math
|
7 |
import gradio as gr
|
8 |
|
9 |
+
# Suppress torchtext deprecation warnings
|
10 |
+
import torchtext
|
11 |
+
torchtext.disable_torchtext_deprecation_warning()
|
12 |
+
|
13 |
# Define the CSS styles
|
14 |
css_styles = '''
|
15 |
@import url('https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;600;700;800&display=swap');
|
|
|
88 |
self.dropout = nn.Dropout(p=dropout)
|
89 |
|
90 |
pe = torch.zeros(max_len, d_model)
|
91 |
+
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
|
92 |
+
div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
|
93 |
pe[:, 0::2] = torch.sin(position * div_term)
|
94 |
+
if d_model % 2 == 1:
|
95 |
+
# For odd d_model, handle the last column
|
96 |
+
pe[:, 1::2] = torch.cos(position * div_term[:-1])
|
97 |
+
else:
|
98 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
99 |
pe = pe.unsqueeze(1)
|
100 |
self.register_buffer('pe', pe)
|
101 |
|
|
|
127 |
tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
|
128 |
src = self.pos_encoder(src)
|
129 |
tgt = self.pos_encoder(tgt)
|
130 |
+
memory = self.transformer(
|
131 |
+
src, tgt, src_mask, tgt_mask, None,
|
132 |
+
src_padding_mask, tgt_padding_mask, memory_key_padding_mask
|
133 |
+
)
|
134 |
output = self.fc_out(memory)
|
135 |
return output
|
136 |
|
|
|
144 |
src_mask = torch.zeros((src_tensor.size(0), src_tensor.size(0)), device=device).type(torch.bool)
|
145 |
|
146 |
with torch.no_grad():
|
147 |
+
memory = model.transformer.encoder(
|
148 |
+
model.pos_encoder(model.src_embedding(src_tensor) * math.sqrt(model.d_model)),
|
149 |
+
src_mask
|
150 |
+
)
|
151 |
|
152 |
ys = torch.ones(1, 1).fill_(tgt_vocab["<bos>"]).type(torch.long).to(device)
|
153 |
+
for _ in range(max_len-1):
|
154 |
tgt_mask = nn.Transformer.generate_square_subsequent_mask(ys.size(0)).to(device)
|
155 |
|
156 |
with torch.no_grad():
|
157 |
+
out = model.transformer.decoder(
|
158 |
+
model.pos_encoder(model.tgt_embedding(ys) * math.sqrt(model.d_model)),
|
159 |
+
memory,
|
160 |
+
tgt_mask
|
161 |
+
)
|
162 |
out = model.fc_out(out)
|
163 |
|
164 |
prob = out[-1].detach()
|
|
|
170 |
break
|
171 |
|
172 |
ys = ys.flatten()
|
173 |
+
translated_tokens = [
|
174 |
+
tgt_vocab.get_itos()[token]
|
175 |
+
for token in ys
|
176 |
+
if token not in [tgt_vocab["<bos>"], tgt_vocab["<eos>"], tgt_vocab["<pad>"]]
|
177 |
+
]
|
178 |
return " ".join(translated_tokens)
|
179 |
|
180 |
# Load the model and dataset
|
|
|
183 |
print(f"Using device: {device}")
|
184 |
|
185 |
# Load the dataset (for vocabulary)
|
186 |
+
file_path = 'path_to_your_dataset.txt' # Replace with the path to your dataset file
|
187 |
dataset = TranslationDataset(file_path)
|
188 |
|
189 |
# Model hyperparameters (make sure these match your trained model)
|
|
|
197 |
DROPOUT = 0.2
|
198 |
|
199 |
# Initialize the model
|
200 |
+
model = TransformerModel(
|
201 |
+
SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, D_MODEL, NHEAD,
|
202 |
+
NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, DIM_FEEDFORWARD, DROPOUT
|
203 |
+
).to(device)
|
204 |
|
205 |
# Load the trained model
|
206 |
+
model.load_state_dict(torch.load('path_to_your_model.pth', map_location=device))
|
207 |
model.eval()
|
208 |
|
209 |
return model, dataset.src_vocab, dataset.tgt_vocab, device
|
|
|
219 |
# Create Gradio interface
|
220 |
iface = gr.Interface(
|
221 |
fn=translate_sentence,
|
222 |
+
inputs=gr.Textbox(label="Enter a sentence:", lines=2, placeholder="Type here..."),
|
223 |
+
outputs=gr.Textbox(label="Translated:"),
|
224 |
title="Translation Talking Script",
|
225 |
description="Enter a sentence to translate.",
|
226 |
css=css_styles
|