FrozenSeg / save_sam_masks.py
xichen98cn's picture
Upload 135 files
3dac99f verified
raw
history blame contribute delete
No virus
4.78 kB
'''
Save SAM mask predictions
'''
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import torch.multiprocessing as mp
import pickle
from tqdm import tqdm
import torch
import cv2
import os
import json
import argparse
import numpy as np
img_anno = {
'ade20k_val':['ADEChallengeData2016/images/validation', 'ADEChallengeData2016/ade20k_panoptic_val.json'],
'pc_val': ['pascal_ctx_d2/images/validation','' ],
'pas_val':['pascal_voc_d2/images/validation',''],
}
sam_checkpoint_dict = {
'vit_b': 'pretrained_checkpoint/sam_vit_b_01ec64.pth',
'vit_h': 'pretrained_checkpoint/sam_vit_h_4b8939.pth',
'vit_l': 'pretrained_checkpoint/sam_vit_l_0b3195.pth',
'vit_t': 'pretrained_checkpoint/mobile_sam.pt'
}
def process_images(args, gpu, data_chunk, save_path, if_parallel):
def to_parallel(if_parallel):
sam_checkpoint = sam_checkpoint_dict[args.sam_model]
sam = sam_model_registry[args.sam_model](checkpoint=sam_checkpoint)
if not if_parallel:
torch.cuda.set_device(gpu)
sam = sam.cuda()
else:
sam = sam.cuda()
sam = torch.nn.DataParallel(sam)
sam = sam.module
return sam
sam = to_parallel(if_parallel)
mask_generator = SamAutomaticMaskGenerator(
model=sam,
pred_iou_thresh=0.8,
stability_score_thresh=0.7,
crop_n_layers=0,
crop_n_points_downscale_factor=2,
min_mask_region_area=100,
output_mode='coco_rle'
)
# Process each image
for image_info in tqdm(data_chunk):
if isinstance(image_info, dict):
if 'coco_url' in image_info:
coco_url = image_info['coco_url']
file_name = coco_url.split('/')[-1].split('.')[0] + '.jpg'
elif 'file_name' in image_info:
file_name = image_info['file_name'].split('.')[0] + '.jpg'
file_path = os.path.join(dataset_path,img_anno[args.data_name][0])
else:
assert isinstance(image_info, str)
file_name = image_info.split('.')[0] + '.jpg'
file_path = os.path.join(dataset_path,img_anno[args.data_name][0])
image_path = f'{file_path}/{file_name}'
try:
id =file_name.split('.')[0]
id = id.replace('/','_')
savepath = f'{save_path}/{id}.pkl'
if not os.path.exists(savepath):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
everything_mask = mask_generator.generate(image, train=False)
everything_mask = sorted(everything_mask, key=lambda x: x['area'], reverse=True)
if len(everything_mask) >50:
everything_mask = everything_mask[:50]
with open(savepath, 'wb') as f:
pickle.dump(everything_mask, f)
except Exception as e:
print(f"Failed to load or convert image at {image_path}. Error: {e}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_name', type=str, default='pas_val')
parser.add_argument('--sam_model', type=str, default='vit_h')
argss = parser.parse_args()
gpus = os.getenv("CUDA_VISIBLE_DEVICES", "")
dataset_path = os.getenv("DETECTRON2_DATASETS", "/users/cx_xchen/DATASETS/")
num_gpus = len([x.strip() for x in gpus.split(",") if x.strip().isdigit()])
print(f"Using {num_gpus} GPUs")
# File paths
if img_anno[argss.data_name][1] != '':
json_file_path = os.path.join(dataset_path, img_anno[argss.data_name][1])
# Load data
with open(json_file_path, 'r') as file:
data = json.load(file)
# Split data into chunks for each GPU
data_chunks = np.array_split(data['images'], num_gpus)
else:
image_dir = os.path.join(dataset_path, img_anno[argss.data_name][0])
image_files = os.listdir(image_dir)
data_chunks = np.array_split(image_files, num_gpus)
# Create processes
save_path = f'output/SAM_masks_pred/{argss.sam_model}_{argss.data_name}'
if not os.path.exists(save_path):
os.makedirs(save_path)
processes = []
parallel = False
# if parallel:
# assert num_gpus>1
for gpu in range(num_gpus):
p = mp.Process(target=process_images, args=(argss, gpu, data_chunks[gpu],save_path, False))
p.start()
processes.append(p)
for p in processes:
p.join()
# elif num_gpus<=1:
# process_images(argss, None, np.concatenate(data_chunks), save_path, if_parallel=True)
# else:
# assert NotImplemented