Spaces:
Running
Running
File size: 4,377 Bytes
19c9e2c 21f63e5 1009662 21f63e5 19c9e2c 2b3e4a6 19c9e2c 21f63e5 53ce97d 21f63e5 19c9e2c 53ce97d 19c9e2c 3ad9c6b 19c9e2c 1009662 19c9e2c 3ad9c6b 53ce97d 19c9e2c 21f63e5 53ce97d 21f63e5 19c9e2c 53ce97d 19c9e2c 3ad9c6b 19c9e2c 1009662 21f63e5 19c9e2c 3ad9c6b 53ce97d 19c9e2c 21f63e5 53ce97d 21f63e5 19c9e2c 21f63e5 19c9e2c 27b5582 19c9e2c 21f63e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
import torch
import cv2
from torchvision.transforms import Compose, ToTensor, Resize, Normalize, ConvertImageDtype
from PIL import Image
import numpy as np
import gradio as gr
from model import IAT # Ensure the correct import path
def set_example_image(example: list) -> dict:
return gr.Image.update(value=example[0])
def tensor_to_numpy(tensor):
print("Converting tensor to numpy array...")
tensor = tensor.detach().cpu().numpy()
if tensor.ndim == 3 and tensor.shape[0] == 3: # Convert CHW to HWC
tensor = tensor.transpose(1, 2, 0)
tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8) # Ensure the output is uint8
return tensor
def dark_inference(img):
print("Starting dark inference...")
model = IAT()
checkpoint_file_path = './checkpoint/best_Epoch_lol.pth'
state_dict = torch.load(checkpoint_file_path, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()
print(f'Load model from {checkpoint_file_path}')
transform = Compose([
ToTensor(),
Resize(384),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
ConvertImageDtype(torch.float)
])
input_img = transform(img)
print(f'Image shape after transform: {input_img.shape}')
with torch.no_grad():
enhanced_img = model(input_img.unsqueeze(0))
result_img = tensor_to_numpy(enhanced_img[0])
print("Dark inference completed.")
return result_img
def exposure_inference(img):
print("Starting exposure inference...")
model = IAT()
checkpoint_file_path = './checkpoint/best_Epoch_exposure.pth'
state_dict = torch.load(checkpoint_file_path, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()
print(f'Load model from {checkpoint_file_path}')
transform = Compose([
ToTensor(),
Resize(384),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
ConvertImageDtype(torch.float)
])
input_img = transform(img)
print(f'Image shape after transform: {input_img.shape}')
with torch.no_grad():
enhanced_img = model(input_img.unsqueeze(0))
result_img = tensor_to_numpy(enhanced_img[0])
print("Exposure inference completed.")
return result_img
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# IAT
Gradio demo for <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>IAT</a>: To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.
"""
)
with gr.Box():
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.Image(label='Input Image', type='numpy')
with gr.Row():
dark_button = gr.Button('Low-light Enhancement')
with gr.Row():
exposure_button = gr.Button('Exposure Correction')
with gr.Column():
res_image = gr.Image(type='numpy', label='Results')
with gr.Row():
dark_example_images = gr.Dataset(
components=[input_image],
samples=[['dark_imgs/1.jpg'], ['dark_imgs/2.jpg'], ['dark_imgs/3.jpg']]
)
with gr.Row():
exposure_example_images = gr.Dataset(
components=[input_image],
samples=[['exposure_imgs/1.jpg'], ['exposure_imgs/2.jpg'], ['exposure_imgs/3.jpeg']]
)
gr.Markdown(
"""
<p style='text-align: center'><a href='https://arxiv.org/abs/2205.14871' target='_blank'>You Only Need 90K Parameters to Adapt Light: A Light Weight Transformer for Image Enhancement and Exposure Correction</a> | <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>Github Repo</a></p>
"""
)
dark_button.click(fn=dark_inference, inputs=input_image, outputs=res_image)
exposure_button.click(fn=exposure_inference, inputs=input_image, outputs=res_image)
dark_example_images.click(fn=set_example_image, inputs=dark_example_images, outputs=dark_example_images.components)
exposure_example_images.click(fn=set_example_image, inputs=exposure_example_images, outputs=exposure_example_images.components)
demo.launch(enable_queue=True)
|