ViewCrafter / viewcrafter.py
Drexubery's picture
update
600dfac
raw
history blame contribute delete
No virus
23.7 kB
import sys
sys.path.append('./extern/dust3r')
from dust3r.inference import inference, load_model
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from dust3r.utils.device import to_numpy
import trimesh
import torch
import numpy as np
import torchvision
import os
import copy
import cv2
from PIL import Image
import pytorch3d
print(pytorch3d.__version__)
from pytorch3d.structures import Pointclouds
from torchvision.utils import save_image
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from utils.pvd_utils import *
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from utils.diffusion_utils import instantiate_from_config,load_model_checkpoint,image_guided_synthesis
from pathlib import Path
from torchvision.utils import save_image
class ViewCrafter:
def __init__(self, opts, gradio = False):
self.opts = opts
self.device = opts.device
self.setup_dust3r()
self.setup_diffusion() #fixme
# initialize ref images, pcd
if not gradio:
self.images, self.img_ori = self.load_initial_images(image_dir=self.opts.image_dir)
self.run_dust3r(input_images=self.images)
def run_dust3r(self, input_images,clean_pc = False):
pairs = make_pairs(input_images, scene_graph='complete', prefilter=None, symmetrize=True)
output = inference(pairs, self.dust3r, self.device, batch_size=self.opts.batch_size)
mode = GlobalAlignerMode.PointCloudOptimizer #if len(self.images) > 2 else GlobalAlignerMode.PairViewer
scene = global_aligner(output, device=self.device, mode=mode)
if mode == GlobalAlignerMode.PointCloudOptimizer:
loss = scene.compute_global_alignment(init='mst', niter=self.opts.niter, schedule=self.opts.schedule, lr=self.opts.lr)
if clean_pc:
self.scene = scene.clean_pointcloud()
else:
self.scene = scene
def render_pcd(self,pts3d,imgs,masks,views,renderer,device):
imgs = to_numpy(imgs)
pts3d = to_numpy(pts3d)
if masks == None:
pts = torch.from_numpy(np.concatenate([p for p in pts3d])).view(-1, 3).to(device)
col = torch.from_numpy(np.concatenate([p for p in imgs])).view(-1, 3).to(device)
else:
# masks = to_numpy(masks)
pts = torch.from_numpy(np.concatenate([p[m] for p, m in zip(pts3d, masks)])).to(device)
col = torch.from_numpy(np.concatenate([p[m] for p, m in zip(imgs, masks)])).to(device)
color_mask = torch.ones(col.shape).to(device)
# point_cloud_mask = Pointclouds(points=[pts],features=[color_mask]).extend(views)
point_cloud = Pointclouds(points=[pts], features=[col]).extend(views)
images = renderer(point_cloud)
# view_masks = renderer(point_cloud_mask)
return images, None
def run_render(self, pcd, imgs,masks, H, W, camera_traj,num_views,use_cpu=False):
if use_cpu:
device = torch.device("cpu")
else:
device = self.device
render_setup = setup_renderer(camera_traj, image_size=(H,W))
renderer = render_setup['renderer']
render_results, viewmask = self.render_pcd(pcd, imgs, masks, num_views,renderer,device)
return render_results, viewmask
def run_diffusion(self, renderings):
prompts = [self.opts.prompt]
videos = (renderings * 2. - 1.).permute(3,0,1,2).unsqueeze(0).to(self.device)
condition_index = [0]
with torch.no_grad(), torch.cuda.amp.autocast():
# [1,1,c,t,h,w]
batch_samples = image_guided_synthesis(self.diffusion, prompts, videos, self.noise_shape, self.opts.n_samples, self.opts.ddim_steps, self.opts.ddim_eta, \
self.opts.unconditional_guidance_scale, self.opts.cfg_img, self.opts.frame_stride, self.opts.text_input, self.opts.multiple_cond_cfg, self.opts.timestep_spacing, self.opts.guidance_rescale, condition_index)
# save_results_seperate(batch_samples[0], self.opts.save_dir, fps=8)
# torch.Size([1, 3, 25, 576, 1024]) [-1,1]
return torch.clamp(batch_samples[0][0].permute(1,2,3,0), -1., 1.)
def nvs_single_view(self, gradio=False):
# 最后一个view为 0 pose
c2ws = self.scene.get_im_poses().detach()[1:]
principal_points = self.scene.get_principal_points().detach()[1:] #cx cy
focals = self.scene.get_focals().detach()[1:]
shape = self.images[0]['true_shape']
H, W = int(shape[0][0]), int(shape[0][1])
pcd = [i.detach() for i in self.scene.get_pts3d(clip_thred=self.opts.dpt_trd)] # a list of points of size whc
depth = [i.detach() for i in self.scene.get_depthmaps()]
depth_avg = depth[-1][H//2,W//2] #以图像中心处的depth(z)为球心旋转
radius = depth_avg*self.opts.center_scale #缩放调整
## change coordinate
c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=-1, r=radius, elevation=self.opts.elevation, device=self.device)
imgs = np.array(self.scene.imgs)
masks = None
if self.opts.mode == 'single_view_nbv':
## 输入candidate->渲染mask->最大mask对应的pose作为nbv
## nbv模式下self.opts.d_theta[0], self.opts.d_phi[0]代表search space中的网格theta, phi之间的间距; self.opts.d_phi[0]的符号代表方向,分为左右两个方向
## FIXME hard coded candidate view数量, 以left为例,第一次迭代从[左,左上]中选取, 从第二次开始可以从[左,左上,左下]中选取
num_candidates = 2
candidate_poses,thetas,phis = generate_candidate_poses(c2ws, H, W, focals, principal_points, self.opts.d_theta[0], self.opts.d_phi[0],num_candidates, self.device)
_, viewmask = self.run_render([pcd[-1]], [imgs[-1]],masks, H, W, candidate_poses,num_candidates,use_cpu=False)
nbv_id = torch.argmin(viewmask.sum(dim=[1,2,3])).item()
save_image( viewmask.permute(0,3,1,2), os.path.join(self.opts.save_dir,f"candidate_mask0_nbv{nbv_id}.png"), normalize=True, value_range=(0, 1))
theta_nbv = thetas[nbv_id]
phi_nbv = phis[nbv_id]
# generate camera trajectory from T_curr to T_nbv
camera_traj,num_views = generate_traj_specified(c2ws, H, W, focals, principal_points, theta_nbv, phi_nbv, self.opts.d_r[0],self.opts.video_length, self.device)
# 重置elevation
self.opts.elevation -= theta_nbv
elif self.opts.mode == 'single_view_target':
camera_traj,num_views = generate_traj_specified(c2ws, H, W, focals, principal_points, self.opts.d_theta[0], self.opts.d_phi[0], self.opts.d_r[0],self.opts.video_length, self.device)
elif self.opts.mode == 'single_view_txt':
if not gradio:
with open(self.opts.traj_txt, 'r') as file:
lines = file.readlines()
phi = [float(i) for i in lines[0].split()]
theta = [float(i) for i in lines[1].split()]
r = [float(i) for i in lines[2].split()]
else:
phi, theta, r = self.gradio_traj
# device = torch.device("cpu")
device = self.device
camera_traj,num_views = generate_traj_txt(c2ws, H, W, focals, principal_points, phi, theta, r,self.opts.video_length, device,viz_traj=True, save_dir = self.opts.save_dir)
# camera_traj,num_views = generate_traj_txt(c2ws, H, W, focals, principal_points, phi, theta, r,self.opts.video_length, self.device,viz_traj=True, save_dir = self.opts.save_dir)
else:
raise KeyError(f"Invalid Mode: {self.opts.mode}")
render_results, viewmask = self.run_render([pcd[-1]], [imgs[-1]],masks, H, W, camera_traj,num_views,use_cpu=False)
render_results = render_results.to(self.device)
render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1)
render_results[0] = self.img_ori
if self.opts.mode == 'single_view_txt':
if phi[-1]==0. and theta[-1]==0. and r[-1]==0.:
render_results[-1] = self.img_ori
# torch.Size([25, 576, 1024, 3]), [0,1]
# save_pointcloud_with_normals([imgs[-1]], [pcd[-1]], msk=None, save_path=os.path.join(self.opts.save_dir,'pcd0.ply') , mask_pc=False, reduce_pc=False)
diffusion_results = self.run_diffusion(render_results)
save_video((diffusion_results + 1.0) / 2.0, os.path.join(self.opts.save_dir, 'diffusion0.mp4'))
return render_results
def nvs_sparse_view(self,iter):
c2ws = self.scene.get_im_poses().detach()
principal_points = self.scene.get_principal_points().detach()
focals = self.scene.get_focals().detach()
shape = self.images[0]['true_shape']
H, W = int(shape[0][0]), int(shape[0][1])
pcd = [i.detach() for i in self.scene.get_pts3d(clip_thred=self.opts.dpt_trd)] # a list of points of size whc
depth = [i.detach() for i in self.scene.get_depthmaps()]
depth_avg = depth[0][H//2,W//2] #以ref图像中心处的depth(z)为球心旋转
radius = depth_avg*self.opts.center_scale #缩放调整
## masks for cleaner point cloud
self.scene.min_conf_thr = float(self.scene.conf_trf(torch.tensor(self.opts.min_conf_thr)))
masks = self.scene.get_masks()
depth = self.scene.get_depthmaps()
bgs_mask = [dpt > self.opts.bg_trd*(torch.max(dpt[40:-40,:])+torch.min(dpt[40:-40,:])) for dpt in depth]
masks_new = [m+mb for m, mb in zip(masks,bgs_mask)]
masks = to_numpy(masks_new)
## render, 从c2ws[0]即ref image对应的相机开始
imgs = np.array(self.scene.imgs)
if self.opts.mode == 'single_view_ref_iterative':
c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=0, r=radius, elevation=self.opts.elevation, device=self.device)
camera_traj,num_views = generate_traj_specified(c2ws[0:1], H, W, focals[0:1], principal_points[0:1], self.opts.d_theta[iter], self.opts.d_phi[iter], self.opts.d_r[iter],self.opts.video_length, self.device)
render_results, viewmask = self.run_render(pcd, imgs,masks, H, W, camera_traj,num_views)
render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1)
render_results[0] = self.img_ori
elif self.opts.mode == 'single_view_1drc_iterative':
self.opts.elevation -= self.opts.d_theta[iter-1]
c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=-1, r=radius, elevation=self.opts.elevation, device=self.device)
camera_traj,num_views = generate_traj_specified(c2ws[-1:], H, W, focals[-1:], principal_points[-1:], self.opts.d_theta[iter], self.opts.d_phi[iter], self.opts.d_r[iter],self.opts.video_length, self.device)
render_results, viewmask = self.run_render(pcd, imgs,masks, H, W, camera_traj,num_views)
render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1)
render_results[0] = (self.images[-1]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2.
elif self.opts.mode == 'single_view_nbv':
c2ws,pcd = world_point_to_obj(poses=c2ws, points=torch.stack(pcd), k=-1, r=radius, elevation=self.opts.elevation, device=self.device)
## 输入candidate->渲染mask->最大mask对应的pose作为nbv
## nbv模式下self.opts.d_theta[0], self.opts.d_phi[0]代表search space中的网格theta, phi之间的间距; self.opts.d_phi[0]的符号代表方向,分为左右两个方向
## FIXME hard coded candidate view数量, 以left为例,第一次迭代从[左,左上]中选取, 从第二次开始可以从[左,左上,左下]中选取
num_candidates = 3
candidate_poses,thetas,phis = generate_candidate_poses(c2ws[-1:], H, W, focals[-1:], principal_points[-1:], self.opts.d_theta[0], self.opts.d_phi[0], num_candidates, self.device)
_, viewmask = self.run_render(pcd, imgs,masks, H, W, candidate_poses,num_candidates)
nbv_id = torch.argmin(viewmask.sum(dim=[1,2,3])).item()
save_image(viewmask.permute(0,3,1,2), os.path.join(self.opts.save_dir,f"candidate_mask{iter}_nbv{nbv_id}.png"), normalize=True, value_range=(0, 1))
theta_nbv = thetas[nbv_id]
phi_nbv = phis[nbv_id]
# generate camera trajectory from T_curr to T_nbv
camera_traj,num_views = generate_traj_specified(c2ws[-1:], H, W, focals[-1:], principal_points[-1:], theta_nbv, phi_nbv, self.opts.d_r[0],self.opts.video_length, self.device)
# 重置elevation
self.opts.elevation -= theta_nbv
render_results, viewmask = self.run_render(pcd, imgs,masks, H, W, camera_traj,num_views)
render_results = F.interpolate(render_results.permute(0,3,1,2), size=(576, 1024), mode='bilinear', align_corners=False).permute(0,2,3,1)
render_results[0] = (self.images[-1]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2.
else:
raise KeyError(f"Invalid Mode: {self.opts.mode}")
save_video(render_results, os.path.join(self.opts.save_dir, f'render{iter}.mp4'))
save_pointcloud_with_normals(imgs, pcd, msk=masks, save_path=os.path.join(self.opts.save_dir, f'pcd{iter}.ply') , mask_pc=True, reduce_pc=False)
diffusion_results = self.run_diffusion(render_results)
save_video((diffusion_results + 1.0) / 2.0, os.path.join(self.opts.save_dir, f'diffusion{iter}.mp4'))
# torch.Size([25, 576, 1024, 3])
return diffusion_results
def nvs_single_view_ref_iterative(self):
all_results = []
sample_rate = 6
idx = 1 #初始包含1张ref image
for itr in range(0, len(self.opts.d_phi)):
if itr == 0:
self.images = [self.images[0]] #去掉后一份copy
diffusion_results_itr = self.nvs_single_view()
# diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device)
diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2)
all_results.append(diffusion_results_itr)
else:
for i in range(0+sample_rate, diffusion_results_itr.shape[0], sample_rate):
self.images.append(get_input_dict(diffusion_results_itr[i:i+1,...],idx,dtype = torch.float32))
idx += 1
self.run_dust3r(input_images=self.images, clean_pc=True)
diffusion_results_itr = self.nvs_sparse_view(itr)
# diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device)
diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2)
all_results.append(diffusion_results_itr)
return all_results
def nvs_single_view_1drc_iterative(self):
all_results = []
sample_rate = 6
idx = 1 #初始包含1张ref image
for itr in range(0, len(self.opts.d_phi)):
if itr == 0:
self.images = [self.images[0]] #去掉后一份copy
diffusion_results_itr = self.nvs_single_view()
# diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device)
diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2)
all_results.append(diffusion_results_itr)
else:
for i in range(0+sample_rate, diffusion_results_itr.shape[0], sample_rate):
self.images.append(get_input_dict(diffusion_results_itr[i:i+1,...],idx,dtype = torch.float32))
idx += 1
self.run_dust3r(input_images=self.images, clean_pc=True)
diffusion_results_itr = self.nvs_sparse_view(itr)
# diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device)
diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2)
all_results.append(diffusion_results_itr)
return all_results
def nvs_single_view_nbv(self):
# lef and right
# d_theta and a_phi 是搜索空间的顶点间隔
all_results = []
## FIXME: hard coded
sample_rate = 6
max_itr = 3
idx = 1 #初始包含1张ref image
for itr in range(0, max_itr):
if itr == 0:
self.images = [self.images[0]] #去掉后一份copy
diffusion_results_itr = self.nvs_single_view()
# diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device)
diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2)
all_results.append(diffusion_results_itr)
else:
for i in range(0+sample_rate, diffusion_results_itr.shape[0], sample_rate):
self.images.append(get_input_dict(diffusion_results_itr[i:i+1,...],idx,dtype = torch.float32))
idx += 1
self.run_dust3r(input_images=self.images, clean_pc=True)
diffusion_results_itr = self.nvs_sparse_view(itr)
# diffusion_results_itr = torch.randn([25, 576, 1024, 3]).to(self.device)
diffusion_results_itr = diffusion_results_itr.permute(0,3,1,2)
all_results.append(diffusion_results_itr)
return all_results
def setup_diffusion(self):
seed_everything(self.opts.seed)
config = OmegaConf.load(self.opts.config)
model_config = config.pop("model", OmegaConf.create())
## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
model_config['params']['unet_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(model_config)
model = model.to(self.device)
model.cond_stage_model.device = self.device
model.perframe_ae = self.opts.perframe_ae
assert os.path.exists(self.opts.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, self.opts.ckpt_path)
model.eval()
self.diffusion = model
h, w = self.opts.height // 8, self.opts.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = self.opts.video_length
self.noise_shape = [self.opts.bs, channels, n_frames, h, w]
def setup_dust3r(self):
self.dust3r = load_model(self.opts.model_path, self.device)
def load_initial_images(self, image_dir):
## load images
## dict_keys(['img', 'true_shape', 'idx', 'instance', 'img_ori']),张量形式
images = load_images([image_dir], size=512,force_1024 = True)
img_ori = (images[0]['img_ori'].squeeze(0).permute(1,2,0)+1.)/2. # [576,1024,3] [0,1]
# img_ori = Image.open(image_dir).convert('RGB')
# transform = transforms.Compose([
# transforms.Resize((576, 1024)),
# transforms.ToTensor(),
# transforms.Normalize((0., 0., 0.), (1., 1., 1.)) # 归一化到[-1,1],如果要归一化到[0,1],请使用transforms.Normalize((0., 0., 0.), (1., 1., 1.))
# ])
# img_ori = transform(img_ori).permute(1,2,0).to(self.device)
if len(images) == 1:
images = [images[0], copy.deepcopy(images[0])]
images[1]['idx'] = 1
return images, img_ori
def run_traj(self,i2v_input_image, i2v_elevation, i2v_center_scale, i2v_pose):
self.opts.elevation = float(i2v_elevation)
self.opts.center_scale = float(i2v_center_scale)
i2v_d_phi,i2v_d_theta,i2v_d_r = [i for i in i2v_pose.split(';')]
self.gradio_traj = [float(i) for i in i2v_d_phi.split()],[float(i) for i in i2v_d_theta.split()],[float(i) for i in i2v_d_r.split()]
transform = transforms.Compose([
transforms.Resize((576,1024)),
# transforms.CenterCrop((576,1024)),
])
torch.cuda.empty_cache()
img_tensor = torch.from_numpy(i2v_input_image).permute(2, 0, 1).unsqueeze(0).float().to(self.device)
img_tensor = (img_tensor / 255. - 0.5) * 2
image_tensor_resized = transform(img_tensor) #1,3,h,w
images = get_input_dict(image_tensor_resized,idx = 0,dtype = torch.float32)
images = [images, copy.deepcopy(images)]
images[1]['idx'] = 1
self.images = images
self.img_ori = (image_tensor_resized.squeeze(0).permute(1,2,0) + 1.)/2.
# self.images: torch.Size([1, 3, 288, 512]), [-1,1]
# self.img_ori: torch.Size([576, 1024, 3]), [0,1]
# self.images, self.img_ori = self.load_initial_images(image_dir=i2v_input_image)
self.run_dust3r(input_images=self.images)
render_results = self.nvs_single_view(gradio=True)
save_video(render_results, os.path.join(self.opts.save_dir, 'render0.mp4'))
traj_dir = os.path.join(self.opts.save_dir, "viz_traj.mp4")
return traj_dir
def run_gen(self,i2v_steps, i2v_seed):
self.opts.ddim_steps = i2v_steps
seed_everything(i2v_seed)
render_dir = os.path.join(self.opts.save_dir, 'render0.mp4')
video = imageio.get_reader(render_dir, 'ffmpeg')
frames = []
for frame in video:
frame = frame / 255.0
frames.append(frame)
frames = np.array(frames)
##torch.Size([25, 576, 1024, 3])
render_results = torch.from_numpy(frames).to(self.device).half()
gen_dir = os.path.join(self.opts.save_dir, "diffusion0.mp4")
diffusion_results = self.run_diffusion(render_results)
save_video((diffusion_results + 1.0) / 2.0, os.path.join(self.opts.save_dir, 'diffusion0.mp4'))
return gen_dir
def run_both(self,i2v_input_image, i2v_elevation, i2v_center_scale, i2v_pose,i2v_steps, i2v_seed):
self.opts.ddim_steps = i2v_steps
seed_everything(i2v_seed)
self.opts.elevation = float(i2v_elevation)
self.opts.center_scale = float(i2v_center_scale)
i2v_d_phi,i2v_d_theta,i2v_d_r = [i for i in i2v_pose.split(';')]
self.gradio_traj = [float(i) for i in i2v_d_phi.split()],[float(i) for i in i2v_d_theta.split()],[float(i) for i in i2v_d_r.split()]
torch.cuda.empty_cache()
img_tensor = torch.from_numpy(i2v_input_image).permute(2, 0, 1).unsqueeze(0).float().to(self.device)
img_tensor = (img_tensor / 255. - 0.5) * 2
image_tensor_resized = center_crop_image(img_tensor) #1,3,h,w
images = get_input_dict(image_tensor_resized,idx = 0,dtype = torch.float32)
images = [images, copy.deepcopy(images)]
images[1]['idx'] = 1
self.images = images
self.img_ori = (image_tensor_resized.squeeze(0).permute(1,2,0) + 1.)/2.
# self.images: torch.Size([1, 3, 288, 512]), [-1,1]
# self.img_ori: torch.Size([576, 1024, 3]), [0,1]
# self.images, self.img_ori = self.load_initial_images(image_dir=i2v_input_image)
self.run_dust3r(input_images=self.images)
self.nvs_single_view(gradio=True)
traj_dir = os.path.join(self.opts.save_dir, "viz_traj.mp4")
gen_dir = os.path.join(self.opts.save_dir, "diffusion0.mp4")
return gen_dir,traj_dir,