SkalskiP commited on
Commit
fa98faf
1 Parent(s): 33e6030

Point prompt mode ready for review

Browse files
Files changed (3) hide show
  1. app.py +186 -68
  2. utils/draw.py +32 -0
  3. utils/efficient_sam.py +33 -0
app.py CHANGED
@@ -7,7 +7,8 @@ import torch
7
  from PIL import Image
8
  from transformers import SamModel, SamProcessor
9
 
10
- from utils.efficient_sam import load, inference_with_box
 
11
 
12
  MARKDOWN = """
13
  # EfficientSAM sv. SAM
@@ -17,28 +18,74 @@ This is a demo for ⚔️ SAM Battlegrounds - a speed and accuracy comparison be
17
  [SAM](https://arxiv.org/abs/2304.02643).
18
  """
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
22
  SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
23
  EFFICIENT_SAM_MODEL = load(device=DEVICE)
24
  MASK_ANNOTATOR = sv.MaskAnnotator(
25
- color=sv.Color.red(),
26
- color_lookup=sv.ColorLookup.INDEX)
27
- BOX_ANNOTATOR = sv.BoundingBoxAnnotator(
28
- color=sv.Color.red(),
29
  color_lookup=sv.ColorLookup.INDEX)
30
 
31
 
32
- def annotate_image(image: np.ndarray, detections: sv.Detections) -> np.ndarray:
 
 
 
 
 
 
 
 
33
  bgr_image = image[:, :, ::-1]
34
  annotated_bgr_image = MASK_ANNOTATOR.annotate(
35
  scene=bgr_image, detections=detections)
36
- annotated_bgr_image = BOX_ANNOTATOR.annotate(
37
- scene=annotated_bgr_image, detections=detections)
 
 
 
 
 
 
 
 
 
38
  return annotated_bgr_image[:, :, ::-1]
39
 
40
 
