tomaarsen HF staff commited on
Commit
2fd38a5
1 Parent(s): 247f25d

Create train_script.py

Browse files
Files changed (1) hide show
  1. train_script.py +111 -0
train_script.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Dict
3
+ import datasets
4
+ from datasets import Dataset
5
+ from sentence_transformers import (
6
+ SentenceTransformer,
7
+ SentenceTransformerTrainer,
8
+ losses,
9
+ evaluation,
10
+ SentenceTransformerTrainingArguments
11
+ )
12
+ from sentence_transformers.models import Transformer, Pooling, Dense, Normalize
13
+
14
+ def to_triplets(dataset):
15
+ premises = defaultdict(dict)
16
+ for sample in dataset:
17
+ premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
18
+ queries = []
19
+ positives = []
20
+ negatives = []
21
+ for premise, sentences in premises.items():
22
+ if 0 in sentences and 2 in sentences:
23
+ queries.append(premise)
24
+ positives.append(sentences[0]) # <- entailment
25
+ negatives.append(sentences[2]) # <- contradiction
26
+ return Dataset.from_dict({
27
+ "anchor": queries,
28
+ "positive": positives,
29
+ "negative": negatives,
30
+ })
31
+
32
+ snli_ds = datasets.load_dataset("snli")
33
+ snli_ds = datasets.DatasetDict({
34
+ "train": to_triplets(snli_ds["train"]),
35
+ "validation": to_triplets(snli_ds["validation"]),
36
+ "test": to_triplets(snli_ds["test"]),
37
+ })
38
+ multi_nli_ds = datasets.load_dataset("multi_nli")
39
+ multi_nli_ds = datasets.DatasetDict({
40
+ "train": to_triplets(multi_nli_ds["train"]),
41
+ "validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
42
+ })
43
+
44
+ all_nli_ds = datasets.DatasetDict({
45
+ "train": datasets.concatenate_datasets([snli_ds["train"], multi_nli_ds["train"]]),#.select(range(10000)),
46
+ "validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),#.select(range(1000)),
47
+ "test": snli_ds["test"]
48
+ })
49
+
50
+ stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
51
+ stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")
52
+
53
+ training_args = SentenceTransformerTrainingArguments(
54
+ output_dir="checkpoints",
55
+ num_train_epochs=1,
56
+ seed=42,
57
+ per_device_train_batch_size=256,
58
+ per_device_eval_batch_size=256,
59
+ learning_rate=2e-5,
60
+ warmup_ratio=0.1,
61
+ bf16=True,
62
+ logging_steps=100,
63
+ eval_strategy="steps",
64
+ eval_steps=100,
65
+ save_steps=100,
66
+ save_total_limit=2,
67
+ metric_for_best_model="sts-dev_spearman_cosine",
68
+ greater_is_better=True,
69
+ )
70
+
71
+ transformer = Transformer("prajjwal1/bert-tiny", max_seq_length=384)
72
+ pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
73
+ dense = Dense(128, 256)
74
+ normalize = Normalize()
75
+ model = SentenceTransformer(modules=[transformer, pooling, dense, normalize])
76
+ # Ensure all tensors in the model are contiguous
77
+ for param in model.parameters():
78
+ param.data = param.data.contiguous()
79
+
80
+ loss = losses.MultipleNegativesRankingLoss(model)
81
+ # loss = losses.MatryoshkaLoss(model, loss, [256, 128, 64, 32, 16, 8])
82
+
83
+ dev_evaluator = evaluation.EmbeddingSimilarityEvaluator(
84
+ stsb_dev["sentence1"],
85
+ stsb_dev["sentence2"],
86
+ [score / 5 for score in stsb_dev["score"]],
87
+ main_similarity=evaluation.SimilarityFunction.COSINE,
88
+ name="sts-dev",
89
+ )
90
+
91
+ trainer = SentenceTransformerTrainer(
92
+ model=model,
93
+ evaluator=dev_evaluator,
94
+ args=training_args,
95
+ train_dataset=all_nli_ds["train"],
96
+ eval_dataset=all_nli_ds["validation"],
97
+ loss=loss,
98
+ )
99
+ trainer.train()
100
+
101
+ test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
102
+ stsb_test["sentence1"],
103
+ stsb_test["sentence2"],
104
+ [score / 5 for score in stsb_test["score"]],
105
+ main_similarity=evaluation.SimilarityFunction.COSINE,
106
+ name="sts-test",
107
+ )
108
+ results = test_evaluator(model)
109
+
110
+ breakpoint()
111
+ model.push_to_hub("sentence-transformers-testing/all-nli-bert-tiny-dense", private=True)