Skip to content

Commit

Permalink
force update cudnn v3
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Aug 4, 2015
1 parent 219d568 commit a174aa3
Showing 1 changed file with 57 additions and 9 deletions.
66 changes: 57 additions & 9 deletions src/layer/cudnn_convolution_layer-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ class CuDNNConvolutionLayer<gpu> : public ConvolutionLayer<gpu> {
float beta = 0.0f;
if (!init_cudnn_) {
init_cudnn_ = true;
if (use_fast_algo_) {
algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
} else {
algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
}
temp_.set_stream(nodes_out[0]->data.stream_);
utils::Check(cudnnSetStream(handle_, nodes_out[0]->data.stream_->stream_) == CUDNN_STATUS_SUCCESS, "cudnn failed");
utils::Check(cudnnSetFilter4dDescriptor(filter_desc_, dtype_,
Expand All @@ -90,10 +85,55 @@ class CuDNNConvolutionLayer<gpu> : public ConvolutionLayer<gpu> {
out.shape_[2], out.shape_[3]) == CUDNN_STATUS_SUCCESS, "cudnn failed");
utils::Check(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW, dtype_,
1, Parent::bias_.shape_[0], 1, 1) == CUDNN_STATUS_SUCCESS, "cudnn failed");
// cudnn v3
utils::Check(cudnnGetConvolutionForwardAlgorithm(handle_,
in_desc_,
filter_desc_,
conv_desc_,
out_desc_,
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
512<<20,
&algo_) == CUDNN_STATUS_SUCCESS, "cudnn fail");

utils::Check(cudnnGetConvolutionBackwardFilterAlgorithm(handle_,
in_desc_,
out_desc_,
conv_desc_,
filter_desc_,
CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
512<<20,
&back_algo_w_) == CUDNN_STATUS_SUCCESS, "cudnn fail");

utils::Check(cudnnGetConvolutionBackwardDataAlgorithm(handle_,
filter_desc_,
out_desc_,
conv_desc_,
in_desc_,
CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
512<<20,
&back_algo_) == CUDNN_STATUS_SUCCESS, "cudnn fail");
size_t back_size = 0;
size_t back_size_w = 0;
utils::Check(cudnnGetConvolutionBackwardDataWorkspaceSize(handle_,
filter_desc_,
out_desc_,
conv_desc_,
in_desc_,
back_algo_,
&back_size) == CUDNN_STATUS_SUCCESS, "cudnn fail");
utils::Check(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_,
in_desc_,
out_desc_,
conv_desc_,
filter_desc_,
back_algo_w_,
&back_size_w) == CUDNN_STATUS_SUCCESS, "cudnn fail");
back_size = std::max(back_size, back_size_w);
utils::Check(cudnnGetConvolutionForwardWorkspaceSize(handle_, in_desc_,
filter_desc_, conv_desc_,
out_desc_, algo_,
&workspace_size_) == CUDNN_STATUS_SUCCESS, "cudnn failed");
workspace_size_ = std::max(back_size, workspace_size_);
temp_.Resize(mshadow::Shape1(workspace_size_ / sizeof(float) + 1), 0.0f);
}
CHECK(nodes_in[0]->data.CheckContiguous());
Expand Down Expand Up @@ -124,15 +164,19 @@ class CuDNNConvolutionLayer<gpu> : public ConvolutionLayer<gpu> {
&beta,
bias_desc_, Parent::gbias_.dptr_) == CUDNN_STATUS_SUCCESS, "cudnn failed");
}
utils::Check(cudnnConvolutionBackwardFilter(handle_, &alpha,
utils::Check(cudnnConvolutionBackwardFilter_v3(handle_, &alpha,
in_desc_, nodes_in[0]->data.dptr_,
out_desc_, nodes_out[0]->data.dptr_,
conv_desc_, &beta,
conv_desc_, back_algo_w_,
temp_.dptr_, workspace_size_,
&beta,
filter_desc_, Parent::gwmat_.dptr_) == CUDNN_STATUS_SUCCESS, "cudnn failed");
utils::Check(cudnnConvolutionBackwardData(handle_, &alpha,
utils::Check(cudnnConvolutionBackwardData_v3(handle_, &alpha,
filter_desc_, Parent::wmat_.dptr_,
out_desc_, nodes_out[0]->data.dptr_,
conv_desc_, &beta,
conv_desc_, back_algo_,
temp_.dptr_, workspace_size_,
&beta,
in_desc_, nodes_in[0]->data.dptr_) == CUDNN_STATUS_SUCCESS, "cudnn failed");
}
private:
Expand Down Expand Up @@ -165,6 +209,10 @@ class CuDNNConvolutionLayer<gpu> : public ConvolutionLayer<gpu> {
cudnnConvolutionDescriptor_t conv_desc_;
/*! \brief cuDNN conv algorithm */
cudnnConvolutionFwdAlgo_t algo_;
/*! \brief cuDNN back algo for data */
cudnnConvolutionBwdDataAlgo_t back_algo_;
/*! \brief cuDNN back algo for filter */
cudnnConvolutionBwdFilterAlgo_t back_algo_w_;
/*! \brief cuDNN workspace size */
size_t workspace_size_;
/*! \brief cuDNN workspace */
Expand Down

0 comments on commit a174aa3

Please sign in to comment.