File size: 4,377 Bytes
19c9e2c
 
21f63e5
1009662
21f63e5
19c9e2c
 
 
2b3e4a6
19c9e2c
 
 
 
21f63e5
53ce97d
21f63e5
 
 
 
 
19c9e2c
 
53ce97d
19c9e2c
 
 
 
 
3ad9c6b
19c9e2c
 
 
1009662
19c9e2c
 
 
3ad9c6b
53ce97d
19c9e2c
21f63e5
 
 
 
53ce97d
21f63e5
19c9e2c
 
53ce97d
19c9e2c
 
 
 
 
3ad9c6b
19c9e2c
 
 
1009662
21f63e5
19c9e2c
 
3ad9c6b
53ce97d
19c9e2c
21f63e5
 
 
 
53ce97d
21f63e5
19c9e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21f63e5
19c9e2c
 
 
 
 
 
 
 
27b5582
19c9e2c
 
 
 
 
 
 
 
 
 
 
 
 
21f63e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import torch
import cv2
from torchvision.transforms import Compose, ToTensor, Resize, Normalize, ConvertImageDtype
from PIL import Image
import numpy as np

import gradio as gr
from model import IAT  # Ensure the correct import path

def set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])

def tensor_to_numpy(tensor):
    print("Converting tensor to numpy array...")
    tensor = tensor.detach().cpu().numpy()
    if tensor.ndim == 3 and tensor.shape[0] == 3:  # Convert CHW to HWC
        tensor = tensor.transpose(1, 2, 0)
    tensor = np.clip(tensor * 255, 0, 255).astype(np.uint8)  # Ensure the output is uint8
    return tensor

def dark_inference(img):
    print("Starting dark inference...")
    model = IAT()
    checkpoint_file_path = './checkpoint/best_Epoch_lol.pth'
    state_dict = torch.load(checkpoint_file_path, map_location='cpu')
    model.load_state_dict(state_dict)
    model.eval()
    print(f'Load model from {checkpoint_file_path}')

    transform = Compose([
        ToTensor(), 
        Resize(384),
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 
        ConvertImageDtype(torch.float) 
    ])
    input_img = transform(img)
    print(f'Image shape after transform: {input_img.shape}')

    with torch.no_grad():
        enhanced_img = model(input_img.unsqueeze(0))
    
    result_img = tensor_to_numpy(enhanced_img[0])
    print("Dark inference completed.")
    return result_img

def exposure_inference(img):
    print("Starting exposure inference...")
    model = IAT()
    checkpoint_file_path = './checkpoint/best_Epoch_exposure.pth'
    state_dict = torch.load(checkpoint_file_path, map_location='cpu')
    model.load_state_dict(state_dict)
    model.eval()
    print(f'Load model from {checkpoint_file_path}')

    transform = Compose([
        ToTensor(), 
        Resize(384),
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 
        ConvertImageDtype(torch.float) 
    ])
    input_img = transform(img)
    print(f'Image shape after transform: {input_img.shape}')

    with torch.no_grad():
        enhanced_img = model(input_img.unsqueeze(0))
    
    result_img = tensor_to_numpy(enhanced_img[0])
    print("Exposure inference completed.")
    return result_img

demo = gr.Blocks()
with demo:
    gr.Markdown(
        """
        # IAT
        Gradio demo for <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>IAT</a>: To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.
        """
    )

    with gr.Box():
        with gr.Row():
                with gr.Column():
                    with gr.Row():
                        input_image = gr.Image(label='Input Image', type='numpy')
                    with gr.Row():
                        dark_button = gr.Button('Low-light Enhancement')
                    with gr.Row():
                        exposure_button = gr.Button('Exposure Correction')
                with gr.Column():
                    res_image = gr.Image(type='numpy', label='Results')
        with gr.Row():
            dark_example_images = gr.Dataset(
                components=[input_image], 
                samples=[['dark_imgs/1.jpg'], ['dark_imgs/2.jpg'], ['dark_imgs/3.jpg']]
            )
        with gr.Row():
            exposure_example_images = gr.Dataset(
                components=[input_image], 
                samples=[['exposure_imgs/1.jpg'], ['exposure_imgs/2.jpg'], ['exposure_imgs/3.jpeg']]
            )

    gr.Markdown(
        """
        <p style='text-align: center'><a href='https://arxiv.org/abs/2205.14871' target='_blank'>You Only Need 90K Parameters to Adapt Light: A Light Weight Transformer for Image Enhancement and Exposure Correction</a> | <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>Github Repo</a></p>
        """
    )

    dark_button.click(fn=dark_inference, inputs=input_image, outputs=res_image)
    exposure_button.click(fn=exposure_inference, inputs=input_image, outputs=res_image)
    dark_example_images.click(fn=set_example_image, inputs=dark_example_images, outputs=dark_example_images.components)
    exposure_example_images.click(fn=set_example_image, inputs=exposure_example_images, outputs=exposure_example_images.components)

demo.launch(enable_queue=True)