Skip to content

Commit

Permalink
add files
Browse files Browse the repository at this point in the history
  • Loading branch information
邱浩楠 authored and 邱浩楠 committed Sep 4, 2023
1 parent 7cdcfe6 commit c2b022a
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 72 deletions.
72 changes: 1 addition & 71 deletions code/utils/light_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,67 +60,6 @@ def add_sample_SHlight(constant_factor, normal_images, sh_coeff):
shading = torch.sum(sh_coeff[None,:,None]*sh[:,:,None], 1) # [bz, 9, 1]
return shading # [bz, 1]

def add_SHlight_infer(normal_images, given_sh):
'''
sh_coeff: [bz, 9, 1]
1, Y, Z, X, YX, YZ, 3Z^2-1, XZ, X^2-y^2
'''
N = normal_images
# sh = torch.stack([
# N[:,0]*0.+1., N[:,0], N[:,1], \
# N[:,2], N[:,0]*N[:,1], N[:,0]*N[:,2],
# N[:,1]*N[:,2], N[:,0]**2 - N[:,1]**2, 3*(N[:,2]**2) - 1
# ],
# 1) # [bz, 9, h, w]
sh = torch.stack([
N[:,0]*0.+1., N[:,1], N[:,2], \
N[:,0], N[:,0]*N[:,1], N[:,1]*N[:,2],
3*(N[:,2]**2) - 1, N[:,0]*N[:,2], N[:,0]**2 - N[:,1]**2
],
1) # [bz, 9, h, w]
sh = sh*given_sh[:, :, :, None]
shading = torch.sum(sh[:,:,None,:,:], 1) # [bz, 9, 1, h, w]
# sh = sh*constant_factor[None,:,None,None]
# shading = torch.sum(sh_coeff[:,:,:,None,None]*sh[:,:,None,:,:], 1) # [bz, 9, 1, h, w]
return shading

def add_SHlight_infer_env(normal_images, given_sh):
'''
sh_coeff: [bz, 9, 1]
1, Y, Z, X, YX, YZ, 3Z^2-1, XZ, X^2-y^2
'''
N = normal_images

constant_factor = torch.tensor([0.5/np.sqrt(np.pi),
np.sqrt(3)/2/np.sqrt(np.pi),
np.sqrt(3)/2/np.sqrt(np.pi),
np.sqrt(3)/2/np.sqrt(np.pi),
np.sqrt(15)/2/np.sqrt(np.pi),
np.sqrt(15)/2/np.sqrt(np.pi),
np.sqrt(5)/4/np.sqrt(np.pi),
np.sqrt(15)/2/np.sqrt(np.pi),
np.sqrt(15)/4/np.sqrt(np.pi)]).float().view(1, 9, 1, 1)

# sh = torch.stack([
# N[:,0]*0.+1., N[:,0], N[:,1], \
# N[:,2], N[:,0]*N[:,1], N[:,0]*N[:,2],
# N[:,1]*N[:,2], N[:,0]**2 - N[:,1]**2, 3*(N[:,2]**2) - 1
# ],
# 1) # [bz, 9, h, w]
sh = torch.stack([
N[:,0]*0.+1., N[:,1], N[:,2], \
N[:,0], N[:,0]*N[:,1], N[:,1]*N[:,2],
3*(N[:,2]**2) - 1, N[:,0]*N[:,2], N[:,0]**2 - N[:,1]**2
],
1) # [bz, 9, h, w]

sh = sh[:,:,None,:,:] * given_sh[None,:,:,:,:]
shading = torch.sum(sh[:,:,:,:,:], 1) # [bz, 9, 1, h, w]
# sh = sh*constant_factor[None,:,None,None]
# shading = torch.sum(sh_coeff[:,:,:,None,None]*sh[:,:,None,:,:], 1) # [bz, 9, 1, h, w]
return shading


