Skip to content

Commit

Permalink
fix docker (zhanghang1989#310)
Browse files Browse the repository at this point in the history
* fix docker

* broken api
  • Loading branch information
zhanghang1989 authored Aug 9, 2020
1 parent 1235f3b commit fed540f
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 52 deletions.
8 changes: 5 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/pytorch:19.05-py3
FROM nvcr.io/nvidia/pytorch:20.06-py3

# Set working directory # Set working directory
WORKDIR /workspace
Expand All @@ -8,14 +8,16 @@ WORKDIR /workspace
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y python3-tk python-pip git tmux htop tree

RUN python -m pip install --upgrade pip
RUN python -m pip install torch==1.4.0
RUN python -m pip install torchvision==0.5.0
#RUN python -m pip install torch==1.4.0
#RUN python -m pip install torchvision==0.5.0
RUN python -m pip install pycocotools==2.0.0

#RUN chmod a+rwx -R /opt/conda/

COPY ./setup.py .
COPY ./encoding ./encoding

ENV FORCE_CUDA="1"
RUN python setup.py develop

COPY ./experiments ./experiments
14 changes: 7 additions & 7 deletions encoding/lib/cpu/roi_align_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,17 +413,17 @@ at::Tensor ROIAlign_Forward_CPU(
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "ROIAlign_Forward_CPU", ([&] {
ROIAlignForwardCompute<scalar_t>(
output.numel(),
input.data<scalar_t>(),
input.data_ptr<scalar_t>(),
static_cast<scalar_t>(spatial_scale),
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
bottom_rois.data<scalar_t>(),
bottom_rois.data_ptr<scalar_t>(),
roi_cols,
output.data<scalar_t>());
output.data_ptr<scalar_t>());
}));

