philipp-zettl commited on
Commit
5545230
1 Parent(s): 401978c

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -42,7 +42,7 @@ The model is a simple sequence classification model based on hidden output layer
42
 
43
  The backbone of the model is BAAI/bge-m3 with 1024.
44
 
45
- An additional layer of (GGU: 3) is added to the output of the backbone to classify the input sequence.
46
 
47
  Using the provided implementation (in repository) of `MultiHeadClassificationTrainer`.
48
 
@@ -231,20 +231,20 @@ def _eval_model(self, dataloader, label_map, sample_key, label_key):
231
  For evaluation, we used the following metrics: accuracy, precision, recall, f1-score. You can find a detailed classification report here:
232
 
233
  **GGU:**
234
- | | index | precision | recall | f1-score | support |
235
- |---:|:-------------|------------:|---------:|-----------:|----------:|
236
- | 0 | Greeting | 0.278481 | 0.709677 | 0.4 | 31 |
237
- | 1 | Gratitude | 0.428571 | 0.176471 | 0.25 | 34 |
238
- | 2 | Other | 0.25 | 0.075 | 0.115385 | 40 |
239
- | 3 | macro avg | 0.319017 | 0.320383 | 0.255128 | 105 |
240
- | 4 | weighted avg | 0.316232 | 0.295238 | 0.243004 | 105 |
241
 
242
  **sentiment:**
243
  | | index | precision | recall | f1-score | support |
244
  |---:|:-------------|------------:|---------:|-----------:|----------:|
245
- | 0 | Positive | 0.568182 | 0.714286 | 0.632911 | 35 |
246
- | 1 | Negative | 0.605263 | 0.821429 | 0.69697 | 28 |
247
- | 2 | Neutral | 0.869565 | 0.47619 | 0.615385 | 42 |
248
- | 3 | macro avg | 0.681003 | 0.670635 | 0.648422 | 105 |
249
- | 4 | weighted avg | 0.698624 | 0.647619 | 0.642983 | 105 |
250
 
 
42
 
43
  The backbone of the model is BAAI/bge-m3 with 1024.
44
 
45
+ An additional layer of (GGU: 3,sentiment: 3) is added to the output of the backbone to classify the input sequence.
46
 
47
  Using the provided implementation (in repository) of `MultiHeadClassificationTrainer`.
48
 
 
231
  For evaluation, we used the following metrics: accuracy, precision, recall, f1-score. You can find a detailed classification report here:
232
 
233
  **GGU:**
234
+ | | index | precision | recall | f1-score | support |
235
+ |---:|:-------------|------------:|----------:|-----------:|----------:|
236
+ | 0 | Greeting | 0.0555556 | 0.03125 | 0.04 | 32 |
237
+ | 1 | Gratitude | 0.320513 | 0.892857 | 0.471698 | 28 |
238
+ | 2 | Other | 0.111111 | 0.0222222 | 0.037037 | 45 |
239
+ | 3 | macro avg | 0.162393 | 0.315443 | 0.182912 | 105 |
240
+ | 4 | weighted avg | 0.15002 | 0.257143 | 0.15385 | 105 |
241
 
242
  **sentiment:**
243
  | | index | precision | recall | f1-score | support |
244
  |---:|:-------------|------------:|---------:|-----------:|----------:|
245
+ | 0 | Positive | 0.653846 | 0.586207 | 0.618182 | 29 |
246
+ | 1 | Negative | 0.777778 | 0.736842 | 0.756757 | 38 |
247
+ | 2 | Neutral | 0.72093 | 0.815789 | 0.765432 | 38 |
248
+ | 3 | macro avg | 0.717518 | 0.712946 | 0.713457 | 105 |
249
+ | 4 | weighted avg | 0.722976 | 0.72381 | 0.721623 | 105 |
250
 
