Skip to content

Commit

Permalink
[CUDA] FusedMHARunnerFP16v2 thread-safe (microsoft#21420)
Browse files Browse the repository at this point in the history
### Description
- [x] Rewrite FusedMHARunnerFP16v2 to make it thread-safe. 
- [x] Add multi-threading tests

Previously, the kernel parameters params is stored as a member of mha
runner, which means that different threads might change the params at
the same time and impacts the other threads.

For example, if batch_size and seq_len was changed by another thread to
larger values in setup(...), buffer overrun might happen in run(...)
because a kernel could read/write memory out of range of allocated
buffers.

In new implementation, I change the api and remove mutable member
variables to make it thread safe. Below is summary of change:

Before:
```
class FusedMHARunnerFP16v2::mhaImpl {
   void setup(int seq_len, int batch_size) {
      // change scalar params
   }

   void run(input, output) {
      // change params for input and output pointers
      // launch kernel using params
   }

   Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}
```

After:
```
class FusedMHARunnerFP16v2::FmhaImpl {
   void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
      // change params
   }

   void run(params, input, output) {
      // change params with input and output pointers
      // launch kernel using params
   }
}
```

### Motivation and Context
microsoft#18854
microsoft#21413
  • Loading branch information
tianleiwu committed Jul 22, 2024
1 parent 11bf309 commit a6c5e2c
Show file tree
Hide file tree
Showing 10 changed files with 534 additions and 234 deletions.
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == relative_position_bias &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, true);
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, true);
if (use_causal_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
Expand All @@ -171,8 +171,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == present &&
nullptr == relative_position_bias &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);

if (use_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
Expand All @@ -184,8 +184,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
}

// In case some kernel not loaded due to shared memory limit, we need to double check here.
const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length);
if (fused_fp16_runner_->isValid(S)) {
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length);
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
fused_runner = fused_fp16_runner_.get();
}
}
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,10 @@ Status FusedTrtSelfAttention(

FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(data.fused_runner);

const int S = causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
const int s = causal ? sequence_length : fused_fp16_runner->NormalizeSequenceLength(sequence_length);

// B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed.
const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size);

fused_fp16_runner->setup(S, B);
const int b = (nullptr == data.mask_index ? batch_size : 2 * batch_size);

if (!causal) {
assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H);
Expand All @@ -261,12 +259,12 @@ Status FusedTrtSelfAttention(
packed_qkv = data.query;
}

fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream);
fused_fp16_runner->Run(b, s, packed_qkv, sequence_offset, data.output, stream);
DUMP_TENSOR("fused output", data.output,
batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
} else {
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream);
fused_fp16_runner->Run(b, s, data.gemm_buffer, sequence_offset, data.output, stream);
DUMP_TENSOR("fused causal output", data.output,
batch_size, sequence_length, parameters.num_heads, parameters.v_head_size);
}
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
if (use_fused_runner) {
// Here we assume that num_heads and head_size does not change for a MultiHeadAttention node.
if (nullptr == fused_fp16_runner_.get()) {
Expand All @@ -206,8 +206,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}

// In case some kernel not loaded due to shared memory limit, we need to double check here.
const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length);
if (fused_fp16_runner_->isValid(S)) {
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length);
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
fused_runner = fused_fp16_runner_.get();
}
}
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ MHARunner* TrtFusedAttention<T>::GetFusedRunner(const cudaDeviceProp& device_pro

// Check whether we can use fused kernel
int sm = device_prop.major * 10 + device_prop.minor;
bool is_fMHA_supported = FusedMHARunnerFP16v2::is_supported(sm,
parameters.head_size,
parameters.sequence_length,
enable_trt_flash_attention_,
false /*causal*/);
bool is_fMHA_supported = FusedMHARunnerFP16v2::IsSupported(sm,
parameters.head_size,
parameters.sequence_length,
enable_trt_flash_attention_,
false /*causal*/);

if (!is_fMHA_supported) {
return fused_runner;
Expand All @@ -72,8 +72,8 @@ MHARunner* TrtFusedAttention<T>::GetFusedRunner(const cudaDeviceProp& device_pro
}

// In case some kernel not loaded due to shared memory limit, we need to double check here.
const int S = fused_fp16_runner_->getSFromMaxSeqLen(parameters.sequence_length);
if (fused_fp16_runner_->isValid(S)) {
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(parameters.sequence_length);
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
fused_runner = fused_fp16_runner_.get();
}

Expand Down
7 changes: 3 additions & 4 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,9 @@ Status FusedScaledDotProductAttention(
parameters.token_count, stream);

FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);
const int S = fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
fused_fp16_runner->setup(S, batch_size);

fused_fp16_runner->run(data.workspace, data.cumulative_sequence_length, data.output, stream);
const int normalized_seq_len = fused_fp16_runner->NormalizeSequenceLength(sequence_length);
fused_fp16_runner->Run(batch_size, normalized_seq_len,
data.workspace, data.cumulative_sequence_length, data.output, stream);
return Status::OK();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,8 @@ Status FusedAttentionTrt(
}

FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);
const int S = fused_fp16_runner->getSFromMaxSeqLen(sequence_length);
fused_fp16_runner->setup(S, batch_size);

fused_fp16_runner->run(qkv, data.cumulative_sequence_length, data.output, stream);
const int normalized_seq_len = fused_fp16_runner->NormalizeSequenceLength(sequence_length);
fused_fp16_runner->Run(batch_size, normalized_seq_len, qkv, data.cumulative_sequence_length, data.output, stream);
return Status::OK();
}

Expand Down
Loading

0 comments on commit a6c5e2c

Please sign in to comment.