BioGeek's picture
Fix typo in title
5d238ae
raw
history blame
No virus
3.25 kB
import os
os.system("pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --index-url https://download.pytorch.org/whl/cpu")
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
os.system('pip install opencv-python-headless==4.8.1.78')
import gradio as gr
import cv2
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog
import numpy as np
# Path to the trained model weights
model_path = './model/keypoint_rcnn_X_101_32x8d_FPN_3x.pth'
number_of_keypoints = 22
# Setup the configuration for the model
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x.yaml"))
cfg.MODEL.DEVICE = 'cpu'
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = number_of_keypoints
cfg.TEST.KEYPOINT_OKS_SIGMAS = np.ones((number_of_keypoints, 1), dtype=float).tolist()
# Load the trained model weights
cfg.MODEL.WEIGHTS = model_path
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.6 # set a custom testing threshold
predictor = DefaultPredictor(cfg)
# Set metadata for visualization
MetadataCatalog.get("spot").set(thing_classes=["wing"])
metadata = MetadataCatalog.get("spot")
def markin(image_path):
im = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
outputs = predictor(im)
v = Visualizer(im[:, :, ::-1],
metadata=metadata,
# scale=0.9,
instance_mode=ColorMode.SEGMENTATION
)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
return out.get_image()
# Setup the Gradio interface
demo = gr.Interface(markin,
gr.Image(type="filepath", sources=['upload']),
"image",
examples=[
os.path.join(os.path.dirname(__file__), "images/mosquito-wing-1.jpg"),
os.path.join(os.path.dirname(__file__), "images/mosquito-wing-2.jpg"),
os.path.join(os.path.dirname(__file__), "images/mosquito-wing-3.jpg"),
os.path.join(os.path.dirname(__file__), "images/mosquito-wing-4.jpg"),
os.path.join(os.path.dirname(__file__), "images/mosquito-wing-5.jpg")
],
title='Mosquito wing landmarking',
description='Mosquitoes are a group of about 3,500 species of small insects, known widely for their role as vectors for numerous diseases. Studying mosquitoes, particularly their wings, is crucial in scientific research, primarily within the fields of entomology, genetics, and evolutionary biology. The wings of mosquitoes not only play a role in their mobility but also may hold important genetic information about their evolution, resistance, and even disease transmission. <br> <a href="https://datamarkin.com/models/automated-measurement-of-mosquito-wings" class="navbar-item "> More about mosquito wing project </a>')
if __name__ == "__main__":
demo.launch()