hallucination_evaluation_model / modeling_hhem_v2.py
ofermend's picture
updated files for hhem-2.1-open
e2b6d9e
raw
history blame
No virus
2.4 kB
import torch
from peft import PeftModel
from transformers import PreTrainedModel, AutoConfig, T5ForTokenClassification, AutoModel, AutoTokenizer, AutoModelForTokenClassification
from .configuration_hhem_v2 import HHEMv2Config
class HHEMv2Model(PreTrainedModel):
config_class = HHEMv2Config
def __init__(self, config):
super().__init__(config)
# self.t5 = T5ForTokenClassification.from_config(
# AutoConfig.from_pretrained(config.foundation)
# )
# def populate(self, model):
# self.t5 = model
# def forward(self, **kwarg):
# return self.t5.transformer(**kwarg)
class HHEMv2ForSequenceClassification(PreTrainedModel):
config_class = HHEMv2Config
def __init__(self, config=HHEMv2Config()):
super().__init__(config)
self.t5 = T5ForTokenClassification(
AutoConfig.from_pretrained(config.foundation)
)
self.prompt = config.prompt
self.tokenzier = AutoTokenizer.from_pretrained(config.foundation)
def populate(self, model: AutoModel):
"""Initiate the model with the pretrained model
This method should only be called by Vectara employee who prepares the model for publishing. Users do not need to call this method.
"""
self.t5 = model
# TODO: Figure out how to publish only the adapter yet still able to do end-to-end pulling and inference.
# def populate_lora(self, checkpoint: str):
# base_model = AutoModelForTokenClassification.from_pretrained(self.config.foundation)
# combined_model = PeftModel.from_pretrained(base_model, checkpoint, is_trainable=False)
# self.t5 = combined_model
def forward(self, **kwargs):
return self.t5(**kwargs)
def predict(self, text_pairs):
tokenizer = self.tokenzier
pair_dict = [{'text1': pair[0], 'text2': pair[1]} for pair in text_pairs]
inputs = tokenizer(
[self.prompt.format(**pair) for pair in pair_dict], return_tensors='pt', padding=True)
self.t5.eval()
with torch.no_grad():
outputs = self.t5(**inputs)
logits = outputs.logits
logits = logits[:, 0, :] # tok_cls
transformed_probs = torch.softmax(logits, dim=-1)
raw_scores = transformed_probs[:, 1] # the probability of class 1
return raw_scores