return output;
Expand Down Expand Up @@ -456,10 +456,10 @@ at::Tensor ROIAlign_Backward_CPU(

AT_ASSERT(bottom_rois.is_contiguous());

AT_DISPATCH_FLOATING_TYPES(bottom_rois.type(), "ROIAlign_Backward_CPU", ([&] {
AT_DISPATCH_FLOATING_TYPES(bottom_rois.scalar_type(), "ROIAlign_Backward_CPU", ([&] {
ROIAlignBackwardCompute<scalar_t>(
grad_output.numel(),
grad_output.data<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
num_rois,
static_cast<scalar_t>(spatial_scale),
channels,
Expand All @@ -468,8 +468,8 @@ at::Tensor ROIAlign_Backward_CPU(
pooled_height,
pooled_width,
sampling_ratio,
grad_in.data<scalar_t>(),
bottom_rois.data<scalar_t>(),
grad_in.data_ptr<scalar_t>(),
bottom_rois.data_ptr<scalar_t>(),
roi_cols);
}));

Expand Down
6 changes: 5 additions & 1 deletion encoding/lib/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <exception>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
Expand All @@ -7,6 +8,7 @@
#include <thrust/transform.h>
#include "common.h"

using namespace std;

namespace {

Expand Down Expand Up @@ -40,5 +42,7 @@ void LeakyRelu_Backward_CUDA(at::Tensor z, at::Tensor dz, float slope) {
*/
// unstable after scaling
at::leaky_relu_(z, 1.0 / slope);
at::leaky_relu_backward(dz, z, slope);
// This API is changed on pytorch side, feature broken
throw "PyTorch API break, Don't use InplaceABN for now.";
// at::leaky_relu_backward(dz, z, slope, false);
}
2 changes: 1 addition & 1 deletion encoding/lib/gpu/device_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ struct DeviceTensor<DType, 1> {

template<typename DType, int Dim>
static DeviceTensor<DType, Dim> devicetensor(const at::Tensor &blob) {
DType *data = blob.data<DType>();
DType *data = blob.data_ptr<DType>();
DeviceTensor<DType, Dim> tensor(data, nullptr);
for (int i = 0; i < Dim; ++i) {
tensor.size_[i] = blob.size(i);
Expand Down
50 changes: 25 additions & 25 deletions encoding/lib/gpu/lib_ssd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,10 @@ std::vector<at::Tensor> box_encoder(const int N_img,
const at::Tensor& dbox,
float criteria) {
// Check everything is on the device
AT_ASSERTM(bbox_input.type().is_cuda(), "bboxes must be a CUDA tensor");
AT_ASSERTM(bbox_offsets.type().is_cuda(), "bbox offsets must be a CUDA tensor");
AT_ASSERTM(labels_input.type().is_cuda(), "labels must be a CUDA tensor");
AT_ASSERTM(dbox.type().is_cuda(), "dboxes must be a CUDA tensor");
AT_ASSERTM(bbox_input.is_cuda(), "bboxes must be a CUDA tensor");
AT_ASSERTM(bbox_offsets.is_cuda(), "bbox offsets must be a CUDA tensor");
AT_ASSERTM(labels_input.is_cuda(), "labels must be a CUDA tensor");
AT_ASSERTM(dbox.is_cuda(), "dboxes must be a CUDA tensor");

// Check at least offsets, bboxes and labels are consistent
// Note: offsets is N+1 vs. N for labels
Expand All @@ -374,7 +374,7 @@ std::vector<at::Tensor> box_encoder(const int N_img,
// allocate final outputs (known size)
#ifdef DEBUG
printf("%d x %d\n", N_img * M, 4);
// at::Tensor bbox_out = dbox.type().tensor({N_img * M, 4});
// at::Tensor bbox_out = dbox.scalar_type().tensor({N_img * M, 4});
printf("allocating %lu bytes for output labels\n", N_img*M*sizeof(long));
#endif
at::Tensor labels_out = at::empty({N_img * M}, labels_input.options());
Expand All @@ -398,15 +398,15 @@ std::vector<at::Tensor> box_encoder(const int N_img,
// Encode the inputs
const int THREADS_PER_BLOCK = 256;
encode<THREADS_PER_BLOCK, 256><<<N_img, THREADS_PER_BLOCK, 0, stream.stream()>>>(N_img,
(float4*)bbox_input.data<float>(),
labels_input.data<long>(),
bbox_offsets.data<int>(),
(float4*)bbox_input.data_ptr<float>(),
labels_input.data_ptr<long>(),
bbox_offsets.data_ptr<int>(),
M,
(float4*)dbox.data<float>(),
(float4*)dbox.data_ptr<float>(),
criteria,
workspace.data<uint8_t>(),
(float4*)bbox_out.data<float>(),
labels_out.data<long>());
workspace.data_ptr<uint8_t>(),
(float4*)bbox_out.data_ptr<float>(),
labels_out.data_ptr<long>());

THCudaCheck(cudaGetLastError());
return {bbox_out, labels_out};
Expand All @@ -429,11 +429,11 @@ at::Tensor calc_ious(const int N_img,
// Get IoU of all source x default box pairs
calc_ious_kernel<<<N_img, 256, 0, stream.stream()>>>(
N_img,
(float4*)boxes1.data<float>(),
boxes1_offsets.data<int>(),
(float4*)boxes1.data_ptr<float>(),
boxes1_offsets.data_ptr<int>(),
M,
(float4*)boxes2.data<float>(),
ious.data<float>());
(float4*)boxes2.data_ptr<float>(),
ious.data_ptr<float>());

THCudaCheck(cudaGetLastError());
return ious;
Expand Down Expand Up @@ -543,9 +543,9 @@ std::vector<at::Tensor> random_horiz_flip(
W = img.size(3);
}

assert(img.type().is_cuda());
assert(bboxes.type().is_cuda());
assert(bbox_offsets.type().is_cuda());
assert(img.is_cuda());
assert(bboxes.is_cuda());
assert(bbox_offsets.is_cuda());

// printf("%d %d %d %d\n", N, C, H, W);
// Need temp storage of size img
Expand All @@ -554,20 +554,20 @@ std::vector<at::Tensor> random_horiz_flip(

auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
img.type(),
img.scalar_type(),
"HorizFlipImagesAndBoxes",
[&] {
HorizFlipImagesAndBoxes<scalar_t><<<N, dim3(16, 16), 0, stream.stream()>>>(
N,
C,
H,
W,
img.data<scalar_t>(),
bboxes.data<float>(),
bbox_offsets.data<int>(),
img.data_ptr<scalar_t>(),
bboxes.data_ptr<float>(),
bbox_offsets.data_ptr<int>(),
p,
flip.data<float>(),
tmp_img.data<scalar_t>(),
flip.data_ptr<float>(),
tmp_img.data_ptr<scalar_t>(),
nhwc);
THCudaCheck(cudaGetLastError());
});
Expand Down
18 changes: 9 additions & 9 deletions encoding/lib/gpu/nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ std::vector<at::Tensor> Non_Max_Suppression_CUDA(
AT_ASSERT(input.size(2) == 4);
AT_ASSERT(input.is_contiguous());
AT_ASSERT(scores.is_contiguous());
AT_ASSERT(input.type().scalarType() == at::kFloat || input.type().scalarType() == at::kDouble);
AT_ASSERT(scores.type().scalarType() == at::kFloat || scores.type().scalarType() == at::kDouble);
AT_ASSERT(input.scalar_type() == at::kFloat || input.scalar_type() == at::kDouble);
AT_ASSERT(scores.scalar_type() == at::kFloat || scores.scalar_type() == at::kDouble);

auto num_boxes = input.size(1);
auto batch_size = input.size(0);
Expand All @@ -89,22 +89,22 @@ std::vector<at::Tensor> Non_Max_Suppression_CUDA(
//cudaGetDeviceProperties in the funcion body...

dim3 mask_grid(batch_size);
if(input.type().scalarType() == at::kFloat)
if(input.scalar_type() == at::kFloat)
{
nms_kernel<<<mask_grid, mask_block, 0, at::cuda::getCurrentCUDAStream()>>>(
mask.data<unsigned char>(),
input.data<float>(),
sorted_inds.data<int64_t>(),
mask.data_ptr<unsigned char>(),
input.data_ptr<float>(),
sorted_inds.data_ptr<int64_t>(),
num_boxes,
thresh);
AT_ASSERT(cudaGetLastError() == cudaSuccess);
}
else
{
nms_kernel<<<mask_grid, mask_block, 0, at::cuda::getCurrentCUDAStream()>>>(
mask.data<unsigned char>(),
input.data<double>(),
sorted_inds.data<int64_t>(),
mask.data_ptr<unsigned char>(),
input.data_ptr<double>(),
sorted_inds.data_ptr<int64_t>(),
num_boxes,
thresh);
AT_ASSERT(cudaGetLastError() == cudaSuccess);
Expand Down
12 changes: 6 additions & 6 deletions encoding/lib/gpu/roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -379,16 +379,16 @@ at::Tensor ROIAlign_Forward_CUDA(
0,
at::cuda::getCurrentCUDAStream()>>>(
count,
input.data<scalar_t>(),
input.data_ptr<scalar_t>(),
static_cast<scalar_t>(spatial_scale),
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
rois.data<scalar_t>(),
output.data<scalar_t>());
rois.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
}));
AT_ASSERT(cudaGetLastError() == cudaSuccess);
return output;
Expand Down Expand Up @@ -426,7 +426,7 @@ at::Tensor ROIAlign_Backward_CUDA(
0,
at::cuda::getCurrentCUDAStream()>>>(
count,
grad_output.data<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
num_rois,
static_cast<scalar_t>(spatial_scale),
channels,
Expand All @@ -435,8 +435,8 @@ at::Tensor ROIAlign_Backward_CUDA(
pooled_height,
pooled_width,
sampling_ratio,
grad_in.data<scalar_t>(),
rois.data<scalar_t>());
grad_in.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>());
}));
AT_ASSERT(cudaGetLastError() == cudaSuccess);
Expand Down

0 comments on commit fed540f

Please sign in to comment.