philipp-zettl commited on
Commit
95fcb47
1 Parent(s): 5545230

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -7,6 +7,7 @@ tags: []
7
  base_model: BAAI/bge-m3
8
  datasets:
9
  - philipp-zettl/GGU-xx
 
10
  metrics:
11
 
12
  - accuracy
@@ -121,7 +122,6 @@ def train_classifier():
121
  momentum = 0.9
122
  l2_reg = 0.25
123
 
124
- num_epochs = 2
125
  l2_loss_weight = 0.25
126
 
127
  model_conf = {
@@ -149,7 +149,7 @@ def train_classifier():
149
  model_conf=model_conf,
150
  optimizer_conf={**optimizer_conf, 'lr': 1e-4},
151
  scheduler_conf=scheduler_conf,
152
- num_epochs=num_epochs,
153
  l2_loss_weight=l2_loss_weight,
154
  use_lr_scheduler=True,
155
  train_run=train_run,
@@ -167,10 +167,10 @@ def train_classifier():
167
  }
168
 
169
  trainer.classifier.add_head('sentiment', 3)
170
- new_model, history = trainer.train(dataset_name='philipp-zettl/GGU-xx', target_heads=['sentiment'], num_epochs=5)
171
  metrics = history['metrics']
172
  history['loss_plot'] = trainer._plot_history(**metrics)
173
- res = trainer.eval({'sentiment': sentiment_label_map})
174
  history['evaluation'] = res['sentiment']
175
 
176
  total_history['sentiment'] = deepcopy(history)
@@ -231,20 +231,20 @@ def _eval_model(self, dataloader, label_map, sample_key, label_key):
231
  For evaluation, we used the following metrics: accuracy, precision, recall, f1-score. You can find a detailed classification report here:
232
 
233
  **GGU:**
234
- | | index | precision | recall | f1-score | support |
235
- |---:|:-------------|------------:|----------:|-----------:|----------:|
236
- | 0 | Greeting | 0.0555556 | 0.03125 | 0.04 | 32 |
237
- | 1 | Gratitude | 0.320513 | 0.892857 | 0.471698 | 28 |
238
- | 2 | Other | 0.111111 | 0.0222222 | 0.037037 | 45 |
239
- | 3 | macro avg | 0.162393 | 0.315443 | 0.182912 | 105 |
240
- | 4 | weighted avg | 0.15002 | 0.257143 | 0.15385 | 105 |
241
 
242
  **sentiment:**
243
  | | index | precision | recall | f1-score | support |
244
  |---:|:-------------|------------:|---------:|-----------:|----------:|
245
- | 0 | Positive | 0.653846 | 0.586207 | 0.618182 | 29 |
246
- | 1 | Negative | 0.777778 | 0.736842 | 0.756757 | 38 |
247
- | 2 | Neutral | 0.72093 | 0.815789 | 0.765432 | 38 |
248
- | 3 | macro avg | 0.717518 | 0.712946 | 0.713457 | 105 |
249
- | 4 | weighted avg | 0.722976 | 0.72381 | 0.721623 | 105 |
250
 
 
7
  base_model: BAAI/bge-m3
8
  datasets:
9
  - philipp-zettl/GGU-xx
10
+ - philipp-zettl/sentiment
11
  metrics:
12
 
13
  - accuracy
 
122
  momentum = 0.9
123
  l2_reg = 0.25
124
 
 
125
  l2_loss_weight = 0.25
126
 
127
  model_conf = {
 
149
  model_conf=model_conf,
150
  optimizer_conf={**optimizer_conf, 'lr': 1e-4},
151
  scheduler_conf=scheduler_conf,
152
+ num_epochs=35,
153
  l2_loss_weight=l2_loss_weight,
154
  use_lr_scheduler=True,
155
  train_run=train_run,
 
167
  }
168
 
169
  trainer.classifier.add_head('sentiment', 3)
170
+ new_model, history = trainer.train(dataset_name='philipp-zettl/sentiment', target_heads=['sentiment'], num_epochs=5, sample_key='text')
171
  metrics = history['metrics']
172
  history['loss_plot'] = trainer._plot_history(**metrics)
173
+ res = trainer.eval({'sentiment': sentiment_label_map}, sample_key='text')
174
  history['evaluation'] = res['sentiment']
175
 
176
  total_history['sentiment'] = deepcopy(history)
 
231
  For evaluation, we used the following metrics: accuracy, precision, recall, f1-score. You can find a detailed classification report here:
232
 
233
  **GGU:**
234
+ | | index | precision | recall | f1-score | support |
235
+ |---:|:-------------|------------:|---------:|-----------:|----------:|
236
+ | 0 | Greeting | 0.861111 | 0.837838 | 0.849315 | 37 |
237
+ | 1 | Gratitude | 0.911765 | 0.885714 | 0.898551 | 35 |
238
+ | 2 | Other | 0.914286 | 0.969697 | 0.941176 | 33 |
239
+ | 3 | macro avg | 0.895721 | 0.89775 | 0.896347 | 105 |
240
+ | 4 | weighted avg | 0.894708 | 0.895238 | 0.894598 | 105 |
241
 
