From 36747ca280d016df62e1cbeea8428f34eb25ae4a Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 14 Feb 2023 16:50:09 -0800 Subject: [PATCH] fix error due to () not used on operator priority. --- onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index 111fed04639e..9627a1f7c374 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -146,7 +146,8 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c const auto BNS = batch_size * num_heads_ * seq_len; const size_t elements_in_query = (size_t)BNS * (size_t)head_size; const size_t elements_after_gemm = (size_t)BNS *(size_t)D; - size_t workspace_size = sizeof(T) * (elements_in_query + (seq_len < D) ? elements_after_gemm : (size_t)0); + bool reuse_output = (seq_len >= D); + size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm)); auto workspace = GetScratchBuffer(workspace_size, context->GetComputeStream()); // format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH) @@ -161,9 +162,9 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c false, head_size, reinterpret_cast(static_cast(nullptr)), total_maxtrix); // reuse output if possible - CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast(workspace.get()) + elements_in_query) - : reinterpret_cast(output->template MutableData()); - int ld_gemm_output = max(seq_len, D); + CudaT* gemm_output = reuse_output ? reinterpret_cast(output->template MutableData()) + : (reinterpret_cast(workspace.get()) + elements_in_query); + int ld_gemm_output = reuse_output ? seq_len : D; const CudaT one = ToCudaType::FromFloat(1.0f); const CudaT zero = ToCudaType::FromFloat(0.0f);