Skip to content

Commit

Permalink
add depth_sq output (not used yet...)
Browse files Browse the repository at this point in the history
  • Loading branch information
kwea123 committed Jul 22, 2022
1 parent 2b1567a commit 8a26180
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions models/csrc/volumerendering.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ __global__ void composite_train_fw_kernel(
const scalar_t T_threshold,
torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> opacity,
torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> depth,
torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> depth_sq,
torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> rgb
){
const int n = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -30,6 +31,7 @@ __global__ void composite_train_fw_kernel(
rgb[ray_idx][1] += w*rgbs[s][1];
rgb[ray_idx][2] += w*rgbs[s][2];
depth[ray_idx] += w*ts[s];
depth_sq[ray_idx] += w*ts[s]*ts[s];
opacity[ray_idx] += w;
T *= 1.0f-a;

Expand All @@ -51,6 +53,7 @@ std::vector<torch::Tensor> composite_train_fw_cu(

auto opacity = torch::zeros({N_rays}, sigmas.options());
auto depth = torch::zeros({N_rays}, sigmas.options());
auto depth_sq = torch::zeros({N_rays}, sigmas.options());
auto rgb = torch::zeros({N_rays, 3}, sigmas.options());

const int threads = 256, blocks = (N_rays+threads-1)/threads;
Expand All @@ -66,18 +69,20 @@ std::vector<torch::Tensor> composite_train_fw_cu(
T_threshold,
opacity.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
depth.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
depth_sq.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
rgb.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>()
);
}));

return {opacity, depth, rgb};
return {opacity, depth, depth_sq, rgb};
}


template <typename scalar_t>
__global__ void composite_train_bw_kernel(
const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> dL_dopacity,
const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> dL_ddepth,
const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> dL_ddepth_sq,
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> dL_drgb,
const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> sigmas,
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> rgbs,
Expand All @@ -86,6 +91,7 @@ __global__ void composite_train_bw_kernel(
const torch::PackedTensorAccessor32<int, 2, torch::RestrictPtrTraits> rays_a,
const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> opacity,
const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> depth,
const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> depth_sq,
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> rgb,
const scalar_t T_threshold,
torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> dL_dsigmas,
Expand All @@ -99,16 +105,16 @@ __global__ void composite_train_bw_kernel(
// front to back compositing
int samples = 0;
scalar_t R = rgb[ray_idx][0], G = rgb[ray_idx][1], B = rgb[ray_idx][2];
scalar_t O = opacity[ray_idx], D = depth[ray_idx];
scalar_t T = 1.0f, r = 0.0f, g = 0.0f, b = 0.0f, d = 0.0f;
scalar_t O = opacity[ray_idx], D = depth[ray_idx], Dsq = depth_sq[ray_idx];
scalar_t T = 1.0f, r = 0.0f, g = 0.0f, b = 0.0f, d = 0.0f, dsq = 0.0f;

while (samples < N_samples) {
const int s = start_idx + samples;
const scalar_t a = 1.0f - __expf(-sigmas[s]*deltas[s]);
const scalar_t w = a * T;

r += w*rgbs[s][0]; g += w*rgbs[s][1]; b += w*rgbs[s][2];
d += w*ts[s];
d += w*ts[s]; dsq += w*ts[s]*ts[s];
T *= 1.0f-a;

// compute gradients by math...
Expand All @@ -121,7 +127,8 @@ __global__ void composite_train_bw_kernel(
dL_drgb[ray_idx][1]*(rgbs[s][1]*T-(G-g)) +
dL_drgb[ray_idx][2]*(rgbs[s][2]*T-(B-b)) +
dL_dopacity[ray_idx]*(1-O) +
dL_ddepth[ray_idx]*(ts[s]*T-(D-d))
dL_ddepth[ray_idx]*(ts[s]*T-(D-d)) +
dL_ddepth_sq[ray_idx]*(ts[s]*ts[s]*T-(Dsq-dsq))
);

if (T <= T_threshold) break; // ray has enough opacity
Expand All @@ -133,6 +140,7 @@ __global__ void composite_train_bw_kernel(
std::vector<torch::Tensor> composite_train_bw_cu(
const torch::Tensor dL_dopacity,
const torch::Tensor dL_ddepth,
const torch::Tensor dL_ddepth_sq,
const torch::Tensor dL_drgb,
const torch::Tensor sigmas,
const torch::Tensor rgbs,
Expand All @@ -141,6 +149,7 @@ std::vector<torch::Tensor> composite_train_bw_cu(
const torch::Tensor rays_a,
const torch::Tensor opacity,
const torch::Tensor depth,
const torch::Tensor depth_sq,
const torch::Tensor rgb,
const float T_threshold
){
Expand All @@ -156,6 +165,7 @@ std::vector<torch::Tensor> composite_train_bw_cu(
composite_train_bw_kernel<scalar_t><<<blocks, threads>>>(
dL_dopacity.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
dL_ddepth.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
dL_ddepth_sq.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
dL_drgb.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
sigmas.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
rgbs.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
Expand All @@ -164,6 +174,7 @@ std::vector<torch::Tensor> composite_train_bw_cu(
rays_a.packed_accessor32<int, 2, torch::RestrictPtrTraits>(),
opacity.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
depth.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
depth_sq.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
rgb.packed_accessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>(),
T_threshold,
dL_dsigmas.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>(),
Expand Down

0 comments on commit 8a26180

Please sign in to comment.