242
  **sentiment:**
243
  | | index | precision | recall | f1-score | support |
244
  |---:|:-------------|------------:|---------:|-----------:|----------:|
245
+ | 0 | Positive | 0.789794 | 0.848009 | 0.817867 | 12685 |
246
+ | 1 | Negative | 0.793988 | 0.834127 | 0.813563 | 14282 |
247
+ | 2 | Neutral | 0.789069 | 0.690309 | 0.736393 | 13239 |
248
+ | 3 | macro avg | 0.790951 | 0.790815 | 0.789274 | 40206 |
249
+ | 4 | weighted avg | 0.791045 | 0.791151 | 0.78951 | 40206 |
250
 
assets/confusion_matrix_GGU.png CHANGED
assets/confusion_matrix_sentiment.png CHANGED
assets/loss_plot_GGU.png CHANGED
assets/loss_plot_sentiment.png CHANGED
heads/GGU.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ff2b644eeb54e3b01ec332e5979b5e98477b4a29464b8db4e2beb29fe1548f27
3
  size 7552
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f0406ada1c1c023a0d2943a98d20f5d0aa1823444dc01fef2c963f1259be4b7
3
  size 7552
heads/sentiment.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:446843989a3b3ca5b766463a54e00f443a928a6a131e7d6e51e5238657b62758
3
  size 7652
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c484582887addd315ba9670af09ee288cd30ebde2f47c0fdbd5f5b23f5ef2721
3
  size 7652
