import os import cv2 import gradio as gr import numpy as np import spaces import torch import torch.nn.functional as F from gradio.themes.utils import sizes from torchvision import transforms from PIL import Image import tempfile from classes_and_palettes import GOLIATH_PALETTE, GOLIATH_CLASSES class Config: ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets') CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") CHECKPOINTS = { "0.3b": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2", "0.6b": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2", "1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2", } class ModelManager: @staticmethod def load_model(checkpoint_name: str): checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, Config.CHECKPOINTS[checkpoint_name]) model = torch.jit.load(checkpoint_path) model.eval() model.to("cuda") return model @staticmethod @torch.inference_mode() def run_model(model, input_tensor, height, width): output = model(input_tensor) output = F.interpolate(output, size=(height, width), mode="bilinear", align_corners=False) _, preds = torch.max(output, 1) return preds class ImageProcessor: def __init__(self): self.transform_fn = transforms.Compose([ transforms.Resize((1024, 768)), transforms.ToTensor(), transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], std=[58.5/255, 57.0/255, 57.5/255]), ]) @spaces.GPU def process_image(self, image: Image.Image, model_name: str): model = ModelManager.load_model(model_name) input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda") preds = ModelManager.run_model(model, input_tensor, image.height, image.width) mask = preds.squeeze(0).cpu().numpy() # Visualize the segmentation blended_image = self.visualize_pred_with_overlay(image, mask) # Create downloadable .npy file npy_path = tempfile.mktemp(suffix='.npy') np.save(npy_path, mask) return blended_image, npy_path @staticmethod def visualize_pred_with_overlay(img, sem_seg, alpha=0.5): img_np = np.array(img.convert("RGB")) sem_seg = np.array(sem_seg) num_classes = len(GOLIATH_CLASSES) ids = np.unique(sem_seg)[::-1] legal_indices = ids < num_classes ids = ids[legal_indices] labels = np.array(ids, dtype=np.int64) colors = [GOLIATH_PALETTE[label] for label in labels] overlay = np.zeros((*sem_seg.shape, 3), dtype=np.uint8) for label, color in zip(labels, colors): overlay[sem_seg == label, :] = color blended = np.uint8(img_np * (1 - alpha) + overlay * alpha) return Image.fromarray(blended) class GradioInterface: def __init__(self): self.image_processor = ImageProcessor() def create_interface(self): app_styles = """ """ header_html = f""" {app_styles}

Sapiens:Body-Part Segmentation

ECCV 2024 (Oral)

Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images. This demo showcases the finetuned body-part segmentation model.

""" js_func = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'dark') { url.searchParams.set('__theme', 'dark'); window.location.href = url.href; } } """ def process_image(image, model_name): result, npy_path = self.image_processor.process_image(image, model_name) return result, npy_path with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: gr.HTML(header_html) with gr.Row(elem_classes="content-container"): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview") model_name = gr.Dropdown( label="Model Size", choices=list(Config.CHECKPOINTS.keys()), value="1b", ) example_model = gr.Examples( inputs=input_image, examples_per_page=14, examples=[ os.path.join(Config.ASSETS_DIR, "images", img) for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images")) ], ) with gr.Column(): result_image = gr.Image(label="Segmentation Result", type="pil", elem_classes="image-preview") npy_output = gr.File(label="Segmentation (.npy)") run_button = gr.Button("Run") gr.Image(os.path.join(Config.ASSETS_DIR, "palette.jpg"), label="Class Palette", type="filepath", elem_classes="image-preview") run_button.click( fn=process_image, inputs=[input_image, model_name], outputs=[result_image, npy_output], ) return demo def main(): # Configure CUDA if available if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True interface = GradioInterface() demo = interface.create_interface() demo.launch(share=False) if __name__ == "__main__": main()