Skip to content

Commit

Permalink
ssd detection lib (zhanghang1989#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed May 1, 2020
1 parent e57e90d commit f70fa97
Show file tree
Hide file tree
Showing 6 changed files with 634 additions and 0 deletions.
1 change: 1 addition & 0 deletions encoding/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@
os.path.join(gpu_path, 'roi_align_kernel.cu'),
os.path.join(gpu_path, 'nms_kernel.cu'),
os.path.join(gpu_path, 'rectify_cuda.cu'),
os.path.join(gpu_path, 'lib_ssd.cu'),
], extra_cuda_cflags=["--expt-extended-lambda"],
build_directory=gpu_path, verbose=False)
2 changes: 2 additions & 0 deletions encoding/lib/cpu/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sumsquare_backward", &Sum_Square_Backward_CPU, "SumSqu backward (CPU)");
m.def("non_max_suppression", &Non_Max_Suppression_CPU, "NMS (CPU)");
m.def("conv_rectify", &CONV_RECTIFY_CPU, "Convolution Rectifier (CPU)");
// Apply fused color jitter
m.def("apply_transform", &apply_transform, "apply_transform");
}
34 changes: 34 additions & 0 deletions encoding/lib/cpu/operator.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>

#include <torch/torch.h>
#include <vector>

Expand Down Expand Up @@ -81,3 +85,33 @@ void CONV_RECTIFY_CPU(
at::IntArrayRef padding,
at::IntArrayRef dilation,
bool avg_mode);

// Fused color jitter application
// ctm [4,4], img [H, W, C]
py::array_t<float> apply_transform(int H, int W, int C, py::array_t<float> img, py::array_t<float> ctm) {
auto img_buf = img.request();
auto ctm_buf = ctm.request();

// printf("H: %d, W: %d, C: %d\n", H, W, C);
py::array_t<float> result{img_buf.size};
auto res_buf = result.request();

float *img_ptr = (float *)img_buf.ptr;
float *ctm_ptr = (float *)ctm_buf.ptr;
float *res_ptr = (float *)res_buf.ptr;

for (int h = 0; h < H; ++h) {
for (int w = 0; w < W; ++w) {
float *ptr = &img_ptr[h * W * C + w * C];
float *out_ptr = &res_ptr[h * W * C + w * C];
// manually unroll over C
out_ptr[0] = ctm_ptr[0] * ptr[0] + ctm_ptr[1] * ptr[1] + ctm_ptr[2] * ptr[2] + ctm_ptr[3];
out_ptr[1] = ctm_ptr[4] * ptr[0] + ctm_ptr[5] * ptr[1] + ctm_ptr[6] * ptr[2] + ctm_ptr[7];
out_ptr[2] = ctm_ptr[8] * ptr[0] + ctm_ptr[9] * ptr[1] + ctm_ptr[10] * ptr[2] + ctm_ptr[11];
}
}

result.resize({H, W, C});

return result;
}
Loading

0 comments on commit f70fa97

Please sign in to comment.