From a174aa375e76da9c8bc05402fd480354a859327d Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Tue, 4 Aug 2015 17:24:16 -0600 Subject: [PATCH] force update cudnn v3 --- src/layer/cudnn_convolution_layer-inl.hpp | 66 +++++++++++++++++++---- 1 file changed, 57 insertions(+), 9 deletions(-) diff --git a/src/layer/cudnn_convolution_layer-inl.hpp b/src/layer/cudnn_convolution_layer-inl.hpp index c3af4b22..72a29206 100644 --- a/src/layer/cudnn_convolution_layer-inl.hpp +++ b/src/layer/cudnn_convolution_layer-inl.hpp @@ -62,11 +62,6 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { float beta = 0.0f; if (!init_cudnn_) { init_cudnn_ = true; - if (use_fast_algo_) { - algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; - } else { - algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; - } temp_.set_stream(nodes_out[0]->data.stream_); utils::Check(cudnnSetStream(handle_, nodes_out[0]->data.stream_->stream_) == CUDNN_STATUS_SUCCESS, "cudnn failed"); utils::Check(cudnnSetFilter4dDescriptor(filter_desc_, dtype_, @@ -90,10 +85,55 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { out.shape_[2], out.shape_[3]) == CUDNN_STATUS_SUCCESS, "cudnn failed"); utils::Check(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW, dtype_, 1, Parent::bias_.shape_[0], 1, 1) == CUDNN_STATUS_SUCCESS, "cudnn failed"); + // cudnn v3 + utils::Check(cudnnGetConvolutionForwardAlgorithm(handle_, + in_desc_, + filter_desc_, + conv_desc_, + out_desc_, + CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, + 512<<20, + &algo_) == CUDNN_STATUS_SUCCESS, "cudnn fail"); + + utils::Check(cudnnGetConvolutionBackwardFilterAlgorithm(handle_, + in_desc_, + out_desc_, + conv_desc_, + filter_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, + 512<<20, + &back_algo_w_) == CUDNN_STATUS_SUCCESS, "cudnn fail"); + + utils::Check(cudnnGetConvolutionBackwardDataAlgorithm(handle_, + filter_desc_, + out_desc_, + conv_desc_, + in_desc_, + CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, + 512<<20, + &back_algo_) == CUDNN_STATUS_SUCCESS, "cudnn fail"); + size_t back_size = 0; + size_t back_size_w = 0; + utils::Check(cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, + filter_desc_, + out_desc_, + conv_desc_, + in_desc_, + back_algo_, + &back_size) == CUDNN_STATUS_SUCCESS, "cudnn fail"); + utils::Check(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, + in_desc_, + out_desc_, + conv_desc_, + filter_desc_, + back_algo_w_, + &back_size_w) == CUDNN_STATUS_SUCCESS, "cudnn fail"); + back_size = std::max(back_size, back_size_w); utils::Check(cudnnGetConvolutionForwardWorkspaceSize(handle_, in_desc_, filter_desc_, conv_desc_, out_desc_, algo_, &workspace_size_) == CUDNN_STATUS_SUCCESS, "cudnn failed"); + workspace_size_ = std::max(back_size, workspace_size_); temp_.Resize(mshadow::Shape1(workspace_size_ / sizeof(float) + 1), 0.0f); } CHECK(nodes_in[0]->data.CheckContiguous()); @@ -124,15 +164,19 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { &beta, bias_desc_, Parent::gbias_.dptr_) == CUDNN_STATUS_SUCCESS, "cudnn failed"); } - utils::Check(cudnnConvolutionBackwardFilter(handle_, &alpha, + utils::Check(cudnnConvolutionBackwardFilter_v3(handle_, &alpha, in_desc_, nodes_in[0]->data.dptr_, out_desc_, nodes_out[0]->data.dptr_, - conv_desc_, &beta, + conv_desc_, back_algo_w_, + temp_.dptr_, workspace_size_, + &beta, filter_desc_, Parent::gwmat_.dptr_) == CUDNN_STATUS_SUCCESS, "cudnn failed"); - utils::Check(cudnnConvolutionBackwardData(handle_, &alpha, + utils::Check(cudnnConvolutionBackwardData_v3(handle_, &alpha, filter_desc_, Parent::wmat_.dptr_, out_desc_, nodes_out[0]->data.dptr_, - conv_desc_, &beta, + conv_desc_, back_algo_, + temp_.dptr_, workspace_size_, + &beta, in_desc_, nodes_in[0]->data.dptr_) == CUDNN_STATUS_SUCCESS, "cudnn failed"); } private: @@ -165,6 +209,10 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { cudnnConvolutionDescriptor_t conv_desc_; /*! \brief cuDNN conv algorithm */ cudnnConvolutionFwdAlgo_t algo_; + /*! \brief cuDNN back algo for data */ + cudnnConvolutionBwdDataAlgo_t back_algo_; + /*! \brief cuDNN back algo for filter */ + cudnnConvolutionBwdFilterAlgo_t back_algo_w_; /*! \brief cuDNN workspace size */ size_t workspace_size_; /*! \brief cuDNN workspace */