OOFMAN29803 commited on
Commit
a1c209b
1 Parent(s): 1a34d05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -14
app.py CHANGED
@@ -6,6 +6,10 @@ import torch.nn as nn
6
  import math
7
  import gradio as gr
8
 
 
 
 
 
9
  # Define the CSS styles
10
  css_styles = '''
11
  @import url('https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;600;700;800&display=swap');
@@ -84,10 +88,14 @@ class PositionalEncoding(nn.Module):
84
  self.dropout = nn.Dropout(p=dropout)
85
 
86
  pe = torch.zeros(max_len, d_model)
87
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
88
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
89
  pe[:, 0::2] = torch.sin(position * div_term)
90
- pe[:, 1::2] = torch.cos(position * div_term)
 
 
 
 
91
  pe = pe.unsqueeze(1)
92
  self.register_buffer('pe', pe)
93
 
@@ -119,7 +127,10 @@ class TransformerModel(nn.Module):
119
  tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
120
  src = self.pos_encoder(src)
121
  tgt = self.pos_encoder(tgt)
122
- memory = self.transformer(src, tgt, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
 
 
 
123
  output = self.fc_out(memory)
124
  return output
125
 
@@ -133,15 +144,21 @@ def translate(model, src_sentence, src_vocab, tgt_vocab, device, max_len=50):
133
  src_mask = torch.zeros((src_tensor.size(0), src_tensor.size(0)), device=device).type(torch.bool)
134
 
135
  with torch.no_grad():
136
- memory = model.transformer.encoder(model.pos_encoder(model.src_embedding(src_tensor) * math.sqrt(model.d_model)), src_mask)
 
 
 
137
 
138
  ys = torch.ones(1, 1).fill_(tgt_vocab["<bos>"]).type(torch.long).to(device)
139
- for i in range(max_len-1):
140
  tgt_mask = nn.Transformer.generate_square_subsequent_mask(ys.size(0)).to(device)
141
 
142
  with torch.no_grad():
143
- out = model.transformer.decoder(model.pos_encoder(model.tgt_embedding(ys) * math.sqrt(model.d_model)),
144
- memory, tgt_mask)
 
 
 
145
  out = model.fc_out(out)
146
 
147
  prob = out[-1].detach()
@@ -153,7 +170,11 @@ def translate(model, src_sentence, src_vocab, tgt_vocab, device, max_len=50):
153
  break
154
 
155
  ys = ys.flatten()
156
- translated_tokens = [tgt_vocab.get_itos()[token] for token in ys if token not in [tgt_vocab["<bos>"], tgt_vocab["<eos>"], tgt_vocab["<pad>"]]]
 
 
 
 
157
  return " ".join(translated_tokens)
158
 
159
  # Load the model and dataset
@@ -162,7 +183,7 @@ def load_model_and_data():
162
  print(f"Using device: {device}")
163
 
164
  # Load the dataset (for vocabulary)
165
- file_path = 'newcode15M.txt' # Replace with the path to your dataset file
166
  dataset = TranslationDataset(file_path)
167
 
168
  # Model hyperparameters (make sure these match your trained model)
@@ -176,10 +197,13 @@ def load_model_and_data():
176
  DROPOUT = 0.2
177
 
178
  # Initialize the model
179
- model = TransformerModel(SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, D_MODEL, NHEAD, NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, DIM_FEEDFORWARD, DROPOUT).to(device)
 
 
 
180
 
181
  # Load the trained model
182
- model.load_state_dict(torch.load('AllOneLM.pth', map_location=device))
183
  model.eval()
184
 
185
  return model, dataset.src_vocab, dataset.tgt_vocab, device
@@ -195,8 +219,8 @@ def translate_sentence(src_sentence):
195
  # Create Gradio interface
196
  iface = gr.Interface(
197
  fn=translate_sentence,
198
- inputs=gr.inputs.Textbox(label="Enter a sentence:", lines=2, placeholder="Type here..."),
199
- outputs=gr.outputs.Textbox(label="Translated:"),
200
  title="Translation Talking Script",
201
  description="Enter a sentence to translate.",
202
  css=css_styles
 
6
  import math
7
  import gradio as gr
8
 
9
+ # Suppress torchtext deprecation warnings
10
+ import torchtext
11
+ torchtext.disable_torchtext_deprecation_warning()
12
+
13
  # Define the CSS styles
14
  css_styles = '''
