Skip to content

Commit

Permalink
Revert pytorch#62143, the new CUDNN_RNN_ALGO_PERSIST_STATIC_SMALL_H a…
Browse files Browse the repository at this point in the history
…lgorithm (pytorch#72089)

Summary:
Revert "[cuDNN] Add a new optimized cuDNN RNN algorithm for small RNN hidden_size (pytorch#62143)"

This reverts commit 965b9f4.

This new cudnn RNN algorithm is causing some failures in our internal testings.

Pull Request resolved: pytorch#72089

Reviewed By: mruberry

Differential Revision: D33905226

Pulled By: ngimel

fbshipit-source-id: 5563a2c275e697477cf79bada3b81a33f1bf2aaa
  • Loading branch information
xwang233 authored and facebook-github-bot committed Feb 1, 2022
1 parent 456d1eb commit 35c240a
Showing 1 changed file with 4 additions and 40 deletions.
44 changes: 4 additions & 40 deletions aten/src/ATen/native/cudnn/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,55 +753,19 @@ namespace {
}
}

inline bool use_rnn_persist_small_h(const RNNDescriptorParams& rnn,
const TensorDescriptorListParams& tensors,
bool forward) {
#if CUDNN_VERSION >= 8201 // 8.2.1
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major < 6) return false;

if (forward) {
if (rnn.mode == CUDNN_RNN_RELU || rnn.mode == CUDNN_RNN_TANH) {
return rnn.hidden_size <= 384;
}
if (rnn.mode == CUDNN_LSTM || rnn.mode == CUDNN_GRU) {
return rnn.hidden_size <= 192;
}
} else /* backward */ {
if (rnn.mode == CUDNN_RNN_RELU || rnn.mode == CUDNN_RNN_TANH) {
return rnn.hidden_size <= 256;
}
if (rnn.mode == CUDNN_LSTM || rnn.mode == CUDNN_GRU) {
return rnn.hidden_size <= 128;
}
}

return false;
#else
return false;
#endif
}

cudnnRNNAlgo_t get_algo(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors, const Tensor input, bool forward) {
cudnnRNNAlgo_t get_algo(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors, const Tensor input) {
// LSTM with projections only works with standard algorithm
if (rnn.proj_size != 0) {
return CUDNN_RNN_ALGO_STANDARD;
}

#if CUDNN_VERSION >= 8201 // 8.2.1
if (use_rnn_persist_small_h(rnn, tensors, forward)) {
return CUDNN_RNN_ALGO_PERSIST_STATIC_SMALL_H;
}
#endif

if (getCudnnDataType(input) == CUDNN_DATA_HALF &&
!tensors.is_input_packed()) {
if (use_persist_common_heuristics(rnn, tensors) &&
use_persist_device_heuristics(rnn, tensors)) {
return CUDNN_RNN_ALGO_PERSIST_STATIC;
}
}

return CUDNN_RNN_ALGO_STANDARD;
}

Expand Down Expand Up @@ -1006,7 +970,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
auto y = output;

auto handle = getCudnnHandle();
cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, true);
cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
fn.rnn.set_algo(algo);
RNNDescriptors descs(fn, handle, x, y, hx, cx);

Expand Down Expand Up @@ -1167,7 +1131,7 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
TORCH_CHECK(dhy.is_cuda() && dy.is_cuda() && (!dcy.defined() || dcy.is_cuda()),
"Gradients aren't CUDA tensors");

cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, false);
cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
fn.rnn.set_algo(algo);
RNNDescriptors descs(fn, handle, x, y, hx, cx);

Expand Down Expand Up @@ -1270,7 +1234,7 @@ std::vector<Tensor> _cudnn_rnn_backward_weight(
const auto& y = output;
auto dw = at::zeros(weight_buf.sizes(), weight_buf.options());

cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input, false);
cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
fn.rnn.set_algo(algo);
RNNDescriptors descs(fn, handle, x, y, hx, cx);

Expand Down

0 comments on commit 35c240a

Please sign in to comment.