chatwithpdfs / model.py
IMvision12's picture
Update
06540c5
raw
history blame contribute delete
No virus
1.67 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
from transformers import BitsAndBytesConfig
def initialize_llmchain(
llm_model: str,
temperature: float,
max_tokens: int,
top_k: int,
access_token: str = None,
torch_dtype: str = "auto",
) -> HuggingFacePipeline:
"""
Initializes a language model chain based on the provided parameters.
Args:
- llm_model (str): The name of the language model to initialize.
- temperature (float): The temperature parameter for text generation.
- max_tokens (int): The maximum number of tokens to generate.
- top_k (int): The top-k parameter for token selection during generation.
- torch_dtype (str): The torch dtype to be used for model inference (default is "auto").
Returns:
- HuggingFacePipeline: Initialized language model pipeline.
"""
model_kwargs = {
"temperature": temperature,
"max_new_tokens": max_tokens,
"top_k": top_k,
"torch_dtype": torch_dtype,
}
# Initialize model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
llm_model,
low_cpu_mem_usage=True,
use_auth_token=access_token,
)
tokenizer = AutoTokenizer.from_pretrained(llm_model, use_auth_token=access_token)
# Initialize pipeline
pipe = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
token=access_token,
model_kwargs=model_kwargs,
pad_token_id=tokenizer.eos_token_id,
)
llm = HuggingFacePipeline(pipeline=pipe)
return llm