|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
import numpy as np
|
|
import random
|
|
import torch
|
|
import torchvision.transforms as tvf
|
|
import argparse
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
import math
|
|
|
|
from mast3r.model import AsymmetricMASt3R
|
|
from mast3r.fast_nn import fast_reciprocal_NNs
|
|
from mast3r.utils.coarse_to_fine import select_pairs_of_crops, crop_slice
|
|
from mast3r.utils.collate import cat_collate, cat_collate_fn_map
|
|
from mast3r.utils.misc import mkdir_for
|
|
from mast3r.datasets.utils.cropping import crop_to_homography
|
|
|
|
import mast3r.utils.path_to_dust3r
|
|
from dust3r.inference import inference, loss_of_one_batch
|
|
from dust3r.utils.geometry import geotrf, colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics
|
|
from dust3r.datasets.utils.transforms import ImgNorm
|
|
from dust3r_visloc.datasets import *
|
|
from dust3r_visloc.localization import run_pnp
|
|
from dust3r_visloc.evaluation import get_pose_error, aggregate_stats, export_results
|
|
from dust3r_visloc.datasets.utils import get_HW_resolution, rescale_points3d
|
|
|
|
|
|
def get_args_parser():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--dataset", type=str, required=True, help="visloc dataset to eval")
|
|
parser_weights = parser.add_mutually_exclusive_group(required=True)
|
|
parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None)
|
|
parser_weights.add_argument("--model_name", type=str, help="name of the model weights",
|
|
choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"])
|
|
|
|
parser.add_argument("--confidence_threshold", type=float, default=1.001,
|
|
help="confidence values higher than threshold are invalid")
|
|
parser.add_argument('--pixel_tol', default=5, type=int)
|
|
|
|
parser.add_argument("--coarse_to_fine", action='store_true', default=False,
|
|
help="do the matching from coarse to fine")
|
|
parser.add_argument("--max_image_size", type=int, default=None,
|
|
help="max image size for the fine resolution")
|
|
parser.add_argument("--c2f_crop_with_homography", action='store_true', default=False,
|
|
help="when using coarse to fine, crop with homographies to keep cx, cy centered")
|
|
|
|
parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
|
|
parser.add_argument("--pnp_mode", type=str, default="cv2", choices=['cv2', 'poselib', 'pycolmap'],
|
|
help="pnp lib to use")
|
|
parser_reproj = parser.add_mutually_exclusive_group()
|
|
parser_reproj.add_argument("--reprojection_error", type=float, default=5.0, help="pnp reprojection error")
|
|
parser_reproj.add_argument("--reprojection_error_diag_ratio", type=float, default=None,
|
|
help="pnp reprojection error as a ratio of the diagonal of the image")
|
|
|
|
parser.add_argument("--max_batch_size", type=int, default=48,
|
|
help="max batch size for inference on crops when using coarse to fine")
|
|
parser.add_argument("--pnp_max_points", type=int, default=100_000, help="pnp maximum number of points kept")
|
|
parser.add_argument("--viz_matches", type=int, default=0, help="debug matches")
|
|
|
|
parser.add_argument("--output_dir", type=str, default=None, help="output path")
|
|
parser.add_argument("--output_label", type=str, default='', help="prefix for results files")
|
|
return parser
|
|
|
|
|
|
@torch.no_grad()
|
|
def coarse_matching(query_view, map_view, model, device, pixel_tol, fast_nn_params):
|
|
|
|
imgs = []
|
|
for idx, img in enumerate([query_view['rgb_rescaled'], map_view['rgb_rescaled']]):
|
|
imgs.append(dict(img=img.unsqueeze(0), true_shape=np.int32([img.shape[1:]]),
|
|
idx=idx, instance=str(idx)))
|
|
output = inference([tuple(imgs)], model, device, batch_size=1, verbose=False)
|
|
pred1, pred2 = output['pred1'], output['pred2']
|
|
conf_list = [pred1['desc_conf'].squeeze(0).cpu().numpy(), pred2['desc_conf'].squeeze(0).cpu().numpy()]
|
|
desc_list = [pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()]
|
|
|
|
|
|
PQ, PM = desc_list[0], desc_list[1]
|
|
if len(PQ) == 0 or len(PM) == 0:
|
|
return [], [], [], []
|
|
|
|
if pixel_tol == 0:
|
|
matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, subsample_or_initxy1=8, **fast_nn_params)
|
|
HM, WM = map_view['rgb_rescaled'].shape[1:]
|
|
HQ, WQ = query_view['rgb_rescaled'].shape[1:]
|
|
|
|
valid_matches_map = (matches_im_map[:, 0] >= 3) & (matches_im_map[:, 0] < WM - 3) & (
|
|
matches_im_map[:, 1] >= 3) & (matches_im_map[:, 1] < HM - 3)
|
|
valid_matches_query = (matches_im_query[:, 0] >= 3) & (matches_im_query[:, 0] < WQ - 3) & (
|
|
matches_im_query[:, 1] >= 3) & (matches_im_query[:, 1] < HQ - 3)
|
|
valid_matches = valid_matches_map & valid_matches_query
|
|
matches_im_map = matches_im_map[valid_matches]
|
|
matches_im_query = matches_im_query[valid_matches]
|
|
valid_pts3d = []
|
|
matches_confs = []
|
|
else:
|
|
yM, xM = torch.where(map_view['valid_rescaled'])
|
|
matches_im_map, matches_im_query = fast_reciprocal_NNs(PM, PQ, (xM, yM), pixel_tol=pixel_tol, **fast_nn_params)
|
|
valid_pts3d = map_view['pts3d_rescaled'].cpu().numpy()[matches_im_map[:, 1], matches_im_map[:, 0]]
|
|
matches_confs = np.minimum(
|
|
conf_list[1][matches_im_map[:, 1], matches_im_map[:, 0]],
|
|
conf_list[0][matches_im_query[:, 1], matches_im_query[:, 0]]
|
|
)
|
|
|
|
matches_im_query = matches_im_query.astype(np.float64)
|
|
matches_im_map = matches_im_map.astype(np.float64)
|
|
matches_im_query[:, 0] += 0.5
|
|
matches_im_query[:, 1] += 0.5
|
|
matches_im_map[:, 0] += 0.5
|
|
matches_im_map[:, 1] += 0.5
|
|
|
|
matches_im_query = geotrf(query_view['to_orig'], matches_im_query, norm=True)
|
|
matches_im_map = geotrf(map_view['to_orig'], matches_im_map, norm=True)
|
|
|
|
matches_im_query[:, 0] -= 0.5
|
|
matches_im_query[:, 1] -= 0.5
|
|
matches_im_map[:, 0] -= 0.5
|
|
matches_im_map[:, 1] -= 0.5
|
|
return valid_pts3d, matches_im_query, matches_im_map, matches_confs
|
|
|
|
|
|
@torch.no_grad()
|
|
def crops_inference(pairs, model, device, batch_size=48, verbose=True):
|
|
assert len(pairs) == 2, "Error, data should be a tuple of dicts containing the batch of image pairs"
|
|
|
|
B = pairs[0]['img'].shape[0]
|
|
if B < batch_size:
|
|
return loss_of_one_batch(pairs, model, None, device=device, symmetrize_batch=False)
|
|
preds = []
|
|
for ii in range(0, B, batch_size):
|
|
sel = slice(ii, ii + min(B - ii, batch_size))
|
|
temp_data = [{}, {}]
|
|
for di in [0, 1]:
|
|
temp_data[di] = {kk: pairs[di][kk][sel]
|
|
for kk in pairs[di].keys() if pairs[di][kk] is not None}
|
|
preds.append(loss_of_one_batch(temp_data, model,
|
|
None, device=device, symmetrize_batch=False))
|
|
|
|
return cat_collate(preds, collate_fn_map=cat_collate_fn_map)
|
|
|
|
|
|
@torch.no_grad()
|
|
def fine_matching(query_views, map_views, model, device, max_batch_size, pixel_tol, fast_nn_params):
|
|
assert pixel_tol > 0
|
|
output = crops_inference([query_views, map_views],
|
|
model, device, batch_size=max_batch_size, verbose=False)
|
|
pred1, pred2 = output['pred1'], output['pred2']
|
|
descs1 = pred1['desc'].clone()
|
|
descs2 = pred2['desc'].clone()
|
|
confs1 = pred1['desc_conf'].clone()
|
|
confs2 = pred2['desc_conf'].clone()
|
|
|
|
|
|
valid_pts3d, matches_im_map, matches_im_query, matches_confs = [], [], [], []
|
|
for ppi, (pp1, pp2, cc11, cc21) in enumerate(zip(descs1, descs2, confs1, confs2)):
|
|
valid_ppi = map_views['valid'][ppi]
|
|
pts3d_ppi = map_views['pts3d'][ppi].cpu().numpy()
|
|
conf_list_ppi = [cc11.cpu().numpy(), cc21.cpu().numpy()]
|
|
|
|
y_ppi, x_ppi = torch.where(valid_ppi)
|
|
matches_im_map_ppi, matches_im_query_ppi = fast_reciprocal_NNs(pp2, pp1, (x_ppi, y_ppi),
|
|
pixel_tol=pixel_tol, **fast_nn_params)
|
|
|
|
valid_pts3d_ppi = pts3d_ppi[matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]]
|
|
matches_confs_ppi = np.minimum(
|
|
conf_list_ppi[1][matches_im_map_ppi[:, 1], matches_im_map_ppi[:, 0]],
|
|
conf_list_ppi[0][matches_im_query_ppi[:, 1], matches_im_query_ppi[:, 0]]
|
|
)
|
|
|
|
matches_im_map_ppi = geotrf(map_views['to_orig'][ppi].cpu().numpy(), matches_im_map_ppi.copy(), norm=True)
|
|
matches_im_query_ppi = geotrf(query_views['to_orig'][ppi].cpu().numpy(), matches_im_query_ppi.copy(), norm=True)
|
|
|
|
matches_im_map.append(matches_im_map_ppi)
|
|
matches_im_query.append(matches_im_query_ppi)
|
|
valid_pts3d.append(valid_pts3d_ppi)
|
|
matches_confs.append(matches_confs_ppi)
|
|
|
|
if len(valid_pts3d) == 0:
|
|
return [], [], [], []
|
|
|
|
matches_im_map = np.concatenate(matches_im_map, axis=0)
|
|
matches_im_query = np.concatenate(matches_im_query, axis=0)
|
|
valid_pts3d = np.concatenate(valid_pts3d, axis=0)
|
|
matches_confs = np.concatenate(matches_confs, axis=0)
|
|
return valid_pts3d, matches_im_query, matches_im_map, matches_confs
|
|
|
|
|
|
def crop(img, mask, pts3d, crop, intrinsics=None):
|
|
out_cropped_img = img.clone()
|
|
if mask is not None:
|
|
out_cropped_mask = mask.clone()
|
|
else:
|
|
out_cropped_mask = None
|
|
if pts3d is not None:
|
|
out_cropped_pts3d = pts3d.clone()
|
|
else:
|
|
out_cropped_pts3d = None
|
|
to_orig = torch.eye(3, device=img.device)
|
|
|
|
|
|
if intrinsics is not None:
|
|
K_old = intrinsics
|
|
imsize, K_new, R, H = crop_to_homography(K_old, crop)
|
|
|
|
H /= H[2, 2]
|
|
homo8 = H.ravel().tolist()[:8]
|
|
|
|
pilim = Image.fromarray((255 * (img + 1.) / 2).to(torch.uint8).numpy())
|
|
pilout_cropped_img = pilim.transform(imsize, Image.Transform.PERSPECTIVE,
|
|
homo8, resample=Image.Resampling.BICUBIC)
|
|
|
|
|
|
out_cropped_img = 2. * torch.tensor(np.array(pilout_cropped_img)).to(img) / 255. - 1.
|
|
if out_cropped_mask is not None:
|
|
pilmask = Image.fromarray((255 * out_cropped_mask).to(torch.uint8).numpy())
|
|
pilout_cropped_mask = pilmask.transform(
|
|
imsize, Image.Transform.PERSPECTIVE, homo8, resample=Image.Resampling.NEAREST)
|
|
out_cropped_mask = torch.from_numpy(np.array(pilout_cropped_mask) > 0).to(out_cropped_mask.dtype)
|
|
if out_cropped_pts3d is not None:
|
|
out_cropped_pts3d = out_cropped_pts3d.numpy()
|
|
out_cropped_X = np.array(Image.fromarray(out_cropped_pts3d[:, :, 0]).transform(imsize,
|
|
Image.Transform.PERSPECTIVE,
|
|
homo8,
|
|
resample=Image.Resampling.NEAREST))
|
|
out_cropped_Y = np.array(Image.fromarray(out_cropped_pts3d[:, :, 1]).transform(imsize,
|
|
Image.Transform.PERSPECTIVE,
|
|
homo8,
|
|
resample=Image.Resampling.NEAREST))
|
|
out_cropped_Z = np.array(Image.fromarray(out_cropped_pts3d[:, :, 2]).transform(imsize,
|
|
Image.Transform.PERSPECTIVE,
|
|
homo8,
|
|
resample=Image.Resampling.NEAREST))
|
|
|
|
out_cropped_pts3d = torch.from_numpy(np.stack([out_cropped_X, out_cropped_Y, out_cropped_Z], axis=-1))
|
|
|
|
to_orig = torch.tensor(H, device=img.device)
|
|
else:
|
|
out_cropped_img = img[crop_slice(crop)]
|
|
if out_cropped_mask is not None:
|
|
out_cropped_mask = out_cropped_mask[crop_slice(crop)]
|
|
if out_cropped_pts3d is not None:
|
|
out_cropped_pts3d = out_cropped_pts3d[crop_slice(crop)]
|
|
to_orig[:2, -1] = torch.tensor(crop[:2])
|
|
|
|
return out_cropped_img, out_cropped_mask, out_cropped_pts3d, to_orig
|
|
|
|
|
|
def resize_image_to_max(max_image_size, rgb, K):
|
|
W, H = rgb.size
|
|
if max_image_size and max(W, H) > max_image_size:
|
|
islandscape = (W >= H)
|
|
if islandscape:
|
|
WMax = max_image_size
|
|
HMax = int(H * (WMax / W))
|
|
else:
|
|
HMax = max_image_size
|
|
WMax = int(W * (HMax / H))
|
|
resize_op = tvf.Compose([ImgNorm, tvf.Resize(size=[HMax, WMax])])
|
|
rgb_tensor = resize_op(rgb).permute(1, 2, 0)
|
|
to_orig_max = np.array([[W / WMax, 0, 0],
|
|
[0, H / HMax, 0],
|
|
[0, 0, 1]])
|
|
to_resize_max = np.array([[WMax / W, 0, 0],
|
|
[0, HMax / H, 0],
|
|
[0, 0, 1]])
|
|
|
|
|
|
new_K = opencv_to_colmap_intrinsics(K)
|
|
new_K[0, :] *= WMax / W
|
|
new_K[1, :] *= HMax / H
|
|
new_K = colmap_to_opencv_intrinsics(new_K)
|
|
else:
|
|
rgb_tensor = ImgNorm(rgb).permute(1, 2, 0)
|
|
to_orig_max = np.eye(3)
|
|
to_resize_max = np.eye(3)
|
|
HMax, WMax = H, W
|
|
new_K = K
|
|
return rgb_tensor, new_K, to_orig_max, to_resize_max, (HMax, WMax)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = get_args_parser()
|
|
args = parser.parse_args()
|
|
conf_thr = args.confidence_threshold
|
|
device = args.device
|
|
pnp_mode = args.pnp_mode
|
|
assert args.pixel_tol > 0
|
|
reprojection_error = args.reprojection_error
|
|
reprojection_error_diag_ratio = args.reprojection_error_diag_ratio
|
|
pnp_max_points = args.pnp_max_points
|
|
viz_matches = args.viz_matches
|
|
|
|
if args.weights is not None:
|
|
weights_path = args.weights
|
|
else:
|
|
weights_path = "naver/" + args.model_name
|
|
model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
|
|
fast_nn_params = dict(device=device, dist='dot', block_size=2**13)
|
|
dataset = eval(args.dataset)
|
|
dataset.set_resolution(model)
|
|
|
|
query_names = []
|
|
poses_pred = []
|
|
pose_errors = []
|
|
angular_errors = []
|
|
params_str = f'tol_{args.pixel_tol}' + ("_c2f" if args.coarse_to_fine else '')
|
|
if args.max_image_size is not None:
|
|
params_str = params_str + f'_{args.max_image_size}'
|
|
if args.coarse_to_fine and args.c2f_crop_with_homography:
|
|
params_str = params_str + '_with_homography'
|
|
for idx in tqdm(range(len(dataset))):
|
|
views = dataset[(idx)]
|
|
query_view = views[0]
|
|
map_views = views[1:]
|
|
query_names.append(query_view['image_name'])
|
|
|
|
query_pts2d = []
|
|
query_pts3d = []
|
|
maxdim = max(model.patch_embed.img_size)
|
|
query_rgb_tensor, query_K, query_to_orig_max, query_to_resize_max, (HQ, WQ) = resize_image_to_max(
|
|
args.max_image_size, query_view['rgb'], query_view['intrinsics'])
|
|
|
|
|
|
query_resolution = get_HW_resolution(HQ, WQ, maxdim=maxdim, patchsize=model.patch_embed.patch_size)
|
|
for map_view in map_views:
|
|
if args.output_dir is not None:
|
|
cache_file = os.path.join(args.output_dir, 'matches', params_str,
|
|
query_view['image_name'], map_view['image_name'] + '.npz')
|
|
else:
|
|
cache_file = None
|
|
|
|
if cache_file is not None and os.path.isfile(cache_file):
|
|
matches = np.load(cache_file)
|
|
valid_pts3d = matches['valid_pts3d']
|
|
matches_im_query = matches['matches_im_query']
|
|
matches_im_map = matches['matches_im_map']
|
|
matches_conf = matches['matches_conf']
|
|
else:
|
|
|
|
if args.coarse_to_fine and (maxdim < max(WQ, HQ)):
|
|
|
|
_, coarse_matches_im0, coarse_matches_im1, _ = coarse_matching(query_view, map_view, model, device,
|
|
0, fast_nn_params)
|
|
|
|
|
|
if viz_matches > 0:
|
|
num_matches = coarse_matches_im1.shape[0]
|
|
print(f'found {num_matches} matches')
|
|
|
|
viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])]
|
|
from matplotlib import pyplot as pl
|
|
n_viz = viz_matches
|
|
match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int)
|
|
viz_matches_im_query = coarse_matches_im0[match_idx_to_viz]
|
|
viz_matches_im_map = coarse_matches_im1[match_idx_to_viz]
|
|
|
|
H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2]
|
|
img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)),
|
|
'constant', constant_values=0)
|
|
img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)),
|
|
'constant', constant_values=0)
|
|
img = np.concatenate((img0, img1), axis=1)
|
|
pl.figure()
|
|
pl.imshow(img)
|
|
cmap = pl.get_cmap('jet')
|
|
for i in range(n_viz):
|
|
(x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T
|
|
pl.plot([x0, x1 + W0], [y0, y1], '-+',
|
|
color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
|
|
pl.show(block=True)
|
|
|
|
valid_all = map_view['valid']
|
|
pts3d = map_view['pts3d']
|
|
|
|
WM_full, HM_full = map_view['rgb'].size
|
|
map_rgb_tensor, map_K, map_to_orig_max, map_to_resize_max, (HM, WM) = resize_image_to_max(
|
|
args.max_image_size, map_view['rgb'], map_view['intrinsics'])
|
|
if WM_full != WM or HM_full != HM:
|
|
y_full, x_full = torch.where(valid_all)
|
|
pos2d_cv2 = torch.stack([x_full, y_full], dim=-1).cpu().numpy().astype(np.float64)
|
|
sparse_pts3d = pts3d[y_full, x_full].cpu().numpy()
|
|
_, _, pts3d_max, valid_max = rescale_points3d(
|
|
pos2d_cv2, sparse_pts3d, map_to_resize_max, HM, WM)
|
|
pts3d = torch.from_numpy(pts3d_max)
|
|
valid_all = torch.from_numpy(valid_max)
|
|
|
|
coarse_matches_im0 = geotrf(query_to_resize_max, coarse_matches_im0, norm=True)
|
|
coarse_matches_im1 = geotrf(map_to_resize_max, coarse_matches_im1, norm=True)
|
|
|
|
crops1, crops2 = [], []
|
|
crops_v1, crops_p1 = [], []
|
|
to_orig1, to_orig2 = [], []
|
|
map_resolution = get_HW_resolution(HM, WM, maxdim=maxdim, patchsize=model.patch_embed.patch_size)
|
|
|
|
for crop_q, crop_b, pair_tag in select_pairs_of_crops(map_rgb_tensor,
|
|
query_rgb_tensor,
|
|
coarse_matches_im1,
|
|
coarse_matches_im0,
|
|
maxdim=maxdim,
|
|
overlap=.5,
|
|
forced_resolution=[map_resolution,
|
|
query_resolution]):
|
|
|
|
if not args.c2f_crop_with_homography:
|
|
map_K = None
|
|
query_K = None
|
|
|
|
c1, v1, p1, trf1 = crop(map_rgb_tensor, valid_all, pts3d, crop_q, map_K)
|
|
c2, _, _, trf2 = crop(query_rgb_tensor, None, None, crop_b, query_K)
|
|
crops1.append(c1)
|
|
crops2.append(c2)
|
|
crops_v1.append(v1)
|
|
crops_p1.append(p1)
|
|
to_orig1.append(trf1)
|
|
to_orig2.append(trf2)
|
|
|
|
if len(crops1) == 0 or len(crops2) == 0:
|
|
valid_pts3d, matches_im_query, matches_im_map, matches_conf = [], [], [], []
|
|
else:
|
|
crops1, crops2 = torch.stack(crops1), torch.stack(crops2)
|
|
if len(crops1.shape) == 3:
|
|
crops1, crops2 = crops1[None], crops2[None]
|
|
crops_v1 = torch.stack(crops_v1)
|
|
crops_p1 = torch.stack(crops_p1)
|
|
to_orig1, to_orig2 = torch.stack(to_orig1), torch.stack(to_orig2)
|
|
map_crop_view = dict(img=crops1.permute(0, 3, 1, 2),
|
|
instance=['1' for _ in range(crops1.shape[0])],
|
|
valid=crops_v1, pts3d=crops_p1,
|
|
to_orig=to_orig1)
|
|
query_crop_view = dict(img=crops2.permute(0, 3, 1, 2),
|
|
instance=['2' for _ in range(crops2.shape[0])],
|
|
to_orig=to_orig2)
|
|
|
|
|
|
valid_pts3d, matches_im_query, matches_im_map, matches_conf = fine_matching(query_crop_view,
|
|
map_crop_view,
|
|
model, device,
|
|
args.max_batch_size,
|
|
args.pixel_tol,
|
|
fast_nn_params)
|
|
matches_im_query = geotrf(query_to_orig_max, matches_im_query, norm=True)
|
|
matches_im_map = geotrf(map_to_orig_max, matches_im_map, norm=True)
|
|
else:
|
|
|
|
valid_pts3d, matches_im_query, matches_im_map, matches_conf = coarse_matching(query_view, map_view,
|
|
model, device,
|
|
args.pixel_tol,
|
|
fast_nn_params)
|
|
if cache_file is not None:
|
|
mkdir_for(cache_file)
|
|
np.savez(cache_file, valid_pts3d=valid_pts3d, matches_im_query=matches_im_query,
|
|
matches_im_map=matches_im_map, matches_conf=matches_conf)
|
|
|
|
|
|
if len(matches_conf) > 0:
|
|
mask = matches_conf >= conf_thr
|
|
valid_pts3d = valid_pts3d[mask]
|
|
matches_im_query = matches_im_query[mask]
|
|
matches_im_map = matches_im_map[mask]
|
|
matches_conf = matches_conf[mask]
|
|
|
|
|
|
if viz_matches > 0:
|
|
num_matches = matches_im_map.shape[0]
|
|
print(f'found {num_matches} matches')
|
|
|
|
viz_imgs = [np.array(query_view['rgb']), np.array(map_view['rgb'])]
|
|
from matplotlib import pyplot as pl
|
|
n_viz = viz_matches
|
|
match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int)
|
|
viz_matches_im_query = matches_im_query[match_idx_to_viz]
|
|
viz_matches_im_map = matches_im_map[match_idx_to_viz]
|
|
|
|
H0, W0, H1, W1 = *viz_imgs[0].shape[:2], *viz_imgs[1].shape[:2]
|
|
img0 = np.pad(viz_imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
|
|
img1 = np.pad(viz_imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
|
|
img = np.concatenate((img0, img1), axis=1)
|
|
pl.figure()
|
|
pl.imshow(img)
|
|
cmap = pl.get_cmap('jet')
|
|
for i in range(n_viz):
|
|
(x0, y0), (x1, y1) = viz_matches_im_query[i].T, viz_matches_im_map[i].T
|
|
pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
|
|
pl.show(block=True)
|
|
|
|
if len(valid_pts3d) == 0:
|
|
pass
|
|
else:
|
|
query_pts3d.append(valid_pts3d)
|
|
query_pts2d.append(matches_im_query)
|
|
|
|
if len(query_pts2d) == 0:
|
|
success = False
|
|
pr_querycam_to_world = None
|
|
else:
|
|
query_pts2d = np.concatenate(query_pts2d, axis=0).astype(np.float32)
|
|
query_pts3d = np.concatenate(query_pts3d, axis=0)
|
|
if len(query_pts2d) > pnp_max_points:
|
|
idxs = random.sample(range(len(query_pts2d)), pnp_max_points)
|
|
query_pts3d = query_pts3d[idxs]
|
|
query_pts2d = query_pts2d[idxs]
|
|
|
|
W, H = query_view['rgb'].size
|
|
if reprojection_error_diag_ratio is not None:
|
|
reprojection_error_img = reprojection_error_diag_ratio * math.sqrt(W**2 + H**2)
|
|
else:
|
|
reprojection_error_img = reprojection_error
|
|
success, pr_querycam_to_world = run_pnp(query_pts2d, query_pts3d,
|
|
query_view['intrinsics'], query_view['distortion'],
|
|
pnp_mode, reprojection_error_img, img_size=[W, H])
|
|
|
|
if not success:
|
|
abs_transl_error = float('inf')
|
|
abs_angular_error = float('inf')
|
|
else:
|
|
abs_transl_error, abs_angular_error = get_pose_error(pr_querycam_to_world, query_view['cam_to_world'])
|
|
|
|
pose_errors.append(abs_transl_error)
|
|
angular_errors.append(abs_angular_error)
|
|
poses_pred.append(pr_querycam_to_world)
|
|
|
|
xp_label = params_str + f'_conf_{conf_thr}'
|
|
if args.output_label:
|
|
xp_label = args.output_label + "_" + xp_label
|
|
if reprojection_error_diag_ratio is not None:
|
|
xp_label = xp_label + f'_reproj_diag_{reprojection_error_diag_ratio}'
|
|
else:
|
|
xp_label = xp_label + f'_reproj_err_{reprojection_error}'
|
|
export_results(args.output_dir, xp_label, query_names, poses_pred)
|
|
out_string = aggregate_stats(f'{args.dataset}', pose_errors, angular_errors)
|
|
print(out_string)
|
|
|