From 8a261802790401b0572a4a0ebc90995354aa6965 Mon Sep 17 00:00:00 2001 From: kwea123 Date: Fri, 22 Jul 2022 20:09:10 +0900 Subject: [PATCH] add depth_sq output (not used yet...) --- models/csrc/volumerendering.cu | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/models/csrc/volumerendering.cu b/models/csrc/volumerendering.cu index 739580ed..c0188b57 100644 --- a/models/csrc/volumerendering.cu +++ b/models/csrc/volumerendering.cu @@ -11,6 +11,7 @@ __global__ void composite_train_fw_kernel( const scalar_t T_threshold, torch::PackedTensorAccessor opacity, torch::PackedTensorAccessor depth, + torch::PackedTensorAccessor depth_sq, torch::PackedTensorAccessor rgb ){ const int n = blockIdx.x * blockDim.x + threadIdx.x; @@ -30,6 +31,7 @@ __global__ void composite_train_fw_kernel( rgb[ray_idx][1] += w*rgbs[s][1]; rgb[ray_idx][2] += w*rgbs[s][2]; depth[ray_idx] += w*ts[s]; + depth_sq[ray_idx] += w*ts[s]*ts[s]; opacity[ray_idx] += w; T *= 1.0f-a; @@ -51,6 +53,7 @@ std::vector composite_train_fw_cu( auto opacity = torch::zeros({N_rays}, sigmas.options()); auto depth = torch::zeros({N_rays}, sigmas.options()); + auto depth_sq = torch::zeros({N_rays}, sigmas.options()); auto rgb = torch::zeros({N_rays, 3}, sigmas.options()); const int threads = 256, blocks = (N_rays+threads-1)/threads; @@ -66,11 +69,12 @@ std::vector composite_train_fw_cu( T_threshold, opacity.packed_accessor(), depth.packed_accessor(), + depth_sq.packed_accessor(), rgb.packed_accessor() ); })); - return {opacity, depth, rgb}; + return {opacity, depth, depth_sq, rgb}; } @@ -78,6 +82,7 @@ template __global__ void composite_train_bw_kernel( const torch::PackedTensorAccessor dL_dopacity, const torch::PackedTensorAccessor dL_ddepth, + const torch::PackedTensorAccessor dL_ddepth_sq, const torch::PackedTensorAccessor dL_drgb, const torch::PackedTensorAccessor sigmas, const torch::PackedTensorAccessor rgbs, @@ -86,6 +91,7 @@ __global__ void composite_train_bw_kernel( const torch::PackedTensorAccessor32 rays_a, const torch::PackedTensorAccessor opacity, const torch::PackedTensorAccessor depth, + const torch::PackedTensorAccessor depth_sq, const torch::PackedTensorAccessor rgb, const scalar_t T_threshold, torch::PackedTensorAccessor dL_dsigmas, @@ -99,8 +105,8 @@ __global__ void composite_train_bw_kernel( // front to back compositing int samples = 0; scalar_t R = rgb[ray_idx][0], G = rgb[ray_idx][1], B = rgb[ray_idx][2]; - scalar_t O = opacity[ray_idx], D = depth[ray_idx]; - scalar_t T = 1.0f, r = 0.0f, g = 0.0f, b = 0.0f, d = 0.0f; + scalar_t O = opacity[ray_idx], D = depth[ray_idx], Dsq = depth_sq[ray_idx]; + scalar_t T = 1.0f, r = 0.0f, g = 0.0f, b = 0.0f, d = 0.0f, dsq = 0.0f; while (samples < N_samples) { const int s = start_idx + samples; @@ -108,7 +114,7 @@ __global__ void composite_train_bw_kernel( const scalar_t w = a * T; r += w*rgbs[s][0]; g += w*rgbs[s][1]; b += w*rgbs[s][2]; - d += w*ts[s]; + d += w*ts[s]; dsq += w*ts[s]*ts[s]; T *= 1.0f-a; // compute gradients by math... @@ -121,7 +127,8 @@ __global__ void composite_train_bw_kernel( dL_drgb[ray_idx][1]*(rgbs[s][1]*T-(G-g)) + dL_drgb[ray_idx][2]*(rgbs[s][2]*T-(B-b)) + dL_dopacity[ray_idx]*(1-O) + - dL_ddepth[ray_idx]*(ts[s]*T-(D-d)) + dL_ddepth[ray_idx]*(ts[s]*T-(D-d)) + + dL_ddepth_sq[ray_idx]*(ts[s]*ts[s]*T-(Dsq-dsq)) ); if (T <= T_threshold) break; // ray has enough opacity @@ -133,6 +140,7 @@ __global__ void composite_train_bw_kernel( std::vector composite_train_bw_cu( const torch::Tensor dL_dopacity, const torch::Tensor dL_ddepth, + const torch::Tensor dL_ddepth_sq, const torch::Tensor dL_drgb, const torch::Tensor sigmas, const torch::Tensor rgbs, @@ -141,6 +149,7 @@ std::vector composite_train_bw_cu( const torch::Tensor rays_a, const torch::Tensor opacity, const torch::Tensor depth, + const torch::Tensor depth_sq, const torch::Tensor rgb, const float T_threshold ){ @@ -156,6 +165,7 @@ std::vector composite_train_bw_cu( composite_train_bw_kernel<<>>( dL_dopacity.packed_accessor(), dL_ddepth.packed_accessor(), + dL_ddepth_sq.packed_accessor(), dL_drgb.packed_accessor(), sigmas.packed_accessor(), rgbs.packed_accessor(), @@ -164,6 +174,7 @@ std::vector composite_train_bw_cu( rays_a.packed_accessor32(), opacity.packed_accessor(), depth.packed_accessor(), + depth_sq.packed_accessor(), rgb.packed_accessor(), T_threshold, dL_dsigmas.packed_accessor(),