Skip to content

Commit

Permalink
Update for at::getCurrentCUDAStream getting moved to at::cuda.
Browse files Browse the repository at this point in the history
  • Loading branch information
dukebw committed Aug 19, 2018
1 parent 0403d28 commit de4f24a
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 24 deletions.
8 changes: 4 additions & 4 deletions encoding/lib/gpu/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ static const unsigned WARP_SIZE = 32;
static const unsigned MAX_BLOCK_SIZE = 512U;

template<typename In, typename Out>
struct ScalarConvert {
struct ScalarConverter {
static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; }
};

Expand Down Expand Up @@ -60,9 +60,9 @@ template <typename DType, typename Acctype>
struct Float2 {
Acctype v1, v2;
__device__ Float2() {}
__device__ Float2(DType v1, DType v2) : v1(ScalarConvert<DType, Acctype>::to(v1)), v2(ScalarConvert<DType, Acctype>::to(v2)) {}
__device__ Float2(DType v) : v1(ScalarConvert<DType, Acctype>::to(v)), v2(ScalarConvert<DType, Acctype>::to(v)) {}
__device__ Float2(int v) : v1(ScalarConvert<int, Acctype>::to(v)), v2(ScalarConvert<int, Acctype>::to(v)) {}
__device__ Float2(DType v1, DType v2) : v1(ScalarConverter<DType, Acctype>::to(v1)), v2(ScalarConverter<DType, Acctype>::to(v2)) {}
__device__ Float2(DType v) : v1(ScalarConverter<DType, Acctype>::to(v)), v2(ScalarConverter<DType, Acctype>::to(v)) {}
__device__ Float2(int v) : v1(ScalarConverter<int, Acctype>::to(v)), v2(ScalarConverter<int, Acctype>::to(v)) {}
__device__ Float2& operator+=(const Float2& a) {
v1 += a.v1;
v2 += a.v2;
Expand Down
20 changes: 11 additions & 9 deletions encoding/lib/gpu/encoding_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <ATen/ATen.h>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/CUDAApplyUtils.cuh"
#include <vector>

#include "common.h"
Expand All @@ -12,7 +14,7 @@ struct AggOp {
DeviceTensor<DType, 3> x,
DeviceTensor<DType, 2> c) : A(a), X(x), C(c) {}
__device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) {
return ScalarConvert<DType, Acctype>::to(A[b][i][k] * (X[b][i][d] - C[k][d]));
return ScalarConverter<DType, Acctype>::to(A[b][i][k] * (X[b][i][d] - C[k][d]));
}
DeviceTensor<DType, 3> A;
DeviceTensor<DType, 3> X;
Expand All @@ -25,7 +27,7 @@ struct AggBackOp {
DeviceTensor<DType, 3> x,
DeviceTensor<DType, 2> c) : G(g), X(x), C(c) {}
__device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) {
return ScalarConvert<DType, Acctype>::to(G[b][k][d] * (X[b][i][d] - C[k][d]));
return ScalarConverter<DType, Acctype>::to(G[b][k][d] * (X[b][i][d] - C[k][d]));
}
DeviceTensor<DType, 3> G;
DeviceTensor<DType, 3> X;
Expand All @@ -39,7 +41,7 @@ struct SL2Op {
__device__ __forceinline__ Acctype operator()(int b, int i, int k, int d)
{
DType r = X[b][i][d] - C[k][d];
return ScalarConvert<DType, Acctype>::to(r * r);
return ScalarConverter<DType, Acctype>::to(r * r);
}
DeviceTensor<DType, 3> X;
DeviceTensor<DType, 2> C;
Expand All @@ -55,7 +57,7 @@ struct SL2GradXOp {
) : GSL(gsl), X(x), C(c), S(s) {}
__device__ __forceinline__ Acctype operator()(int b, int i, int k, int d)
{
return ScalarConvert<DType, Acctype>::to(
return ScalarConverter<DType, Acctype>::to(
2 * S[k] * GSL[b][i][k] * (X[b][i][d]-C[k][d]));
}
DeviceTensor<DType, 3> GSL;
Expand Down Expand Up @@ -312,7 +314,7 @@ at::Tensor Aggregate_Forward_CUDA(
const at::Tensor C_) {
/* Device tensors */
auto E_ = A_.type().tensor({A_.size(0), C_.size(0), C_.size(1)}).zero_();
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// B, K, D
dim3 blocks(C_.size(1), C_.size(0), X_.size(0));
dim3 threads(getNumThreads(X_.size(1)));
Expand All @@ -338,7 +340,7 @@ std::vector<at::Tensor> Aggregate_Backward_CUDA(
auto gradA_ = at::zeros_like(A_);
auto gradX_ = at::bmm(A_, GE_);
auto gradC_ = (-GE_ * A_.sum(1).unsqueeze(2)).sum(0);
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// B, K, D
dim3 blocks(C_.size(0), X_.size(1), X_.size(0));
dim3 threads(getNumThreads(C_.size(1)));
Expand All @@ -361,7 +363,7 @@ at::Tensor ScaledL2_Forward_CUDA(
const at::Tensor C_,
const at::Tensor S_) {
auto SL_ = X_.type().tensor({X_.size(0), X_.size(1), C_.size(0)}).zero_();
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(C_.size(0), X_.size(1), X_.size(0));
dim3 threads(getNumThreads(C_.size(1)));

Expand All @@ -388,7 +390,7 @@ std::vector<at::Tensor> ScaledL2_Backward_CUDA(
auto GX_ = at::zeros_like(X_);
auto GC_ = at::zeros_like(C_);
/* kernel function */
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks1(X_.size(2), X_.size(1), X_.size(0));
dim3 threads1(getNumThreads(C_.size(0)));
dim3 blocks2(C_.size(1), C_.size(0));
Expand Down
12 changes: 7 additions & 5 deletions encoding/lib/gpu/roi_align_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <ATen/ATen.h>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/CUDAApplyUtils.cuh"

#include <cuda.h>
#include <cuda_runtime.h>
#include "cuda.h"
#include "cuda_runtime.h"

namespace {

Expand Down Expand Up @@ -375,7 +377,7 @@ at::Tensor ROIAlignForwardCUDA(
<<<ROI_GET_BLOCKS(count),
ROI_CUDA_NUM_THREADS,
0,
at::globalContext().getCurrentCUDAStream()>>>(
at::cuda::getCurrentCUDAStream()>>>(
count,
input.data<scalar_t>(),
static_cast<scalar_t>(spatial_scale),
Expand Down Expand Up @@ -422,7 +424,7 @@ at::Tensor ROIAlignBackwardCUDA(
<<<ROI_GET_BLOCKS(count),
ROI_CUDA_NUM_THREADS,
0,
at::globalContext().getCurrentCUDAStream()>>>(
at::cuda::getCurrentCUDAStream()>>>(
count,
grad_output.data<scalar_t>(),
num_rois,
Expand Down
14 changes: 8 additions & 6 deletions encoding/lib/gpu/syncbn_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <ATen/ATen.h>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/CUDAApplyUtils.cuh"
#include <vector>

#include "common.h"
Expand All @@ -12,7 +14,7 @@ struct GradOp {
: mean(m), input(i), gradOutput(g) {}
__device__ __forceinline__ Float2<DType, Acctype> operator()(int batch, int plane, int n) {
DType g = gradOutput[batch][plane][n];
DType c = ScalarConvert<Acctype, DType>::to(input[batch][plane][n] - mean);
DType c = ScalarConverter<Acctype, DType>::to(input[batch][plane][n] - mean);
return Float2<DType, Acctype>(g, g * c);
}
const Acctype mean;
Expand Down Expand Up @@ -180,7 +182,7 @@ at::Tensor BatchNorm_Forward_CUDA(
const at::Tensor gamma_,
const at::Tensor beta_) {
auto output_ = at::zeros_like(input_);
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Forward_CUDA", ([&] {
Expand Down Expand Up @@ -214,7 +216,7 @@ std::vector<at::Tensor> BatchNorm_Backward_CUDA(
at::Tensor gradMean_ = at::zeros_like(mean_);
at::Tensor gradStd_ = at::zeros_like(std_);
/* cuda utils*/
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Backward_CUDA", ([&] {
Expand Down Expand Up @@ -246,7 +248,7 @@ std::vector<at::Tensor> Sum_Square_Forward_CUDA(
at::Tensor sum_ = input_.type().tensor({input_.size(1)}).zero_();
at::Tensor square_ = input_.type().tensor({input_.size(1)}).zero_();
/* cuda utils*/
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Backward_CUDA", ([&] {
Expand All @@ -269,7 +271,7 @@ at::Tensor Sum_Square_Backward_CUDA(
/* outputs */
at::Tensor gradInput_ = at::zeros_like(input_);
/* cuda utils*/
cudaStream_t stream = at::globalContext().getCurrentCUDAStream();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(input_.size(1));
dim3 threads(getNumThreads(input_.size(2)));
AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Backward_CUDA", ([&] {
Expand Down

0 comments on commit de4f24a

Please sign in to comment.