File size: 1,518 Bytes
fc262e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import matplotlib.pyplot as plt


def _update_axis(
    axis, image, title=None, fontsize=18, remove_axis=True, title_loc="center"
):
    axis.imshow(image, origin="upper")
    if title is not None:
        axis.set_title(title, fontsize=fontsize, loc=title_loc)
    if remove_axis is True:
        axis.set_axis_off()


def tensor_image_to_grid(
    images: list,
    transform,
    row_count,
    col_count=None,
    figsize=(20, 20),
    fontsize=None,
):
    def splt_image_title(image):
        if isinstance(image, tuple):
            return image[0], image[1]
        else:
            return image, None

    def torch_to_image(t):
        return transform(image=t.permute(1, 2, 0).numpy())["image"]

    col_count = row_count if col_count is None else col_count
    if len(images) == 1:
        img, title = splt_image_title(images[0])
        plt.imshow(torch_to_image(img))
        plt.title = title
        plt.tight_layout()
        plt.axis("off")
    else:
        _, axii = plt.subplots(row_count, col_count, figsize=figsize)
        for ax, image in zip(axii.reshape(-1), images):
            try:
                img, title = splt_image_title(image)
                _update_axis(
                    axis=ax,
                    image=torch_to_image(img),
                    remove_axis=True,
                    title=title,
                    fontsize=figsize[0] if fontsize is None else fontsize,
                )
            except:
                pass

    plt.tight_layout()
    plt.show()