--- license: mit --- # Two steps only need. First step. (git clone and install required packages) ```shell git clone https://github.com/ByungKwanLee/TroL bash install ``` Second step. (open, edit, and run `demo.py`) ```python import torch from config import * from PIL import Image from utils.utils import * import torch.nn.functional as F from trol.load_trol import load_trol from torchvision.transforms.functional import pil_to_tensor # model selection link = "TroL-1.8B" # [Select One] 'TroL-1.8B' | 'TroL-3.8B' | 'TroL-7B' # User prompt prompt_type="with_image" # Select one option "text_only", "with_image" img_path='figures/demo.png' question="What is the troll doing? Provide the detail in the image and imagine what the event happens." # loading model model, tokenizer = load_trol(link=link) # cpu -> gpu for param in model.parameters(): if not param.is_cuda: param.data = param.to('cuda:0') # prompt type -> input prompt image_token_number = None if prompt_type == 'with_image': # Image Load image = pil_to_tensor(Image.open(img_path).convert("RGB")) if not "3.8B" in link: image_token_number = 1225 image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0) inputs = [{'image': image, 'question': question}] elif prompt_type=='text_only': inputs = [{'question': question}] # Generate with torch.inference_mode(): _inputs = model.eval_process(inputs=inputs, data='demo', tokenizer=tokenizer, device='cuda:0', img_token_number=image_token_number) generate_ids = model.generate(**_inputs, max_new_tokens=256, use_cache=True) response = output_filtering(tokenizer.batch_decode(generate_ids, skip_special_tokens=False)[0], model) print(response) ``` So easy Let's say TroL! # Paper arxiv.org/abs/2406.12246