assets/confusion_matrix_GGU.png CHANGED
assets/confusion_matrix_sentiment.png CHANGED
assets/loss_plot_GGU.png CHANGED
assets/loss_plot_sentiment.png CHANGED
heads/GGU.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7addbfb97d15aa2703e981078bd21c32b5f2d1783d3b7227e412301eff6f796e
3
  size 7552
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff2b644eeb54e3b01ec332e5979b5e98477b4a29464b8db4e2beb29fe1548f27
3
  size 7552
heads/sentiment.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:964c8089b6a5252343f3d372a8ffb81a545d2c83ba7d00a5a567ff07083e5118
3
  size 7652
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:446843989a3b3ca5b766463a54e00f443a928a6a131e7d6e51e5238657b62758
3
  size 7652
multi-head-sequence-classification-model-model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:23318136d699081f48df309d2daa21a9b5bcd6381f483c6682f609597a094f6a
3
  size 1135701541
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6400c69f1f1efffd158930fc7be49cf7147fe2fba2764bdf3f12d39235060521
3
  size 1135701541
train.py CHANGED
@@ -167,14 +167,19 @@ class MultiHeadClassification(nn.Module):
167
  Returns:
168
  None
169
  """
 
170
  if head_name in self.heads:
171
- self.heads[head_name].load_state_dict(torch.load(path))
172
- self.to(self.device)
 
 
173
  return
174
- model = torch.load(path)
175
  assert model['weight'].shape[1] == self.backbone.config.hidden_size
176
- self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, model['weight'].shape[0])
 
177
  self.heads[head_name].load_state_dict(model)
 
178
 
179
  self.to(self.torch_dtype).to(self.device)
180
 
@@ -286,6 +291,7 @@ class MultiHeadClassification(nn.Module):
286
  """
287
  self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes)
288
  self.heads[head_name].to(self.torch_dtype).to(self.device)
 
289
 
290
  def remove_head(self, head_name):
291
  """
@@ -294,6 +300,7 @@ class MultiHeadClassification(nn.Module):
294
  if head_name not in self.heads:
295
  raise ValueError(f'Head {head_name} not found')
296
  del self.heads[head_name]
 
297
 
298
  @classmethod
299
  def from_pretrained(cls, model_name, head_config=None, dropout=0.1, l2_reg=0.01):
@@ -331,6 +338,7 @@ class MultiHeadClassification(nn.Module):
331
  backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone.pth'))
332
  instance = cls(backbone, head_config, dropout, l2_reg)
333
  instance.load(os.path.join(model_path, 'pretrained/model.pth'))
 
334
  return instance
335
 
336
  class MultiHeadClassificationTrainer:
 
167
  Returns:
168
  None
169
  """
170
+ model = torch.load(path)
171
  if head_name in self.heads:
172
+ num_classes = model['weight'].shape[0]
173
+ self.heads[head_name].load_state_dict(model)
174
+ self.to(self.torch_dtype).to(self.device)
175
+ self.head_config[head_name] = num_classes
176
  return
177
+
178
  assert model['weight'].shape[1] == self.backbone.config.hidden_size
179
+ num_classes = model['weight'].shape[0]
180
+ self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes)
181
  self.heads[head_name].load_state_dict(model)
182
+ self.head_config[head_name] = num_classes
183
 
184
  self.to(self.torch_dtype).to(self.device)
185
 
 
291
  """
292
  self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes)
293
  self.heads[head_name].to(self.torch_dtype).to(self.device)
294
+ self.head_config[head_name] = num_classes
295
 
296
  def remove_head(self, head_name):
297
  """
 
300
  if head_name not in self.heads:
301
  raise ValueError(f'Head {head_name} not found')
302
  del self.heads[head_name]
303
+ del self.head_config[head_name]
304
 
305
  @classmethod
306
  def from_pretrained(cls, model_name, head_config=None, dropout=0.1, l2_reg=0.01):
 
338
  backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone.pth'))
339
  instance = cls(backbone, head_config, dropout, l2_reg)
340
  instance.load(os.path.join(model_path, 'pretrained/model.pth'))
341
+ instance.head_config = {k: v. instance.heads}
342
  return instance
343
 
344
  class MultiHeadClassificationTrainer: