Skip to content

Commit

Permalink
add lib (zhanghang1989#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed Jun 4, 2018
1 parent d8abf50 commit ed5456d
Show file tree
Hide file tree
Showing 20 changed files with 2,404 additions and 2 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,5 @@ build/
data/
docs/src/
docs/html/
encoding/lib/
encoding/_ext/
encoding.egg-info/
experiments/segmentation/
20 changes: 20 additions & 0 deletions encoding/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os
import torch
from torch.utils.cpp_extension import load

cwd = os.path.dirname(os.path.realpath(__file__))
cpu_path = os.path.join(cwd, 'cpu')
gpu_path = os.path.join(cwd, 'gpu')

cpu = load( 'enclib_cpu', [
os.path.join(cpu_path, 'roi_align.cpp'),
os.path.join(cpu_path, 'roi_align_cpu.cpp'),
], build_directory=cpu_path, verbose=False)

if torch.cuda.is_available():
gpu = load( 'enclib_gpu', [
os.path.join(gpu_path, 'operator.cpp'),
os.path.join(gpu_path, 'encoding_kernel.cu'),
os.path.join(gpu_path, 'syncbn_kernel.cu'),
os.path.join(gpu_path, 'roi_align_kernel.cu'),
], build_directory=gpu_path, verbose=False)
Empty file added encoding/lib/cpu/__init__.py
Empty file.
28 changes: 28 additions & 0 deletions encoding/lib/cpu/roi_align.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include <torch/torch.h>
// CPU declarations

at::Tensor ROIAlignForwardCPU(
const at::Tensor& input,
const at::Tensor& bottom_rois,
int64_t pooled_height,
int64_t pooled_width,
double spatial_scale,
int64_t sampling_ratio);

at::Tensor ROIAlignBackwardCPU(
const at::Tensor& bottom_rois,
const at::Tensor& grad_output, // gradient of the output of the layer
int64_t b_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t pooled_height,
int64_t pooled_width,
double spatial_scale,
int64_t sampling_ratio);


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("roi_align_forward", &ROIAlignForwardCPU, "ROI Align forward (CPU)");
m.def("roi_align_backward", &ROIAlignBackwardCPU, "ROI Align backward (CPU)");
}
Loading

0 comments on commit ed5456d

Please sign in to comment.