Skip to content

Commit

Permalink
fix error due to () not used on operator priority. (#14699)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghuanrong authored and PatriceVignola committed Feb 22, 2023
1 parent a0ab9c2 commit 999747d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
const auto BNS = batch_size * num_heads_ * seq_len;
const size_t elements_in_query = (size_t)BNS * (size_t)head_size;
const size_t elements_after_gemm = (size_t)BNS *(size_t)D;
size_t workspace_size = sizeof(T) * (elements_in_query + (seq_len < D) ? elements_after_gemm : (size_t)0);
bool reuse_output = (seq_len >= D);
size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm));
auto workspace = GetScratchBuffer<void>(workspace_size, context->GetComputeStream());

// format 1: BxSx(NH * total_matrix) => matrix_to_transpose * (BxNxSxH)
Expand All @@ -161,9 +162,9 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
false, head_size, reinterpret_cast<CudaT*>(static_cast<CudaT*>(nullptr)), total_maxtrix);

// reuse output if possible
CudaT* gemm_output = (seq_len < D) ? (reinterpret_cast<CudaT*>(workspace.get()) + elements_in_query)
: reinterpret_cast<CudaT*>(output->template MutableData<T>());
int ld_gemm_output = max(seq_len, D);
CudaT* gemm_output = reuse_output ? reinterpret_cast<CudaT*>(output->template MutableData<T>())
: (reinterpret_cast<CudaT*>(workspace.get()) + elements_in_query);
int ld_gemm_output = reuse_output ? seq_len : D;

const CudaT one = ToCudaType<T>::FromFloat(1.0f);
const CudaT zero = ToCudaType<T>::FromFloat(0.0f);
Expand Down

0 comments on commit 999747d

Please sign in to comment.