import os import torch import cv2 from torchvision.transforms import Compose, ToTensor, Resize, Normalize, ConvertImageDtype from PIL import Image import numpy as np import gradio as gr from model import IAT # Ensure the correct import path def set_example_image(example: list) -> dict: return gr.Image.update(value=example[0]) def tensor_to_numpy(tensor): print("Converting tensor to numpy array...") tensor = tensor.detach().cpu().numpy() if tensor.ndim == 3 and tensor.shape[0] == 3: # Convert CHW to HWC tensor = tensor.transpose(1, 2, 0) tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8) # Ensure the output is uint8 return tensor def dark_inference(img): print("Starting dark inference...") model = IAT() checkpoint_file_path = './checkpoint/best_Epoch_lol.pth' state_dict = torch.load(checkpoint_file_path, map_location='cpu') model.load_state_dict(state_dict) model.eval() print(f'Load model from {checkpoint_file_path}') transform = Compose([ ToTensor(), Resize(384), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ConvertImageDtype(torch.float) ]) input_img = transform(img) print(f'Image shape after transform: {input_img.shape}') with torch.no_grad(): enhanced_img = model(input_img.unsqueeze(0)) result_img = tensor_to_numpy(enhanced_img[0]) print("Dark inference completed.") return result_img def exposure_inference(img): print("Starting exposure inference...") model = IAT() checkpoint_file_path = './checkpoint/best_Epoch_exposure.pth' state_dict = torch.load(checkpoint_file_path, map_location='cpu') model.load_state_dict(state_dict) model.eval() print(f'Load model from {checkpoint_file_path}') transform = Compose([ ToTensor(), Resize(384), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ConvertImageDtype(torch.float) ]) input_img = transform(img) print(f'Image shape after transform: {input_img.shape}') with torch.no_grad(): enhanced_img = model(input_img.unsqueeze(0)) result_img = tensor_to_numpy(enhanced_img[0]) print("Exposure inference completed.") return result_img demo = gr.Blocks() with demo: gr.Markdown( """ # IAT Gradio demo for IAT: To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. """ ) with gr.Box(): with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.Image(label='Input Image', type='numpy') with gr.Row(): dark_button = gr.Button('Low-light Enhancement') with gr.Row(): exposure_button = gr.Button('Exposure Correction') with gr.Column(): res_image = gr.Image(type='numpy', label='Results') with gr.Row(): dark_example_images = gr.Dataset( components=[input_image], samples=[['dark_imgs/1.jpg'], ['dark_imgs/2.jpg'], ['dark_imgs/3.jpg']] ) with gr.Row(): exposure_example_images = gr.Dataset( components=[input_image], samples=[['exposure_imgs/1.jpg'], ['exposure_imgs/2.jpg'], ['exposure_imgs/3.jpeg']] ) gr.Markdown( """

You Only Need 90K Parameters to Adapt Light: A Light Weight Transformer for Image Enhancement and Exposure Correction | Github Repo

""" ) dark_button.click(fn=dark_inference, inputs=input_image, outputs=res_image) exposure_button.click(fn=exposure_inference, inputs=input_image, outputs=res_image) dark_example_images.click(fn=set_example_image, inputs=dark_example_images, outputs=dark_example_images.components) exposure_example_images.click(fn=set_example_image, inputs=exposure_example_images, outputs=exposure_example_images.components) demo.launch(enable_queue=True)