def SH_basis(normal):
'''
get SH basis based on normal
Expand Down Expand Up @@ -193,16 +132,7 @@ def get_shading(normal, SH):
def draw_shading(sh):
# ---------------- create normal for rendering half sphere ------
img_size = 256
# x = np.linspace(-1, 1, img_size)
# z = np.linspace(1, -1, img_size)
# x, z = np.meshgrid(x, z)

# mag = np.sqrt(x**2 + z**2)
# valid = mag <=1
# y = -np.sqrt(1 - (x*valid)**2 - (z*valid)**2)
# x = x * valid
# y = y * valid
# z = z * valid

x = np.linspace(-1, 1, img_size)
y = np.linspace(1, -1, img_size)
x, y = np.meshgrid(x, y)
Expand Down
125 changes: 124 additions & 1 deletion code/utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def plot_images(model_outputs, depth_image, ground_truth, path, epoch, img_index
rgb_gt = (rgb_gt.cuda() + 1.) / 2.
else:
rgb_gt = None

if 'albedo_values' in model_outputs:
albedo = model_outputs['albedo_values']
spec = model_outputs['spec_values'].repeat(1,3)
specmap = model_outputs['specmap_values'].repeat(1,3)
rough_rgb = model_outputs['rough_rgb_values']
else:
albedo = None
spec = None
specmap = None
rough_rgb = None

rgb_points = model_outputs['rgb_values']
rgb_points = rgb_points.reshape(batch_size, num_samples, 3)

Expand All @@ -99,6 +111,15 @@ def plot_images(model_outputs, depth_image, ground_truth, path, epoch, img_index
output_vs_gt = torch.cat((output_vs_gt, rgb_gt, depth_image.repeat(1, 1, 3), normal_points), dim=0)
else:
output_vs_gt = torch.cat((output_vs_gt, depth_image.repeat(1, 1, 3), normal_points), dim=0)

if 'albedo_values' in model_outputs:
rough_rgb_points = rough_rgb.reshape(batch_size, num_samples, 3)
rough_rgb_points = (rough_rgb_points + 1.) / 2.
albedo_points = albedo.reshape(batch_size, num_samples, 3)
spec_points = spec.reshape(batch_size, num_samples, 3)
specmap_points = specmap.reshape(batch_size, num_samples, 3)
output_vs_gt = torch.cat((output_vs_gt, rough_rgb_points, albedo_points, spec_points, specmap_points), dim=0)

if 'lbs_weight' in model_outputs:
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -134,7 +155,7 @@ def plot_images(model_outputs, depth_image, ground_truth, path, epoch, img_index

tensor = torchvision.utils.make_grid(output_vs_gt_plot,
scale_each=False,
normalize=False,
normalize=True,
nrow=output_vs_gt.shape[0]).cpu().detach().numpy()

tensor = tensor.transpose(1, 2, 0)
Expand All @@ -154,3 +175,105 @@ def plot_images(model_outputs, depth_image, ground_truth, path, epoch, img_index
def lin2img(tensor, img_res):
batch_size, num_samples, channels = tensor.shape
return tensor.permute(0, 2, 1).view(batch_size, channels, img_res[0], img_res[1])

def light_plot(img_index, model_outputs, pose, ground_truth, path, epoch, img_res, plot_nimgs, min_depth, max_depth, res_init, res_up, is_eval=False):
# arrange data to plot
batch_size = pose.shape[0]
num_samples = int(model_outputs['rgb_values'].shape[0] / batch_size)
# plot rendered images

if 'rgb' in ground_truth:
rgb_gt = ground_truth['rgb']
rgb_gt = (rgb_gt.cuda() + 1.) / 2.
else:
rgb_gt = None


if 'bg_img' in model_outputs:
rgb_points = model_outputs['rgb_values'] * (model_outputs['object_mask'] * model_outputs['foreground_mask']) + model_outputs['bg_img'] * (1 - model_outputs['object_mask'] * model_outputs['foreground_mask'])
else:
rgb_points = model_outputs['rgb_values'] * (model_outputs['object_mask'] * model_outputs['foreground_mask']) + (1 - model_outputs['object_mask'] * model_outputs['foreground_mask'])
rgb_points = rgb_points.reshape(batch_size, num_samples, 3)

normal_points = model_outputs['normal_values'] * model_outputs['face_mask'] + (1 - model_outputs['face_mask'])
normal_points = normal_points.reshape(batch_size, num_samples, 3)

fine_normal_points = model_outputs['fine_normal_values'] * model_outputs['face_mask'] + (1 - model_outputs['face_mask'])
fine_normal_points = fine_normal_points.reshape(batch_size, num_samples, 3)

# pre_spec_points = model_outputs['pre_spec_values']
# pre_spec_points = pre_spec_points.reshape(batch_size, num_samples, 3)

albedo_points = model_outputs['albedo_values'] * model_outputs['face_mask'] + (1 - model_outputs['face_mask'])
albedo_points = albedo_points.reshape(batch_size, num_samples, 3)

specmap_points = model_outputs['specmap_values'] * model_outputs['face_mask'] + (1 - model_outputs['face_mask'])
specmap_points = specmap_points.reshape(batch_size, num_samples, 3)

shading_points = model_outputs['shading_values'] * model_outputs['face_mask']
shading_points = shading_points.reshape(batch_size, num_samples, 3)

spec_points = model_outputs['spec_values'] * model_outputs['face_mask']
spec_points = spec_points.reshape(batch_size, num_samples, 3)

light_points = model_outputs['light_values']
light_points = light_points.reshape(batch_size, num_samples, 3)

rgb_points = (torch.clamp(rgb_points, -1.0, 1.0) + 1.) / 2.
specmap_points = (torch.clamp(specmap_points, -1.0, 1.0) + 1.) / 2.
normal_points = (torch.clamp(normal_points, -1.0, 1.0) + 1.) / 2.
fine_normal_points = (torch.clamp(fine_normal_points, -1.0, 1.0) + 1.) / 2.
# albedo_points = (torch.clamp(albedo_points, -1.0, 1.0) + 1.) / 2.
# shading_points = (torch.clamp(shading_points, -1.0, 1.0) + 1.) / 2.

output_vs_gt = rgb_points
if rgb_gt is not None:
output_vs_gt = torch.cat((rgb_gt, output_vs_gt, albedo_points, normal_points, fine_normal_points, specmap_points, spec_points, shading_points, light_points), dim=0)
else:
output_vs_gt = torch.cat((output_vs_gt, albedo_points,normal_points, fine_normal_points, specmap_points, spec_points, shading_points, light_points), dim=0)

if 'sample_light_values' in model_outputs:

sample_rgb_points = model_outputs['sample_rgb_values'] * model_outputs['object_mask'] + (1 - model_outputs['object_mask'])
sample_rgb_points = sample_rgb_points.reshape(batch_size, num_samples, 3)

sample_light_points = model_outputs['sample_light_values']
sample_light_points = sample_light_points.reshape(batch_size, num_samples, 3)

sample_rgb_points = (torch.clamp(sample_rgb_points, -1.0, 1.0) + 1.) / 2.
output_vs_gt = torch.cat((output_vs_gt, sample_rgb_points, sample_light_points), dim = 0)

output_vs_gt_plot = lin2img(output_vs_gt, img_res)

tensor = torchvision.utils.make_grid(output_vs_gt_plot,
scale_each=False,
normalize=False,
nrow=output_vs_gt.shape[0]).cpu().detach().numpy()

tensor = tensor.transpose(1, 2, 0)
scale_factor = 255
tensor = (tensor * scale_factor).astype(np.uint8)

img = Image.fromarray(tensor)
if is_eval:
wo_epoch_path = path.replace('/epoch_{}'.format(epoch), '')
if not os.path.exists('{0}/rendering_test'.format(wo_epoch_path)):
os.mkdir('{0}/rendering_test'.format(wo_epoch_path))
img.save('{0}/rendering_test/epoch_{1}_{2}.png'.format(wo_epoch_path, epoch, img_index))

plot_image(rgb_gt, path, epoch, img_index, 1, img_res, 'gt')
plot_image(rgb_points, path, epoch, img_index, 1, img_res, 'rgb')
plot_image(albedo_points, path, epoch, img_index, 1, img_res, 'albedo')
plot_image(normal_points, path, epoch, img_index, 1, img_res, 'normal')
plot_image(fine_normal_points, path, epoch, img_index, 1, img_res, 'fine_normal')
plot_image(specmap_points, path, epoch, img_index, 1, img_res, 'specmap')
plot_image(spec_points, path, epoch, img_index, 1, img_res, 'spec')
plot_image(shading_points, path, epoch, img_index, 1, img_res, 'shading')
plot_image(light_points, path, epoch, img_index, 1, img_res, 'light')
else:
wo_epoch_path = path.replace('/epoch_{}'.format(epoch), '')
if not os.path.exists('{0}/rendering'.format(wo_epoch_path)):
os.mkdir('{0}/rendering'.format(wo_epoch_path))
img.save('{0}/rendering/epoch_{1}_{2}.png'.format(wo_epoch_path, epoch, img_index))

del output_vs_gt

0 comments on commit c2b022a

Please sign in to comment.