41
- def efficient_sam_inference(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  image: np.ndarray,
43
  x_min: int,
44
  y_min: int,
@@ -49,10 +96,17 @@ def efficient_sam_inference(
49
  mask = inference_with_box(image, box, EFFICIENT_SAM_MODEL, DEVICE)
50
  mask = mask[np.newaxis, ...]
51
  detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
52
- return annotate_image(image=image, detections=detections)
 
 
 
 
 
 
 
53
 
54
 
55
- def sam_inference(
56
  image: np.ndarray,
57
  x_min: int,
58
  y_min: int,
@@ -76,10 +130,17 @@ def sam_inference(
76
  )[0][0][0].numpy()
77
  mask = mask[np.newaxis, ...]
78
  detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
79
- return annotate_image(image=image, detections=detections)
 
 
 
 
 
 
 
80
 
81
 
82
- def inference(
83
  image: np.ndarray,
84
  x_min: int,
85
  y_min: int,
@@ -87,8 +148,46 @@ def inference(
87
  y_max: int
88
  ) -> Tuple[np.ndarray, np.ndarray]:
89
  return (
90
- efficient_sam_inference(image, x_min, y_min, x_max, y_max),
91
- sam_inference(image, x_min, y_min, x_max, y_max)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
 
94
 
@@ -96,73 +195,92 @@ def clear(_: np.ndarray) -> Tuple[None, None]:
96
  return None, None
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  with gr.Blocks() as demo:
100
  gr.Markdown(MARKDOWN)
101
  with gr.Tab(label="Box prompt"):
102
  with gr.Row():
103
  with gr.Column():
104
- input_image = gr.Image()
105
  with gr.Accordion(label="Box", open=False):
106
  with gr.Row():
107
- x_min_number = gr.Number(label="x_min")
108
- y_min_number = gr.Number(label="y_min")
109
- x_max_number = gr.Number(label="x_max")
110
- y_max_number = gr.Number(label="y_max")
111
- efficient_sam_output_image = gr.Image(label="EfficientSAM")
112
- sam_output_image = gr.Image(label="SAM")
113
  with gr.Row():
114
- submit_button = gr.Button("Submit")
115
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  gr.Examples(
117
- fn=inference,
118
- examples=[
119
- [
120
- 'https://media.roboflow.com/efficient-sam/beagle.jpeg',
121
- 69,
122
- 26,
123
- 625,
124
- 704
125
- ],
126
- [
127
- 'https://media.roboflow.com/efficient-sam/corgi.jpg',
128
- 801,
129
- 510,
130
- 1782,
131
- 993
132
- ],
133
- [
134
- 'https://media.roboflow.com/efficient-sam/horses.jpg',
135
- 814,
136
- 696,
137
- 1523,
138
- 1183
139
- ],
140
- [
141
- 'https://media.roboflow.com/efficient-sam/bears.jpg',
142
- 653,
143
- 874,
144
- 1173,
145
- 1229
146
- ]
147
- ],
148
- inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
149
- outputs=[efficient_sam_output_image, sam_output_image],
150
  )
151
 
152
- submit_button.click(
153
- efficient_sam_inference,
154
- inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
155
- outputs=efficient_sam_output_image
156
  )
157
- submit_button.click(
158
- sam_inference,
159
- inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
160
- outputs=sam_output_image
161
  )
162
- input_image.change(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  clear,
164
- inputs=input_image,
165
- outputs=[efficient_sam_output_image, sam_output_image]
166
  )
167
 
168
  demo.launch(debug=False, show_error=True)
 
7
  from PIL import Image
8
  from transformers import SamModel, SamProcessor
9
 
10
+ from utils.efficient_sam import load, inference_with_box, inference_with_point
11
+ from utils.draw import draw_circle, calculate_dynamic_circle_radius
12
 
13
  MARKDOWN = """
14
  # EfficientSAM sv. SAM
 
18
  [SAM](https://arxiv.org/abs/2304.02643).
19
  """
20
 
21
+ BOX_EXAMPLES = [
22
+ ['https://media.roboflow.com/efficient-sam/corgi.jpg', 801, 510, 1782, 993],
23
+ ['https://media.roboflow.com/efficient-sam/horses.jpg', 814, 696, 1523, 1183],
24
+ ['https://media.roboflow.com/efficient-sam/bears.jpg', 653, 874, 1173, 1229]
25
+ ]
26
+
27
+ POINT_EXAMPLES = [
28
+ ['https://media.roboflow.com/efficient-sam/corgi.jpg', 1291, 751],
29
+ ['https://media.roboflow.com/efficient-sam/horses.jpg', 1168, 939],
30
+ ['https://media.roboflow.com/efficient-sam/bears.jpg', 913, 1051]
31
+ ]
32
+
33
+ PROMPT_COLOR = sv.Color.from_hex("#D3D3D3")
34
+ MASK_COLOR = sv.Color.from_hex("#FF0000")
35
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
37
  SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
38
  EFFICIENT_SAM_MODEL = load(device=DEVICE)
39
  MASK_ANNOTATOR = sv.MaskAnnotator(
40
+ color=MASK_COLOR,
 
 
 
41
  color_lookup=sv.ColorLookup.INDEX)
42
 
43
 
44
+ def annotate_image_with_box_prompt_result(
45
+ image: np.ndarray,
46
+ detections: sv.Detections,
47
+ x_min: int,
48
+ y_min: int,
49
+ x_max: int,
50
+ y_max: int
51
+ ) -> np.ndarray:
52
+ h, w, _ = image.shape
53
  bgr_image = image[:, :, ::-1]
54
  annotated_bgr_image = MASK_ANNOTATOR.annotate(
55
  scene=bgr_image, detections=detections)
56
+ annotated_bgr_image = sv.draw_rectangle(
57
+ scene=annotated_bgr_image,
58
+ rect=sv.Rect(
59
+ x=x_min,
60
+ y=y_min,
61
+ width=int(x_max - x_min),
62
+ height=int(y_max - y_min),
63
+ ),
64
+ color=PROMPT_COLOR,
65
+ thickness=sv.calculate_dynamic_line_thickness(resolution_wh=(w, h))
66
+ )
67
  return annotated_bgr_image[:, :, ::-1]
68
 
69
 
70
+ def annotate_image_with_point_prompt_result(
71
+ image: np.ndarray,
72
+ detections: sv.Detections,
73
+ x: int,
74
+ y: int
75
+ ) -> np.ndarray:
76
+ h, w, _ = image.shape
77
+ bgr_image = image[:, :, ::-1]
78
+ annotated_bgr_image = MASK_ANNOTATOR.annotate(
79
+ scene=bgr_image, detections=detections)
80
+ annotated_bgr_image = draw_circle(
81
+ scene=annotated_bgr_image,
82
+ center=sv.Point(x=x, y=y),
83
+ radius=calculate_dynamic_circle_radius(resolution_wh=(w, h)),
84
+ color=PROMPT_COLOR)
85
+ return annotated_bgr_image[:, :, ::-1]
86
+
87
+
88
+ def efficient_sam_box_inference(
89
  image: np.ndarray,
90
  x_min: int,
91
  y_min: int,
 
96
  mask = inference_with_box(image, box, EFFICIENT_SAM_MODEL, DEVICE)
97
  mask = mask[np.newaxis, ...]
98
  detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
99
+ return annotate_image_with_box_prompt_result(
100
+ image=image,
101
+ detections=detections,
102
+ x_max=x_max,
103
+ x_min=x_min,
104
+ y_max=y_max,
105
+ y_min=y_min
106
+ )
107
 
108
 
109
+ def sam_box_inference(
110
  image: np.ndarray,
111
  x_min: int,
112
  y_min: int,
 
130
  )[0][0][0].numpy()
131
  mask = mask[np.newaxis, ...]
132
  detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
133
+ return annotate_image_with_box_prompt_result(
134
+ image=image,
135
+ detections=detections,
136
+ x_max=x_max,
137
+ x_min=x_min,
138
+ y_max=y_max,
139
+ y_min=y_min
140
+ )
141
 
142
 
143
+ def box_inference(
144
  image: np.ndarray,
145
  x_min: int,
146
  y_min: int,
 
148
  y_max: int
149
  ) -> Tuple[np.ndarray, np.ndarray]:
150
  return (
151
+ efficient_sam_box_inference(image, x_min, y_min, x_max, y_max),
152
+ sam_box_inference(image, x_min, y_min, x_max, y_max)
153
+ )
154
+
155
+
156
+ def efficient_sam_point_inference(image: np.ndarray, x: int, y: int) -> np.ndarray:
157
+ point = np.array([[x, y]])
158
+ mask = inference_with_point(image, point, EFFICIENT_SAM_MODEL, DEVICE)
159
+ mask = mask[np.newaxis, ...]
160
+ detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
161
+ return annotate_image_with_point_prompt_result(
162
+ image=image, detections=detections, x=x, y=y)
163
+
164
+
165
+ def sam_point_inference(image: np.ndarray, x: int, y: int) -> np.ndarray:
166
+ input_points = [[[x, y]]]
167
+ inputs = SAM_PROCESSOR(
168
+ Image.fromarray(image),
169
+ input_points=[input_points],
170
+ return_tensors="pt"
171
+ ).to(DEVICE)
172
+
173
+ with torch.no_grad():
174
+ outputs = SAM_MODEL(**inputs)
175
+
176
+ mask = SAM_PROCESSOR.image_processor.post_process_masks(
177
+ outputs.pred_masks.cpu(),
178
+ inputs["original_sizes"].cpu(),
179
+ inputs["reshaped_input_sizes"].cpu()
180
+ )[0][0][0].numpy()
181
+ mask = mask[np.newaxis, ...]
182
+ detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
183
+ return annotate_image_with_point_prompt_result(
184
+ image=image, detections=detections, x=x, y=y)
185
+
186
+
187
+ def point_inference(image: np.ndarray, x: int, y: int) -> Tuple[np.ndarray, np.ndarray]:
188
+ return (
189
+ efficient_sam_point_inference(image, x, y),
190
+ sam_point_inference(image, x, y)
191
  )
192
 
193
 
 
195
  return None, None
196
 
197
 
198
+ box_input_image = gr.Image()
199
+ x_min_number = gr.Number(label="x_min")
200
+ y_min_number = gr.Number(label="y_min")
201
+ x_max_number = gr.Number(label="x_max")
202
+ y_max_number = gr.Number(label="y_max")
203
+ box_inputs = [box_input_image, x_min_number, y_min_number, x_max_number, y_max_number]
204
+
205
+ point_input_image = gr.Image()
206
+ x_number = gr.Number(label="x")
207
+ y_number = gr.Number(label="y")
208
+ point_inputs = [point_input_image, x_number, y_number]
209
+
210
+
211
  with gr.Blocks() as demo:
212
  gr.Markdown(MARKDOWN)
213
  with gr.Tab(label="Box prompt"):
214
  with gr.Row():
215
  with gr.Column():
216
+ box_input_image.render()
217
  with gr.Accordion(label="Box", open=False):
218
  with gr.Row():
219
+ x_min_number.render()
220
+ y_min_number.render()
221
+ x_max_number.render()
222
+ y_max_number.render()
223
+ efficient_sam_box_output_image = gr.Image(label="EfficientSAM")
224
+ sam_box_output_image = gr.Image(label="SAM")
225
  with gr.Row():
226
+ submit_box_inference_button = gr.Button("Submit")
227
+ gr.Examples(
228
+ fn=box_inference,
229
+ examples=BOX_EXAMPLES,
230
+ inputs=box_inputs,
231
+ outputs=[efficient_sam_box_output_image, sam_box_output_image],
232
+ )
233
+ with gr.Tab(label="Point prompt"):
234
+ with gr.Row():
235
+ with gr.Column():
236
+ point_input_image.render()
237
+ with gr.Accordion(label="Point", open=False):
238
+ with gr.Row():
239
+ x_number.render()
240
+ y_number.render()
241
+ efficient_sam_point_output_image = gr.Image(label="EfficientSAM")
242
+ sam_point_output_image = gr.Image(label="SAM")
243
+ with gr.Row():
244
+ submit_point_inference_button = gr.Button("Submit")
245
  gr.Examples(
246
+ fn=point_inference,
247
+ examples=POINT_EXAMPLES,
248
+ inputs=point_inputs,
249
+ outputs=[efficient_sam_point_output_image, sam_point_output_image],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  )
251
 
252
+ submit_box_inference_button.click(
253
+ efficient_sam_box_inference,
254
+ inputs=box_inputs,
255
+ outputs=efficient_sam_box_output_image
256
  )
257
+ submit_box_inference_button.click(
258
+ sam_box_inference,
259
+ inputs=box_inputs,
260
+ outputs=sam_box_output_image
261
  )
262
+
263
+ submit_point_inference_button.click(
264
+ efficient_sam_point_inference,
265
+ inputs=point_inputs,
266
+ outputs=efficient_sam_point_output_image
267
+ )
268
+ submit_point_inference_button.click(
269
+ sam_point_inference,
270
+ inputs=point_inputs,
271
+ outputs=sam_point_output_image
272
+ )
273
+
274
+ box_input_image.change(
275
+ clear,
276
+ inputs=box_input_image,
277
+ outputs=[efficient_sam_box_output_image, sam_box_output_image]
278
+ )
279
+
280
+ point_input_image.change(
281
  clear,
282
+ inputs=point_input_image,
283
+ outputs=[efficient_sam_point_output_image, sam_point_output_image]
284
  )
285
 
286
  demo.launch(debug=False, show_error=True)
utils/draw.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import supervision as sv
6
+
7
+
8
+ def draw_circle(
9
+ scene: np.ndarray, center: sv.Point, color: sv.Color, radius: int = 2
10
+ ) -> np.ndarray:
11
+ cv2.circle(
12
+ scene,
13
+ center=center.as_xy_int_tuple(),
14
+ radius=radius,
15
+ color=color.as_bgr(),
16
+ thickness=-1,
17
+ )
18
+ return scene
19
+
20
+
21
+ def calculate_dynamic_circle_radius(resolution_wh: Tuple[int, int]) -> int:
22
+ min_dimension = min(resolution_wh)
23
+ if min_dimension < 480:
24
+ return 4
25
+ if min_dimension < 720:
26
+ return 8
27
+ if min_dimension < 1080:
28
+ return 8
29
+ if min_dimension < 2160:
30
+ return 16
31
+ else:
32
+ return 16
utils/efficient_sam.py CHANGED
@@ -45,3 +45,36 @@ def inference_with_box(
45
  max_predicted_iou = curr_predicted_iou
46
  selected_mask_using_predicted_iou = all_masks[m]
47
  return selected_mask_using_predicted_iou
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  max_predicted_iou = curr_predicted_iou
46
  selected_mask_using_predicted_iou = all_masks[m]
47
  return selected_mask_using_predicted_iou
48
+
49
+
50
+ def inference_with_point(
51
+ image: np.ndarray,
52
+ point: np.ndarray,
53
+ model: torch.jit.ScriptModule,
54
+ device: torch.device
55
+ ) -> np.ndarray:
56
+ pts_sampled = torch.reshape(torch.tensor(point), [1, 1, -1, 2])
57
+ max_num_pts = pts_sampled.shape[2]
58
+ pts_labels = torch.ones(1, 1, max_num_pts)
59
+ img_tensor = ToTensor()(image)
60
+
61
+ predicted_logits, predicted_iou = model(
62
+ img_tensor[None, ...].to(device),
63
+ pts_sampled.to(device),
64
+ pts_labels.to(device),
65
+ )
66
+ predicted_logits = predicted_logits.cpu()
67
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
68
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
69
+
70
+ max_predicted_iou = -1
71
+ selected_mask_using_predicted_iou = None
72
+ for m in range(all_masks.shape[0]):
73
+ curr_predicted_iou = predicted_iou[m]
74
+ if (
75
+ curr_predicted_iou > max_predicted_iou
76
+ or selected_mask_using_predicted_iou is None
77
+ ):
78
+ max_predicted_iou = curr_predicted_iou
79
+ selected_mask_using_predicted_iou = all_masks[m]
80
+ return selected_mask_using_predicted_iou