import matplotlib.pyplot as plt def _update_axis( axis, image, title=None, fontsize=18, remove_axis=True, title_loc="center" ): axis.imshow(image, origin="upper") if title is not None: axis.set_title(title, fontsize=fontsize, loc=title_loc) if remove_axis is True: axis.set_axis_off() def tensor_image_to_grid( images: list, transform, row_count, col_count=None, figsize=(20, 20), fontsize=None, ): def splt_image_title(image): if isinstance(image, tuple): return image[0], image[1] else: return image, None def torch_to_image(t): return transform(image=t.permute(1, 2, 0).numpy())["image"] col_count = row_count if col_count is None else col_count if len(images) == 1: img, title = splt_image_title(images[0]) plt.imshow(torch_to_image(img)) plt.title = title plt.tight_layout() plt.axis("off") else: _, axii = plt.subplots(row_count, col_count, figsize=figsize) for ax, image in zip(axii.reshape(-1), images): try: img, title = splt_image_title(image) _update_axis( axis=ax, image=torch_to_image(img), remove_axis=True, title=title, fontsize=figsize[0] if fontsize is None else fontsize, ) except: pass plt.tight_layout() plt.show()