15
  @import url('https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@400;600;700;800&display=swap');
 
88
  self.dropout = nn.Dropout(p=dropout)
89
 
90
  pe = torch.zeros(max_len, d_model)
91
+ position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
92
+ div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
93
  pe[:, 0::2] = torch.sin(position * div_term)
94
+ if d_model % 2 == 1:
95
+ # For odd d_model, handle the last column
96
+ pe[:, 1::2] = torch.cos(position * div_term[:-1])
97
+ else:
98
+ pe[:, 1::2] = torch.cos(position * div_term)
99
  pe = pe.unsqueeze(1)
100
  self.register_buffer('pe', pe)
101
 
 
127
  tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
128
  src = self.pos_encoder(src)
129
  tgt = self.pos_encoder(tgt)
130
+ memory = self.transformer(
131
+ src, tgt, src_mask, tgt_mask, None,
132
+ src_padding_mask, tgt_padding_mask, memory_key_padding_mask
133
+ )
134
  output = self.fc_out(memory)
135
  return output
136
 
 
144
  src_mask = torch.zeros((src_tensor.size(0), src_tensor.size(0)), device=device).type(torch.bool)
145
 
146
  with torch.no_grad():
147
+ memory = model.transformer.encoder(
148
+ model.pos_encoder(model.src_embedding(src_tensor) * math.sqrt(model.d_model)),
149
+ src_mask
150
+ )
151
 
152
  ys = torch.ones(1, 1).fill_(tgt_vocab["<bos>"]).type(torch.long).to(device)
153
+ for _ in range(max_len-1):
154
  tgt_mask = nn.Transformer.generate_square_subsequent_mask(ys.size(0)).to(device)
155
 
156
  with torch.no_grad():
157
+ out = model.transformer.decoder(
158
+ model.pos_encoder(model.tgt_embedding(ys) * math.sqrt(model.d_model)),
159
+ memory,
160
+ tgt_mask
161
+ )
162
  out = model.fc_out(out)
163
 
164
  prob = out[-1].detach()
 
170
  break
171
 
172
  ys = ys.flatten()
173
+ translated_tokens = [
174
+ tgt_vocab.get_itos()[token]
175
+ for token in ys
176
+ if token not in [tgt_vocab["<bos>"], tgt_vocab["<eos>"], tgt_vocab["<pad>"]]
177
+ ]
178
  return " ".join(translated_tokens)
179
 
180
  # Load the model and dataset
 
183
  print(f"Using device: {device}")
184
 
185
  # Load the dataset (for vocabulary)
186
+ file_path = 'path_to_your_dataset.txt' # Replace with the path to your dataset file
187
  dataset = TranslationDataset(file_path)
188
 
189
  # Model hyperparameters (make sure these match your trained model)
 
197
  DROPOUT = 0.2
198
 
199
  # Initialize the model
200
+ model = TransformerModel(
201
+ SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, D_MODEL, NHEAD,
202
+ NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, DIM_FEEDFORWARD, DROPOUT
203
+ ).to(device)
204
 
205
  # Load the trained model
206
+ model.load_state_dict(torch.load('path_to_your_model.pth', map_location=device))
207
  model.eval()
208
 
209
  return model, dataset.src_vocab, dataset.tgt_vocab, device
 
219
  # Create Gradio interface
220
  iface = gr.Interface(
221
  fn=translate_sentence,
222
+ inputs=gr.Textbox(label="Enter a sentence:", lines=2, placeholder="Type here..."),
223
+ outputs=gr.Textbox(label="Translated:"),
224
  title="Translation Talking Script",
225
  description="Enter a sentence to translate.",
226
  css=css_styles