model.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Description: Classification models
2
+ from transformers import AutoModel, AutoTokenizer, BatchEncoding, TrainingArguments, Trainer
3
+ from functools import partial
4
+ from huggingface_hub import snapshot_download
5
+ from huggingface_hub.constants import HF_HUB_CACHE
6
+ from accelerate import Accelerator
7
+ from accelerate.utils import find_executable_batch_size as auto_find_batch_size
8
+ from datasets import load_dataset, Dataset
9
+ from torch.utils.data import DataLoader
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ import numpy as np
14
+ import json
15
+ import os
16
+ from tqdm import tqdm
17
+ import pandas as pd
18
+
19
+ import matplotlib.pyplot as plt
20
+ from sklearn.metrics import (
21
+ ConfusionMatrixDisplay,
22
+ accuracy_score,
23
+ classification_report,
24
+ confusion_matrix,
25
+ f1_score,
26
+ recall_score
27
+ )
28
+
29
+ BASE_PATH = os.path.dirname(os.path.abspath(__file__))
30
+
31
+
32
+ class MultiHeadClassification(nn.Module):
33
+ """
34
+ MultiHeadClassification
35
+
36
+ An easy to use multi-head classification model. It takes a backbone model and a dictionary of head configurations.
37
+ It can be used to train multiple classification tasks at once using a single backbone model.
38
+
39
+ Apart from joint training, it also supports training individual heads separately, providing a simple way to freeze
40
+ and unfreeze heads.
41
+
42
+ Example:
43
+ >>> from transformers import AutoModel, AutoTokenizer
44
+ >>> from torch.optim import AdamW
45
+ >>> import torch
46
+ >>> import time
47
+ >>> import torch.nn as nn
48
+ >>>
49
+ >>> # Manually load backbone model to create model
50
+ >>> backbone = AutoModel.from_pretrained('BAAI/bge-m3')
51
+ >>> model = MultiHeadClassification(backbone, {'binary': 2, 'sentiment': 3, 'something': 4}).to('cuda')
52
+ >>> print(model)
53
+ >>> # Load tokenizer for data preprocessing
54
+ >>> tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-m3')
55
+ >>> # some training data
56
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt", padding=True, truncation=True)
57
+ >>> optimizer = AdamW(model.parameters(), lr=5e-4)
58
+ >>> samples = tokenizer(["Hello, my dog is cute", "Hello, my dog is cute", "I like turtles"], return_tensors="pt", padding=True, truncation=True).to('cuda')
59
+ >>> labels = {'binary': torch.tensor([0, 0, 1]), 'sentiment': torch.tensor([0, 1, 2]), 'something': torch.tensor([0, 1, 2])}
60
+ >>> model.freeze_backbone()
61
+ >>> model.train(True)
62
+ >>> for i in range(10):
63
+ ... optimizer.zero_grad()
64
+ ... outputs = model(samples)
65
+ ... loss = sum([nn.CrossEntropyLoss()(outputs[name].cpu(), labels[name]) for name in model.heads.keys()])
66
+ ... loss.backward()
67
+ ... optimizer.step()
68
+ ... print(loss.item())
69
+ ... #time.sleep(1)
70
+ ... print(model(samples))
71
+ >>> # Save full model
72
+ >>> model.save('model.pth')
73
+ >>> # Save head only
74
+ >>> model.save_head('binary', 'binary.pth')
75
+ >>> # Load full model
76
+ >>> model = MultiHeadClassification(backbone, {}).to('cuda')
77
+ >>> model.load('model.pth')
78
+ >>> # Load head only
79
+ >>> model = MultiHeadClassification(backbone, {}).to('cuda')
80
+ >>> model.load_head('binary', 'binary.pth')
81
+ >>> # Adding new head
82
+ >>> model.add_head('new_head', 3)
83
+ >>> print(model)
84
+ >>> # extend dataset with data for new head
85
+ >>> labels['new_head'] = torch.tensor([0, 1, 2])
86
+ >>> # Freeze all heads and backbone
87
+ >>> model.freeze_all()
88
+ >>> # Only unfreeze new head
89
+ >>> model.unfreeze_head('new_head')
90
+ >>> model.train(True)
91
+ >>> for i in range(10):
92
+ ... optimizer.zero_grad()
93
+ ... outputs = model(samples)
94
+ ... loss = sum([nn.CrossEntropyLoss()(outputs[name].cpu(), labels[name]) for name in model.heads.keys()])
95
+ ... loss.backward()
96
+ ... optimizer.step()
97
+ ... print(loss.item())
98
+ >>> print(model(samples))
99
+
100
+ Args:
101
+ backbone (transformers.PreTrainedModel): A pretrained transformer model
102
+ head_config (dict): A dictionary with head configurations. The key is the head name and the value is the number
103
+ of classes for that head.
104
+ """
105
+ def __init__(self, backbone, head_config, dropout=0.1, l2_reg=0.01):
106
+ super().__init__()
107
+ self.backbone = backbone
108
+ self.num_heads = len(head_config)
109
+ self.heads = nn.ModuleDict({
110
+ name: nn.Linear(backbone.config.hidden_size, num_classes)
111
+ for name, num_classes in head_config.items()
112
+ })
113
+ self.do = nn.Dropout(dropout)
114
+ self.l2_reg = l2_reg
115
+ self.device = 'cpu'
116
+ self.torch_dtype = torch.float16
117
+ self.head_config = head_config
118
+
119
+ def forward(self, x, head_names=None) -> dict:
120
+ """
121
+ Forward pass of the model.
122
+
123
+ Requires tokenizer output as input. The input should be a dictionary with keys 'input_ids', 'attention_mask'.
124
+
125
+ Args:
126
+ x (dict): Tokenizer output
127
+ head_names (list): (optional) List of head names to return logits for. If None, returns logits for all heads.
128
+
129
+ Returns:
130
+ dict: A dictionary with head names as keys and logits as values
131
+ """
132
+ x = self.backbone(**x, return_dict=True, output_hidden_states=True).last_hidden_state[:, 0, :]
133
+ x = self.do(x)
134
+ if head_names is None:
135
+ return {name: head(x) for name, head in self.heads.items()}
136
+ return {name: head(x) for name, head in self.heads.items() if name in head_names}
137
+
138
+ def get_l2_loss(self):
139
+ """
140
+ Getter for L2 regularization loss
141
+
142
+ Returns:
143
+ torch.Tensor: L2 regularization loss
144
+ """
145
+ l2_loss = torch.tensor(0.).to(self.device)
146
+ for param in self.parameters():
147
+ if param.requires_grad:
148
+ l2_loss += torch.norm(param, 2)
149
+ return (self.l2_reg * l2_loss).to(self.device)
150
+
151
+ def to(self, *args, **kwargs):
152
+ super().to(*args, **kwargs)
153
+ if isinstance(args[0], torch.dtype):
154
+ self.torch_dtype = args[0]
155
+ elif isinstance(args[0], str):
156
+ self.device = args[0]
157
+ return self
158
+
159
+ def load_head(self, head_name, path):
160
+ """
161
+ Load head from a file
162
+
163
+ Args:
164
+ head_name (str): Name of the head
165
+ path (str): Path to the file
166
+
167
+ Returns:
168
+ None
169
+ """
170
+ model = torch.load(path)
171
+ if head_name in self.heads:
172
+ num_classes = model['weight'].shape[0]
173
+ self.heads[head_name].load_state_dict(model)
174
+ self.to(self.torch_dtype).to(self.device)
175
+ self.head_config[head_name] = num_classes
176
+ return
177
+
178
+ assert model['weight'].shape[1] == self.backbone.config.hidden_size
179
+ num_classes = model['weight'].shape[0]
180
+ self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes)
181
+ self.heads[head_name].load_state_dict(model)
182
+ self.head_config[head_name] = num_classes
183
+
184
+ self.to(self.torch_dtype).to(self.device)
185
+
186
+ def save_head(self, head_name, path):
187
+ """
188
+ Save head to a file
189
+
190
+ Args:
191
+ head_name (str): Name of the head
192
+ path (str): Path to the file
193
+ """
194
+ torch.save(self.heads[head_name].state_dict(), path)
195
+
196
+ def save(self, path):
197
+ """
198
+ Save the full model to a file
199
+
200
+ Args:
201
+ path (str): Path to the file
202
+ """
203
+ torch.save(self.state_dict(), path)
204
+
205
+ def load(self, path):
206
+ """
207
+ Load the full model from a file
208
+
209
+ Args:
210
+ path (str): Path to the file
211
+ """
212
+ self.load_state_dict(torch.load(path))
213
+ self.to(self.torch_dtype).to(self.device)
214
+
215
+ def save_backbone(self, path):
216
+ """
217
+ Save the backbone to a file
218
+
219
+ Args:
220
+ path (str): Path to the file
221
+ """
222
+ self.backbone.save_pretrained(path)
223
+
224
+ def load_backbone(self, path):
225
+ """
226
+ Load the backbone from a file
227
+
228
+ Args:
229
+ path (str): Path to the file
230
+ """
231
+ self.backbone = AutoModel.from_pretrained(path)
232
+ self.to(self.torch_dtype).to(self.device)
233
+
234
+ def freeze_backbone(self):
235
+ """ Freeze the backbone """
236
+ for param in self.backbone.parameters():
237
+ param.requires_grad = False
238
+
239
+ def unfreeze_backbone(self):
240
+ """ Unfreeze the backbone """
241
+ for param in self.backbone.parameters():
242
+ param.requires_grad = True
243
+
244
+ def freeze_head(self, head_name):
245
+ """
246
+ Freeze a head by name
247
+
248
+ Args:
249
+ head_name (str): Name of the head
250
+ """
251
+ for param in self.heads[head_name].parameters():
252
+ param.requires_grad = False
253
+
254
+ def unfreeze_head(self, head_name):
255
+ """
256
+ Unfreeze a head by name
257
+
258
+ Args:
259
+ head_name (str): Name of the head
260
+ """
261
+ for param in self.heads[head_name].parameters():
262
+ param.requires_grad = True
263
+
264
+ def freeze_all_heads(self):
265
+ """ Freeze all heads """
266
+ for head_name in self.heads.keys():
267
+ self.freeze_head(head_name)
268
+
269
+ def unfreeze_all_heads(self):
270
+ """ Unfreeze all heads """
271
+ for head_name in self.heads.keys():
272
+ self.unfreeze_head(head_name)
273
+
274
+ def freeze_all(self):
275
+ """ Freeze all """
276
+ self.freeze_backbone()
277
+ self.freeze_all_heads()
278
+
279
+ def unfreeze_all(self):
280
+ """ Unfreeze all """
281
+ self.unfreeze_backbone()
282
+ self.unfreeze_all_heads()
283
+
284
+ def add_head(self, head_name, num_classes):
285
+ """
286
+ Add a new head to the model
287
+
288
+ Args:
289
+ head_name (str): Name of the head
290
+ num_classes (int): Number of classes for the head
291
+ """
292
+ self.heads[head_name] = nn.Linear(self.backbone.config.hidden_size, num_classes)
293
+ self.heads[head_name].to(self.torch_dtype).to(self.device)
294
+ self.head_config[head_name] = num_classes
295
+
296
+ def remove_head(self, head_name):
297
+ """
298
+ Remove a head from the model
299
+ """
300
+ if head_name not in self.heads:
301
+ raise ValueError(f'Head {head_name} not found')
302
+ del self.heads[head_name]
303
+ del self.head_config[head_name]
304
+
305
+ @classmethod
306
+ def from_pretrained(cls, model_name, head_config=None, dropout=0.1, l2_reg=0.01):
307
+ """
308
+ Load a pretrained model from Huggingface model hub
309
+
310
+ Args:
311
+ model_name (str): Name of the model
312
+ head_config (dict): Head configuration
313
+ dropout (float): Dropout rate
314
+ l2_reg (float): L2 regularization rate
315
+ """
316
+ if head_config is None:
317
+ head_config = {}
318
+ # check if model exists locally
319
+ hf_cache_dir = HF_HUB_CACHE
320
+ model_path = os.path.join(hf_cache_dir, model_name)
321
+ if os.path.exists(model_path):
322
+ return cls._from_directory(model_path, head_config, dropout, l2_reg)
323
+
324
+ model_path = snapshot_download(repo_id=model_name, cache_dir=hf_cache_dir)
325
+ return cls._from_directory(model_path, head_config, dropout, l2_reg)
326
+
327
+ @classmethod
328
+ def _from_directory(cls, model_path, head_config, dropout=0.1, l2_reg=0.01):
329
+ """
330
+ Load a model from a directory
331
+
332
+ Args:
333
+ model_path (str): Path to the model directory
334
+ head_config (dict): Head configuration
335
+ dropout (float): Dropout rate
336
+ l2_reg (float): L2 regularization rate
337
+ """
338
+ backbone = AutoModel.from_pretrained(os.path.join(model_path, 'pretrained/backbone.pth'))
339
+ instance = cls(backbone, head_config, dropout, l2_reg)
340
+ instance.load(os.path.join(model_path, 'pretrained/model.pth'))
341
+ instance.head_config = {k: v. instance.heads}
342
+ return instance
343
+
344
+ class MultiHeadClassificationTrainer:
345
+ def __init__(self, **kwargs):
346
+ self.model_conf = kwargs.get('model_conf', {})
347
+ self.optimizer_conf = kwargs.get('optimizer_conf', {})
348
+ self.scheduler_conf = kwargs.get('scheduler_conf', {})
349
+ self.dropout = kwargs.get('dropout', 0.1)
350
+ self.l2_loss_weight = kwargs.get('l2_loss_weight', 0.01)
351
+ self.num_epochs = kwargs.get('num_epochs', 100)
352
+ self.device = kwargs.get('device', 'cuda')
353
+ self.train_run = kwargs.get('train_run', 0)
354
+ self.name_prefix = kwargs.get('name_prefix', 'multihead-classification')
355
+ self.use_lr_scheduler = kwargs.get('use_lr_scheduler', True)
356
+ self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
357
+ self.batch_size = kwargs.get('batch_size', 4)
358
+ self.train_test_split = kwargs.get('train_test_split', 0.2)
359
+ self.load_best = kwargs.get('load_best', True)
360
+ self.auto_find_batch_size = kwargs.get('auto_find_batch_size', False)
361
+ self.test_data = None
362
+ self.accelerator = Accelerator()
363
+
364
+ self.classifier = MultiHeadClassification(
365
+ **self.model_conf
366
+ ).to(torch.float16)
367
+ self.classifier.freeze_backbone()
368
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_conf.get('tokenizer', self.classifier.backbone.name_or_path), model_max_length=128)
369
+
370
+ def _batch_data(self, batch_size, data):
371
+ return DataLoader(data, shuffle=True, batch_size=batch_size)
372
+
373
+ def train(self, dataset_name: str = None, train_data: DataLoader = None, val_data: DataLoader = None, lr: float = None, num_epochs: int = None, target_heads: list[str] = None, batch_size: int = 4, sample_key=None, label_key=None):
374
+ has_dataset = train_data is not None
375
+ assert (dataset_name is not None and not has_dataset) or (has_dataset and dataset_name is None), 'Must provide either dataset or dataset_name'
376
+ if dataset_name is not None:
377
+ assert target_heads is not None, 'target_heads must be provided when using dataset_name'
378
+
379
+ if sample_key is None:
380
+ sample_key = 'sample'
381
+ if label_key is None:
382
+ label_key = 'label'
383
+
384
+ self.accelerator.free_memory()
385
+ self.classifier = self.accelerator.prepare(self.classifier)
386
+
387
+ if dataset_name is not None:
388
+ dataset = load_dataset(dataset_name)['train'].train_test_split(test_size=self.train_test_split)
389
+ train_data = dataset['train']
390
+ val_data = dataset['test'].train_test_split(test_size=0.5)
391
+ self.test_data = val_data['test']
392
+ val_data = val_data['train']
393
+
394
+ if batch_size is not None:
395
+ self.batch_size = batch_size
396
+
397
+ if isinstance(train_data, Dataset):
398
+ sample = next(iter(train_data))
399
+ print('Tokenizing dataset...', sample, type(sample))
400
+ is_string_dataset = isinstance(sample[0], str) if not isinstance(sample, dict) else isinstance(sample[sample_key], str)
401
+
402
+ if is_string_dataset:
403
+ if isinstance(sample, list):
404
+ train_data = train_data.map(lambda x: self.tokenizer([x[0]] if isinstance(x[0], str) else x[0], return_tensors="pt", padding=True, truncation=True), batched=True)
405
+ val_data = val_data.map(lambda x: self.tokenizer([x[0]] if isinstance(x[0], str) else x[0], return_tensors="pt", padding=True, truncation=True), batched=True)
406
+ elif isinstance(sample, dict):
407
+ assert sample_key in sample and label_key in sample, 'Invalid dataset format'
408
+ train_data = train_data.map(lambda x: self.tokenizer([x[sample_key]] if isinstance(x[sample_key], str) else x[sample_key], return_tensors="pt", padding=True, truncation=True), batched=True)
409
+ val_data = val_data.map(lambda x: self.tokenizer([x[sample_key]] if isinstance(x[sample_key], str) else x[sample_key], return_tensors="pt", padding=True, truncation=True), batched=True)
410
+ else:
411
+ raise ValueError('Invalid dataset format')
412
+
413
+ create_train_data = partial(self._batch_data, data=train_data)
414
+ create_val_data = partial(self._batch_data, data=val_data)
415
+
416
+ if self.auto_find_batch_size:
417
+ train_data = auto_find_batch_size(create_train_data)()
418
+ val_data = auto_find_batch_size(create_val_data)()
419
+ else:
420
+ train_data = create_train_data(self.batch_size)
421
+ val_data = create_val_data(self.batch_size)
422
+ # otherwise, assume it's already tokenized
423
+ else:
424
+ assert train_data is not None and val_data is not None, 'train_data and val_data must be provided'
425
+ assert isinstance(train_data, DataLoader) and isinstance(val_data, DataLoader), 'train_data and val_data must be DataLoader instances'
426
+
427
+ optimizer_name = self.optimizer_conf.pop('optimizer', 'sgd')
428
+ loss_name = self.optimizer_conf.pop('loss', 'crossentropy')
429
+ if lr:
430
+ self.optimizer_conf['lr'] = lr
431
+ if num_epochs:
432
+ self.num_epochs = num_epochs
433
+
434
+ self.classifier.unfreeze_all()
435
+ # freeze backbone
436
+ print('Freezing backbone')
437
+ self.classifier.freeze_backbone()
438
+ # freeze all heads that are not in the training data
439
+
440
+ if target_heads is None:
441
+ sample = next(iter(train_data))
442
+ if isinstance(sample, dict):
443
+ train_heads = list(sample[label_key].keys())
444
+ elif isinstance(sample, list):
445
+ train_heads = list(sample[1].keys())
446
+ else:
447
+ raise ValueError('Invalid dataset format')
448
+ else:
449
+ train_heads = target_heads
450
+
451
+ for head_name in self.classifier.heads.keys():
452
+ if head_name not in train_heads:
453
+ print(f'Freezing head {head_name}')
454
+ self.classifier.freeze_head(head_name)
455
+
456
+ self.classifier.to(self.device)
457
+ self.classifier.train(True)
458
+ loss_func = {'crossentropy': nn.CrossEntropyLoss, 'bce': nn.BCELoss}.get(loss_name, nn.CrossEntropyLoss)
459
+ optimizer_class = {'sgd': optim.SGD, 'adam': optim.Adam}.get(optimizer_name, optim.SGD)
460
+ optimizer = optimizer_class(self.classifier.parameters(), **self.optimizer_conf)
461
+
462
+ scheduler = None
463
+ if self.use_lr_scheduler:
464
+ scheduler_class = {
465
+ 'plateau': optim.lr_scheduler.ReduceLROnPlateau,
466
+ 'step': optim.lr_scheduler.StepLR,
467
+ }.get(self.scheduler_conf.get('scheduler'), optim.lr_scheduler.ReduceLROnPlateau)
468
+ scheduler = scheduler_class(optimizer, 'min', **self.scheduler_conf)
469
+
470
+ history = self._train(loss_func(), optimizer, scheduler, self.accelerator.prepare(train_data), self.accelerator.prepare(val_data), train_heads, sample_key, label_key)
471
+ if self.load_best:
472
+ self.classifier.load(os.path.join(BASE_PATH, f'../train_runs/{self.name_prefix}-run-{self.train_run-1}-best-model.pth'))
473
+ return self.classifier, history
474
+
475
+ def _train(self, criterion, optimizer, scheduler, dataloader, val_dataloader, head_names, sample_key, label_key):
476
+ average_acc = 0
477
+ losses = []
478
+ precisions = []
479
+ best_prec = 0.0
480
+
481
+ val_losses = []
482
+ val_accs = []
483
+ avg_val_acc = 0.0
484
+
485
+ patience = 50
486
+ reset_patience = 25
487
+ patience_reset_counter = 0
488
+ patience_counter = 0
489
+ current_max = 0
490
+ total_max = 0
491
+ num_samples = len(dataloader)
492
+ pbar = tqdm(total=self.num_epochs * num_samples, desc='Training model...')
493
+ for epoch in range(self.num_epochs):
494
+ self.classifier.train() # Set the model to training mode
495
+ running_loss = 0.0
496
+ all_preds = {name: [] for name in head_names}
497
+ all_labels = {name: [] for name in head_names}
498
+
499
+ for step, sample in enumerate(dataloader):
500
+ labels = {name: sample[label_key] for name in head_names}
501
+ embeddings = BatchEncoding({k: torch.stack(v, dim=1).to(self.device) for k, v in sample.items() if k not in [label_key, sample_key]}).to(self.device)
502
+ outputs = self.classifier(embeddings, head_names=head_names) # Forward pass
503
+ loss = sum([criterion(outputs[name].to(self.device), labels[name].to(self.device)) for name in labels.keys()])
504
+ loss += self.l2_loss_weight * self.classifier.get_l2_loss().to(self.device)
505
+ running_loss += loss.item()
506
+ loss.backward() # Backward pass
507
+ if (step + 1) % self.gradient_accumulation_steps == 0:
508
+ optimizer.step() # Update model parameters
509
+ optimizer.zero_grad() # Zero the parameter gradients
510
+ # Store predictions and labels for precision calculation
511
+ for name in labels.keys():
512
+ preds = outputs[name][0].argmax().item()
513
+ all_labels[name].append(labels[name][0].cpu().numpy())
514
+ all_preds[name].append(preds)
515
+
516
+ pbar.update(1)
517
+ # clear memory
518
+ torch.cuda.empty_cache()
519
+
520
+ epoch_loss = running_loss / num_samples
521
+ if scheduler:
522
+ scheduler.step(epoch_loss)
523
+
524
+ average_acc += np.mean([np.mean(np.abs(np.array(all_labels[name]) - np.array(all_preds[name])) == 0) for name in head_names])
525
+ average_acc /= 2.0
526
+ if val_dataloader:
527
+ val_loss, val_acc = self.validate(self.classifier, criterion, val_dataloader, head_names, sample_key, label_key)
528
+ avg_val_acc += val_acc.item()
529
+ avg_val_acc /= 2.0
530
+ val_losses.append(val_loss)
531
+ val_accs.append(val_acc)
532
+ losses.append(epoch_loss)
533
+ precisions.append(average_acc)
534
+ if avg_val_acc > current_max:
535
+ current_max = avg_val_acc
536
+ self.classifier.save(os.path.join(BASE_PATH, f'../train_runs/{self.name_prefix}-run-{self.train_run}-best-model.pth'))
537
+ best_prec = max(average_acc, best_prec)
538
+ #print(f"Epoch {epoch+1}/{num_epochs} (LR: {scheduler.get_last_lr()[0]:.4e}), Loss: {epoch_loss:.4f}, Precision: {l:.4f}")
539
+ pbar_data = {
540
+ 'epoch': epoch + 1,
541
+ 'loss': epoch_loss,
542
+ 'avg_acc': average_acc,
543
+ 'acc_max': best_prec
544
+ }
545
+ if scheduler:
546
+ pbar_data['lr'] = scheduler.get_last_lr()[0]
547
+ if val_dataloader:
548
+ pbar_data['val_loss'] = val_loss
549
+ pbar_data['val_acc'] = val_acc
550
+ pbar_data['avg_val_acc'] = avg_val_acc
551
+ pbar.set_postfix(pbar_data)
552
+ # clear memory
553
+ torch.cuda.empty_cache()
554
+
555
+ pbar.close()
556
+ param_dict = {
557
+ 'dropout': self.dropout,
558
+ 'model_conf': {k:v for k, v in self.model_conf.items() if k not in ['tokenizer', 'backbone']},
559
+ 'optimizer_conf': self.optimizer_conf,
560
+ 'scheduler_conf': self.scheduler_conf,
561
+ 'l2_loss_weight': self.l2_loss_weight,
562
+ 'num_epochs': self.num_epochs,
563
+ 'device': self.device,
564
+ 'train_run': self.train_run,
565
+ 'name_prefix': self.name_prefix,
566
+ 'use_lr_scheduler': self.use_lr_scheduler,
567
+ 'metrics': {
568
+ 'loss': losses,
569
+ 'val_loss': val_losses,
570
+ 'precision': precisions,
571
+ 'val_precision': val_accs
572
+ }
573
+ }
574
+ with open(os.path.join(BASE_PATH, f'../train_runs/{self.name_prefix}-train-run-{self.train_run}.json'), 'w') as f:
575
+ json.dump(param_dict, f)
576
+ print("Training complete!")
577
+ self.train_run += 1
578
+
579
+ return param_dict
580
+
581
+ def _plot_history(self, loss, val_loss, precision, val_precision):
582
+ fig = plt.figure(figsize=(15,7))
583
+ ax = plt.subplot(1,2, 1)
584
+ ax.set_title('loss')
585
+ plt.plot(range(len(loss)), loss, 'g--', label='train_loss')
586
+ plt.plot(range(len(loss)), val_loss, 'r--', label='val_loss')
587
+ plt.yscale('log')
588
+ plt.legend()
589
+ ax = plt.subplot(1,2, 2)
590
+ ax.set_title('accuracy')
591
+ plt.plot(range(len(precision)), precision, 'g--', label='prec')
592
+ plt.plot(range(len(precision)), val_precision, 'r--',label='val_prec')
593
+ plt.legend()
594
+ return fig
595
+
596
+ def validate(self, model, criterion, dataloader, head_names=None, sample_key='sample', label_key='label'):
597
+ running_loss = 0
598
+ num_samples = len(dataloader)
599
+ if head_names is None:
600
+ sample = next(iter(dataloader))[1]
601
+ head_names = list(sample.keys())
602
+
603
+ all_labels = {name: [] for name in head_names}
604
+ all_preds = {name: [] for name in head_names}
605
+
606
+ num_labels = {name: model.heads[name].out_features for name in head_names}
607
+
608
+ model.train(False)
609
+ for sample in dataloader:
610
+ labels = {name: sample[label_key] for name in head_names}
611
+ embeddings = BatchEncoding({k: torch.stack(v, dim=1).to(self.device) for k, v in sample.items() if k not in [label_key, sample_key]})
612
+ outputs = model(embeddings) # Forward pass
613
+ loss = sum([criterion(outputs[name].to(self.device), labels[name].to(self.device)) for name in head_names])
614
+ loss += self.l2_loss_weight * model.get_l2_loss().to(self.device)
615
+ running_loss += loss.item()
616
+ # Store predictions and labels for precision calculation
617
+ for name in head_names:
618
+ preds = outputs[name][0].argmax().item()
619
+ all_labels[name].append(labels[name][0].cpu().numpy())
620
+ all_preds[name].append(preds)
621
+ torch.cuda.empty_cache()
622
+ return running_loss / num_samples, np.mean([np.mean(np.abs(np.array(all_labels[name]) - np.array(all_preds[name])) == 0) for name in head_names])
623
+
624
+ def eval(self, label_map, test_set=None, sample_key='sample', label_key='label'):
625
+ if test_set is None:
626
+ assert self.test_data is not None, 'No test data provided'
627
+ test_set = self.test_data
628
+ sample = next(iter(test_set))
629
+ is_string_dataset = isinstance(sample[0], str) if not isinstance(sample, dict) else isinstance(sample[sample_key], str)
630
+
631
+ if is_string_dataset:
632
+ if isinstance(sample, list):
633
+ test_set = test_set.map(lambda x: self.tokenizer([x[0]] if isinstance(x[0], str) else x[0], return_tensors="pt", padding=True, truncation=True), batched=True)
634
+ elif isinstance(sample, dict):
635
+ assert sample_key in sample and label_key in sample, 'Invalid dataset format'
636
+ test_set = test_set.map(lambda x: self.tokenizer([x[sample_key]] if isinstance(x[sample_key], str) else x[sample_key], return_tensors="pt", padding=True, truncation=True), batched=True)
637
+ else:
638
+ raise ValueError('Invalid dataset format')
639
+
640
+ test_set = DataLoader(test_set, shuffle=True, batch_size=self.batch_size)
641
+ self.classifier.to(self.device)
642
+ return self._eval_model(test_set, label_map, sample_key, label_key)
643
+
644
+ def _eval_model(self, dataloader, label_map, sample_key, label_key):
645
+ self.classifier.train(False)
646
+ eval_heads = list(label_map.keys())
647
+ y_pred = {h: [] for h in eval_heads}
648
+ y_test = {h: [] for h in eval_heads}
649
+ for sample in tqdm(dataloader, total=len(dataloader), desc='Evaluating model...'):
650
+ labels = {name: sample[label_key] for name in eval_heads}
651
+ embeddings = BatchEncoding({k: torch.stack(v, dim=1).to(self.device) for k, v in sample.items() if k not in [label_key, sample_key]})
652
+ output = self.classifier(embeddings.to('cuda'), head_names=eval_heads)
653
+ for head in eval_heads:
654
+ y_pred[head].extend(output[head].argmax(dim=1).cpu())
655
+ y_test[head].extend(labels[head])
656
+ torch.cuda.empty_cache()
657
+
658
+ accuracies = {h: accuracy_score(y_test[h], y_pred[h]) for h in eval_heads}
659
+ f1_scores = {h: f1_score(y_test[h], y_pred[h], average="macro") for h in eval_heads}
660
+ recalls = {h: recall_score(y_test[h], y_pred[h], average='macro') for h in eval_heads}
661
+
662
+ report = {}
663
+ for head in eval_heads:
664
+ cm = confusion_matrix(y_test[head], y_pred[head], labels=list(label_map[head].keys()))
665
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(label_map[head].values()))
666
+ clf_report = classification_report(
667
+ y_test[head], y_pred[head], output_dict=True, target_names=list(label_map[head].values())
668
+ )
669
+ del clf_report["accuracy"]
670
+ clf_report = pd.DataFrame(clf_report).T.reset_index()
671
+ report[head] = dict(
672
+ clf_report=clf_report, confusion_matrix=disp, metrics={'accuracy': accuracies[head], 'f1': f1_scores[head], 'recall': recalls[head]}
673
+ )
674
+ return report
multi-head-sequence-classification-model-model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6400c69f1f1efffd158930fc7be49cf7147fe2fba2764bdf3f12d39235060521
3
  size 1135701541
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11a81b327d68a9c64534a68a223337a474abdf1c538eed908add65a2270010e1
3
  size 1135701541