philipp-zettl
commited on
Commit
•
5545230
1
Parent(s):
401978c
Upload folder using huggingface_hub
Browse files- README.md +13 -13
- assets/confusion_matrix_GGU.png +0 -0
- assets/confusion_matrix_sentiment.png +0 -0
- assets/loss_plot_GGU.png +0 -0
- assets/loss_plot_sentiment.png +0 -0
- heads/GGU.pth +1 -1
- heads/sentiment.pth +1 -1
- multi-head-sequence-classification-model-model.pth +1 -1
- train.py +12 -4
README.md
CHANGED
@@ -42,7 +42,7 @@ The model is a simple sequence classification model based on hidden output layer
|
|
42 |
|
43 |
The backbone of the model is BAAI/bge-m3 with 1024.
|
44 |
|
45 |
-
An additional layer of (GGU: 3) is added to the output of the backbone to classify the input sequence.
|
46 |
|
47 |
Using the provided implementation (in repository) of `MultiHeadClassificationTrainer`.
|
48 |
|
@@ -231,20 +231,20 @@ def _eval_model(self, dataloader, label_map, sample_key, label_key):
|
|
231 |
For evaluation, we used the following metrics: accuracy, precision, recall, f1-score. You can find a detailed classification report here:
|
232 |
|
233 |
**GGU:**
|
234 |
-
| | index | precision |
|
235 |
-
|
236 |
-
| 0 | Greeting |
|
237 |
-
| 1 | Gratitude |
|
238 |
-
| 2 | Other |
|
239 |
-
| 3 | macro avg |
|
240 |
-
| 4 | weighted avg |
|
241 |
|
242 |
**sentiment:**
|
243 |
| | index | precision | recall | f1-score | support |
|
244 |
|---:|:-------------|------------:|---------:|-----------:|----------:|
|
245 |
-
| 0 | Positive | 0.
|
246 |
-
| 1 | Negative | 0.
|
247 |
-
| 2 | Neutral | 0.
|
248 |
-
| 3 | macro avg | 0.
|
249 |
-
| 4 | weighted avg | 0.
|
250 |
|
|
|
42 |
|
43 |
The backbone of the model is BAAI/bge-m3 with 1024.
|
44 |
|
45 |
+
An additional layer of (GGU: 3,sentiment: 3) is added to the output of the backbone to classify the input sequence.
|
46 |
|
47 |
Using the provided implementation (in repository) of `MultiHeadClassificationTrainer`.
|
48 |
|
|
|
231 |
For evaluation, we used the following metrics: accuracy, precision, recall, f1-score. You can find a detailed classification report here:
|
232 |
|
233 |
**GGU:**
|
234 |
+
| | index | precision | recall | f1-score | support |
|
235 |
+
|---:|:-------------|------------:|----------:|-----------:|----------:|
|
236 |
+
| 0 | Greeting | 0.0555556 | 0.03125 | 0.04 | 32 |
|
237 |
+
| 1 | Gratitude | 0.320513 | 0.892857 | 0.471698 | 28 |
|
238 |
+
| 2 | Other | 0.111111 | 0.0222222 | 0.037037 | 45 |
|
239 |
+
| 3 | macro avg | 0.162393 | 0.315443 | 0.182912 | 105 |
|
240 |
+
| 4 | weighted avg | 0.15002 | 0.257143 | 0.15385 | 105 |
|
241 |
|
242 |
**sentiment:**
|
243 |
| | index | precision | recall | f1-score | support |
|
244 |
|---:|:-------------|------------:|---------:|-----------:|----------:|
|
245 |
+
| 0 | Positive | 0.653846 | 0.586207 | 0.618182 | 29 |
|
246 |
+
| 1 | Negative | 0.777778 | 0.736842 | 0.756757 | 38 |
|
247 |
+
| 2 | Neutral | 0.72093 | 0.815789 | 0.765432 | 38 |
|
248 |
+
| 3 | macro avg | 0.717518 | 0.712946 | 0.713457 | 105 |
|
249 |
+
| 4 | weighted avg | 0.722976 | 0.72381 | 0.721623 | 105 |
|
250 |
|
assets/confusion_matrix_GGU.png
CHANGED
assets/confusion_matrix_sentiment.png
CHANGED
assets/loss_plot_GGU.png
CHANGED
assets/loss_plot_sentiment.png
CHANGED
heads/GGU.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 7552
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff2b644eeb54e3b01ec332e5979b5e98477b4a29464b8db4e2beb29fe1548f27
|
3 |
size 7552
|
heads/sentiment.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 7652
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:446843989a3b3ca5b766463a54e00f443a928a6a131e7d6e51e5238657b62758
|
3 |
size 7652
|
multi-head-sequence-classification-model-model.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1135701541
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6400c69f1f1efffd158930fc7be49cf7147fe2fba2764bdf3f12d39235060521
|
3 |
size 1135701541
|
train.py
CHANGED
@@ -167,14 +167,19 @@ class MultiHeadClassification(nn.Module):
|
|
167 |
Returns:
|
168 |
None
|
169 |
"""
|
|
|
170 |
if head_name in self.heads:
|
171 |
-
|
172 |
-
self.
|
|
|
|
|
173 |
return
|
174 |
-
|
175 |
assert model['weight'].shape[1] == self.backbone.config.hidden_size
|
176 |
-
|
|
|
177 |
self.heads[head_name].load_state_dict(model)
|
|
|
178 |
|
179 |
self.to(self.torch_dtype).to(self.device)
|
180 |
|
@@ -286,6 +291,7 @@ class MultiHeadClassification(nn.Module):
|
|
286 |
"""
|
287 |
self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes)
|
288 |
self.heads[head_name].to(self.torch_dtype).to(self.device)
|
|
|
289 |
|
290 |
def remove_head(self, head_name):
|
291 |
"""
|
@@ -294,6 +300,7 @@ class MultiHeadClassification(nn.Module):
|
|
294 |
if head_name not in self.heads:
|
295 |
raise ValueError(f'Head {head_name} not found')
|
296 |
del self.heads[head_name]
|
|
|
297 |
|
298 |
@classmethod
|
299 |
def from_pretrained(cls, model_name, head_config=None, dropout=0.1, l2_reg=0.01):
|
@@ -331,6 +338,7 @@ class MultiHeadClassification(nn.Module):
|
|
331 |
backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone.pth'))
|
332 |
instance = cls(backbone, head_config, dropout, l2_reg)
|
333 |
instance.load(os.path.join(model_path, 'pretrained/model.pth'))
|
|
|
334 |
return instance
|
335 |
|
336 |
class MultiHeadClassificationTrainer:
|
|
|
167 |
Returns:
|
168 |
None
|
169 |
"""
|
170 |
+
model = torch.load(path)
|
171 |
if head_name in self.heads:
|
172 |
+
num_classes = model['weight'].shape[0]
|
173 |
+
self.heads[head_name].load_state_dict(model)
|
174 |
+
self.to(self.torch_dtype).to(self.device)
|
175 |
+
self.head_config[head_name] = num_classes
|
176 |
return
|
177 |
+
|
178 |
assert model['weight'].shape[1] == self.backbone.config.hidden_size
|
179 |
+
num_classes = model['weight'].shape[0]
|
180 |
+
self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes)
|
181 |
self.heads[head_name].load_state_dict(model)
|
182 |
+
self.head_config[head_name] = num_classes
|
183 |
|
184 |
self.to(self.torch_dtype).to(self.device)
|
185 |
|
|
|
291 |
"""
|
292 |
self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes)
|
293 |
self.heads[head_name].to(self.torch_dtype).to(self.device)
|
294 |
+
self.head_config[head_name] = num_classes
|
295 |
|
296 |
def remove_head(self, head_name):
|
297 |
"""
|
|
|
300 |
if head_name not in self.heads:
|
301 |
raise ValueError(f'Head {head_name} not found')
|
302 |
del self.heads[head_name]
|
303 |
+
del self.head_config[head_name]
|
304 |
|
305 |
@classmethod
|
306 |
def from_pretrained(cls, model_name, head_config=None, dropout=0.1, l2_reg=0.01):
|
|
|
338 |
backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone.pth'))
|
339 |
instance = cls(backbone, head_config, dropout, l2_reg)
|
340 |
instance.load(os.path.join(model_path, 'pretrained/model.pth'))
|
341 |
+
instance.head_config = {k: v. instance.heads}
|
342 |
return instance
|
343 |
|
344 |
class MultiHeadClassificationTrainer:
|