From 45737400a2f3015c11f005ed7603611eaed306a6 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Wed, 15 May 2024 00:14:29 -0700 Subject: [PATCH] [ORT 1.18.0 Release] Cherry pick 3rd/Final round (#20677) Co-authored-by: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Co-authored-by: rachguo Co-authored-by: rachguo Co-authored-by: Tianlei Wu Co-authored-by: George Wu Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: Jian Chen --- csharp/OnnxRuntime.CSharp.proj | 4 + docs/ContribOperators.md | 46 +- docs/OperatorKernels.md | 2 +- .../contrib_ops/cpu/bert/attention_common.h | 8 +- .../cuda/sparse/sparse_attention.cc | 63 +- .../cuda/sparse/sparse_attention_helper.h | 140 ++-- .../cuda/sparse/sparse_attention_impl.cu | 60 +- .../cuda/sparse/sparse_attention_impl.h | 10 +- .../sparse_attention_common.h | 26 +- .../sparse_attention_v2_common.h | 32 +- .../core/graph/contrib_ops/bert_defs.cc | 60 +- .../transformers/test_sparse_attention.py | 74 ++- .../github/apple/package_release_tasks.py | 20 +- .../c-api-noopenmp-packaging-pipelines.yml | 619 ++---------------- .../cuda-packaging-pipeline.yml | 199 ------ .../nuget/templates/dml-vs-2022.yml | 4 +- .../qnn-ep-nuget-packaging-pipeline.yml | 2 + .../stages/java-cuda-packaging-stage.yml | 187 ++++++ .../jobs/linux-gpu-tensorrt-packaging-job.yml | 108 +++ .../stages/nuget-combine-cuda-stage.yml | 68 +- .../nuget-linux-cuda-packaging-stage.yml | 238 +++---- .../stages/nuget-win-cuda-packaging-stage.yml | 39 +- .../azure-pipelines/templates/c-api-cpu.yml | 6 +- .../linux-gpu-tensorrt-packaging-pipeline.yml | 117 ---- .../templates/py-win-arm64-qnn.yml | 2 +- .../templates/py-win-x64-qnn.yml | 2 +- .../templates/react-native-ci.yml | 22 +- .../azure-pipelines/templates/win-ci.yml | 24 +- 28 files changed, 892 insertions(+), 1290 deletions(-) delete mode 100644 tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml create mode 100644 tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml create mode 100644 tools/ci_build/github/azure-pipelines/stages/jobs/linux-gpu-tensorrt-packaging-job.yml delete mode 100644 tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml diff --git a/csharp/OnnxRuntime.CSharp.proj b/csharp/OnnxRuntime.CSharp.proj index ae9b3750ec89..e09c865a8d16 100644 --- a/csharp/OnnxRuntime.CSharp.proj +++ b/csharp/OnnxRuntime.CSharp.proj @@ -50,8 +50,12 @@ CMake creates a target to this project + $(BuildDate) + $(BuildTime) $([System.DateTime]::UtcNow.ToString(yyyyMMdd)) $([System.DateTime]::UtcNow.ToString(hhmm)) + + diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9dc183e0e73b..fc559411df19 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5553,11 +5553,29 @@ This version of the operator has been available since version 1 of the 'com.micr When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically. For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3). - Padding shall be on the right side. + The block_row_indices and block_col_indices are the CSR representation of block mask. The block_col_indices might contain + paddings at the right side when different layout has different number of non-zeros in block mask. - When do_rotary is True, cos_cache and sin_cache are required. + An example of block mask with 2 layouts where each layout is 4 x 4 blocks: + [[[1, 0, 0, 0], + [1, 1, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 1]], + + [[1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 0, 1, 1]]] + + The corresponding CSR format: + block_col_indices = [[0, 0, 1, 1, 2, 1, 2, 3, -1], [0, 0, 1, 0, 1, 2, 0, 2, 3]] + block_row_indices = [[0, 1, 3, 5, 8], [0, 1, 3, 6, 9]] + + When do_rotary is True, cos_cache and sin_cache are required. Note that the maximum sequence length supported by cos + or sin cache can be different from the maximum sequence length used by kv cache. Only supports unidirectional attention with cache of past key and value in linear buffers. + For performance, past_key and present_key share same memory buffer, and past_value and present_value too. #### Version @@ -5581,7 +5599,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of tokens per sparse block. Choices: 16, 32, 64, 128
-#### Inputs (8 - 10) +#### Inputs (9 - 11)
query : T
@@ -5590,20 +5608,22 @@ This version of the operator has been available since version 1 of the 'com.micr
Key with shape (batch_size, sequence_length, kv_num_heads * head_size)
value (optional) : T
Value with shape (batch_size, sequence_length, kv_num_heads * head_size)
-
past_key (optional) : T
-
Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)
-
past_value (optional) : T
-
Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)
-
block_mask : M
-
block mask. 1 indicates attention and 0 no attention. Its shape is (num_layout, max_blocks, max_blocks), where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.
+
past_key : T
+
Key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)
+
past_value : T
+
Value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)
+
block_row_indices : M
+
The row indices of CSR format of block mask with shape (num_layout, max_blocks + 1).The num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.
+
block_col_indices : M
+
The col indices of CSR format of block mask with shape (num_layout, max_nnz_blocks).The max_nnz_blocks is the maximum number of non-zeros per layout in block mask.
total_sequence_length : M
Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.
key_total_sequence_lengths : M
1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.
cos_cache (optional) : T
-
Cos cache of rotary with shape (max_sequence_length, head_size / 2).
+
Cos cache of rotary with shape (max_rotary_sequence_length, head_size / 2).
sin_cache (optional) : T
-
Sin cache of rotary with shape (max_sequence_length, head_size / 2).
+
Sin cache of rotary with shape (max_rotary_sequence_length, head_size / 2).
#### Outputs @@ -5612,9 +5632,9 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
3D output tensor with shape (batch_size, sequence_length, num_heads * head_size)
present_key : T
-
Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).
+
Updated key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).
present_value : T
-
Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).
+
Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).
#### Type Constraints diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 481d36f3f60b..0381c2add86e 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -906,7 +906,7 @@ Do not modify directly.* |SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| -|SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_mask:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|SparseAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* block_row_indices:**M**
*in* block_col_indices:**M**
*in* total_sequence_length:**M**
*in* key_total_sequence_lengths:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 7d07453337fd..a5b9c84c63eb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -126,11 +126,15 @@ struct SparseAttentionParameters { bool rotary_interleaved; // whether to use interleaved rotary embedding int rotary_dim; // rotary embedding dimension int sparse_block_size; // block size for sparse attention - int num_sparse_layout; // number of sparse layout, or the first dimension of block_mask + int num_sparse_layout; // number of sparse layout + int stride_col_indices; // shape of block_col_indices is [num_sparse_layout, stride_col_indices] + int stride_row_indices; // shape of block_row_indices is [num_sparse_layout, stride_row_indices] float scale; // scaling factor applied prior to softmax bool is_packed_qkv; // whether qkv is packed int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys - int max_sequence_length; // max sequence length allowed + int max_sequence_length; // max sequence length for sparse layout + int max_rotary_sequence_length; // max sequence length for rotary cos/sin cache + int max_cache_sequence_length; // max sequence length for kv cache buffer bool past_present_share_buffer; // whether past_key and present_key share buffer, so is past_value and present_value }; diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc index d5673b29cf5b..506a6683de6a 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc @@ -4,7 +4,6 @@ #include "contrib_ops/cuda/sparse/sparse_attention_impl.h" #include "contrib_ops/cuda/sparse/sparse_attention.h" #include "contrib_ops/cuda/sparse/sparse_attention_helper.h" -#include "contrib_ops/cuda/sparse/block_mask.h" #include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h" #include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h" #include "core/platform/env_var_utils.h" @@ -26,7 +25,7 @@ namespace cuda { .TypeConstraint("M", DataTypeImpl::GetTensorType()) \ .MayInplace(3, 1) \ .MayInplace(4, 2) \ - .InputMemoryType(OrtMemTypeCPUInput, 6), \ + .InputMemoryType(OrtMemTypeCPUInput, 7), \ SparseAttention); REGISTER_KERNEL_TYPED(MLFloat16) @@ -77,15 +76,16 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* value = context->Input(2); const Tensor* past_key = context->Input(3); const Tensor* past_value = context->Input(4); - const Tensor* block_mask = context->Input(5); - const Tensor* total_seq_len = context->Input(6); - const Tensor* seqlens_k_total = context->Input(7); - const Tensor* cos_cache = context->Input(8); - const Tensor* sin_cache = context->Input(9); + const Tensor* block_row_indices = context->Input(5); + const Tensor* block_col_indices = context->Input(6); + const Tensor* total_seq_len = context->Input(7); + const Tensor* seqlens_k_total = context->Input(8); + const Tensor* cos_cache = context->Input(9); + const Tensor* sin_cache = context->Input(10); SparseAttentionParameters parameters; - // Parameters from node attribute + // Parameters from node attribute shall be set before calling CheckInputs parameters.sparse_block_size = sparse_block_size_; parameters.num_heads = num_heads_; parameters.kv_num_heads = kv_num_heads_; @@ -101,7 +101,8 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { past_value, cos_cache, sin_cache, - block_mask, + block_row_indices, + block_col_indices, seqlens_k_total, total_seq_len)); // Some limitations of CUDA kernels @@ -177,7 +178,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, output_shape); std::vector present_dims = { - parameters.batch_size, parameters.kv_num_heads, parameters.max_sequence_length, parameters.head_size}; + parameters.batch_size, parameters.kv_num_heads, parameters.max_cache_sequence_length, parameters.head_size}; TensorShape present_shape(present_dims); Tensor* present_key = context->Output(1, present_shape); Tensor* present_value = context->Output(2, present_shape); @@ -188,13 +189,12 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { data.query = reinterpret_cast(query->Data()); data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data()); data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); - data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); - data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); - data.block_mask = block_mask->Data(); + data.past_key = reinterpret_cast(past_key->Data()); + data.past_value = reinterpret_cast(past_value->Data()); data.seqlens_k_total = (nullptr == seqlens_k_total) ? nullptr : seqlens_k_total->Data(); data.output = reinterpret_cast(output->MutableData()); - data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); - data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); + data.present_key = reinterpret_cast(present_key->MutableData()); + data.present_value = reinterpret_cast(present_value->MutableData()); // Check past and present share buffer. parameters.past_present_share_buffer = (data.past_key != nullptr && data.past_key == data.present_key); @@ -214,29 +214,9 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { // Currently, we use same block size in kernel. // TODO: support kernel block size that is smaller than sparse_block_size in tunable (need expand block mask). data.kernel_layout.block_size = parameters.sparse_block_size; - data.kernel_layout.mask = data.block_mask; data.kernel_layout.num_layout = parameters.num_sparse_layout; - data.kernel_layout.num_cols = parameters.max_sequence_length / data.kernel_layout.block_size; - data.kernel_layout.num_rows = parameters.max_sequence_length / data.kernel_layout.block_size; - - // Allocate buffer for CSR col and row indices. - onnxruntime::Stream* stream = context->GetComputeStream(); - int dense_blocks = data.kernel_layout.num_layout * data.kernel_layout.num_cols * data.kernel_layout.num_rows; - auto csr_col_indices_buffer = GetScratchBuffer(static_cast(dense_blocks), stream); - auto csr_row_indices_buffer = GetScratchBuffer( - static_cast(data.kernel_layout.num_layout * (data.kernel_layout.num_rows + 1)), stream); - - data.kernel_layout.csr_col_indices = reinterpret_cast(csr_col_indices_buffer.get()); - data.kernel_layout.csr_row_indices = reinterpret_cast(csr_row_indices_buffer.get()); - - ConvertMaskToCSR(cuda_stream, - data.kernel_layout.mask, - data.kernel_layout.num_layout, - data.kernel_layout.num_rows, - data.kernel_layout.num_cols, - csr_row_indices_buffer.get(), - csr_col_indices_buffer.get(), - device_prop.maxThreadsPerBlock); + data.kernel_layout.csr_col_indices = block_col_indices->Data(); + data.kernel_layout.csr_row_indices = block_row_indices->Data(); size_t rotary_buffer_bytes = 0; if (do_rotary_) { @@ -244,7 +224,8 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length * parameters.head_size; rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length; } - auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, context->GetComputeStream()); + onnxruntime::Stream* stream = context->GetComputeStream(); + auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, stream); data.rotary_buffer = reinterpret_cast(rotary_buffer.get()); size_t transposed_q_bytes = 0; @@ -252,7 +233,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { transposed_q_bytes = parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(T); } - auto transposed_q_buffer = GetScratchBuffer(transposed_q_bytes, context->GetComputeStream()); + auto transposed_q_buffer = GetScratchBuffer(transposed_q_bytes, stream); if (transposed_q_buffer) { data.transposed_q_buffer = reinterpret_cast(transposed_q_buffer.get()); } @@ -263,7 +244,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T)); } - auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, context->GetComputeStream()); + auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, stream); if (unpacked_qkv_buffer) { data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); } @@ -327,7 +308,7 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { } } - v2_kernel_buffer = GetScratchBuffer(v2_kernel_buffer_size, context->GetComputeStream()); + v2_kernel_buffer = GetScratchBuffer(v2_kernel_buffer_size, stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(v2_kernel_buffer.get(), v2_kernel_inputs_pinned, sizeof(int32_t) * v2_kernel_buffer_size, cudaMemcpyHostToDevice, cuda_stream)); diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h index 7e98b374c455..a5f1d50e618a 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h @@ -19,7 +19,8 @@ Status CheckInputs(void* params, const Tensor* past_value, const Tensor* cos_cache, const Tensor* sin_cache, - const Tensor* block_mask, + const Tensor* block_row_indices, + const Tensor* block_col_indices, const Tensor* seqlens_k_total, const Tensor* total_seq_len) { // No packing for q/k/v: @@ -31,14 +32,14 @@ Status CheckInputs(void* params, // key nullptr // value nullptr // Shape for other inputs: - // past_key (batch_size, kv_num_heads, max_sequence_length, head_size) or nullptr - // past_value (batch_size, kv_num_heads, max_sequence_length, head_size) or nullptr - // block_mask (num_heads, max_blocks, max_blocks) or (1, max_blocks, max_blocks) - // where max_blocks = max_sequence_length / sparse_block_size + // past_key (batch_size, kv_num_heads, max_cache_sequence_length, head_size) + // past_value (batch_size, kv_num_heads, max_cache_sequence_length, head_size) + // block_row_indices (num_layout, max_blocks + 1), where max_blocks = max_sequence_length / sparse_block_size + // block_col_indices (num_layout, max_nnz) // seqlens_k_total (batch_size) when do_rotary is True, optional otherwise // total_seq_len (1) - // cos_cache (max_sequence_length, rotary_dim / 2) when do_rotary is true. - // sin_cache (max_sequence_length, rotary_dim / 2) when do_rotary is true. + // cos_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true. + // sin_cache (max_rotary_sequence_length, rotary_dim / 2) when do_rotary is true. assert(params != nullptr); SparseAttentionParameters* parameters = reinterpret_cast(params); @@ -121,57 +122,78 @@ Status CheckInputs(void* params, kv_hidden_size = head_size * kv_num_heads; } - const auto& block_mask_dim = block_mask->Shape().GetDims(); - if (!(block_mask_dim.size() == 3 && block_mask_dim[1] == block_mask_dim[2] && - (static_cast(num_heads) % block_mask_dim[0] == 0L))) { + if (!onnxruntime::IsScalarOr1ElementVector(total_seq_len)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "total_sequence_length tensor must be of one element."); + } + int total_sequence_length = *((*total_seq_len).template Data()); + + // Check block_row_indices + const auto& block_row_indices_dim = block_row_indices->Shape().GetDims(); + if (!(block_row_indices_dim.size() == 2 && + block_row_indices_dim[1] > 1 && + (static_cast(num_heads) % block_row_indices_dim[0] == 0L))) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "block_row_indices must have shape (num_layout, max_blocks + 1) where num_heads is divisible by num_layout."); + } + int max_blocks = static_cast(block_row_indices_dim[1]) - 1; + + // Check block_col_indices + const auto& block_col_indices_dim = block_col_indices->Shape().GetDims(); + if (!(block_col_indices_dim.size() == 2 && + block_col_indices_dim[0] == block_row_indices_dim[0] && + block_col_indices_dim[1] <= max_blocks * max_blocks)) { return ORT_MAKE_STATUS( ONNXRUNTIME, INVALID_ARGUMENT, - "block_mask must have shape (num_layout, max_blocks, max_blocks) where num_heads is divisible by num_layout."); + "block_col_indices must have shape (num_layout, max_nnz), " + "where max_nnz <= max_blocks * max_blocks."); } - int max_blocks = static_cast(block_mask_dim[1]); int max_sequence_length = max_blocks * parameters->sparse_block_size; + if (max_sequence_length < total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "max_sequence_length should be no less than total_sequence_length:", + total_sequence_length, + ", max_sequence_length deduced from block_row_indices:", max_sequence_length); + } - // Check past-present KV - if (past_key != nullptr && past_value != nullptr) { - if (past_key->Shape() != past_value->Shape()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall have same shape"); - } + // Check kv cache + ORT_ENFORCE(past_key != nullptr && past_value != nullptr); + if (past_key->Shape() != past_value->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' and 'past_value' shall have same shape"); + } - const auto& past_key_dims = past_key->Shape().GetDims(); - if (past_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' is expected to have 4 dimensions, got ", - past_key_dims.size()); - } + const auto& past_key_dims = past_key->Shape().GetDims(); + if (past_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' is expected to have 4 dimensions, got ", + past_key_dims.size()); + } - if (past_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 0 should be batch_size ", batch_size, ", got ", - past_key_dims[0]); - } + if (past_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 0 should be batch_size ", batch_size, ", got ", + past_key_dims[0]); + } - if (past_key_dims[is_past_bsnh ? 2 : 1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' shall have kv_num_heads"); - } + if (past_key_dims[is_past_bsnh ? 2 : 1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' shall have kv_num_heads"); + } - int max_cache_sequence_length = static_cast(past_key_dims[is_past_bsnh ? 1 : 2]); - if (max_cache_sequence_length != max_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'block_mask' should have the same sequence length:", - "max_sequence_length deduced from past_key is ", max_cache_sequence_length, - "; max_sequence_length deduced from block_mask is ", max_sequence_length); - } + int max_cache_sequence_length = static_cast(past_key_dims[is_past_bsnh ? 1 : 2]); + if (max_cache_sequence_length < total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "max_cache_sequence_length should be no less than total_sequence_length:", + total_sequence_length, + ", max_cache_sequence_length:", max_cache_sequence_length); + } - if (past_key_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); - } - } else if (past_key != nullptr || past_value != nullptr) { + if (past_key_dims[3] != head_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall be both present or both absent."); + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3]); } // Check the shape of total_key_sequence_lengths. We do not check the values here. @@ -181,13 +203,8 @@ Status CheckInputs(void* params, "key_total_sequence_lengths must have shape (batch_size)."); } - if (!onnxruntime::IsScalarOr1ElementVector(total_seq_len)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "total_sequence_length tensor must be of one element."); - } - int total_sequence_length = *((*total_seq_len).template Data()); - int rotary_dim = 0; + int max_rotary_sequence_length = 0; if (do_rotary) { if (cos_cache == nullptr || sin_cache == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -202,14 +219,19 @@ Status CheckInputs(void* params, "head_size shall be a multiple of 16. Got head_size = ", head_size); } - if (cos_dims[0] < total_sequence_length) { + if (cos_dims[0] != sin_dims[0]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be not be less than total_sequence_length."); + "cos_cache and sin_cache dimension 0 should be same size."); } - if (sin_dims[0] < total_sequence_length) { + + max_rotary_sequence_length = static_cast(cos_dims[0]); + if (max_rotary_sequence_length < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be not be less than total_sequence_length."); + "max_rotary_sequence_length should be no less than total_sequence_length:", + total_sequence_length, + ", max_rotary_sequence_length:", max_rotary_sequence_length); } + if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); @@ -229,12 +251,16 @@ Status CheckInputs(void* params, parameters->sequence_length = sequence_length; parameters->total_sequence_length = total_sequence_length; parameters->max_sequence_length = max_sequence_length; + parameters->max_cache_sequence_length = max_cache_sequence_length; + parameters->max_rotary_sequence_length = max_rotary_sequence_length; parameters->hidden_size = q_hidden_size; parameters->head_size = head_size; parameters->kv_hidden_size = kv_hidden_size; parameters->rotary_dim = rotary_dim; parameters->is_packed_qkv = is_packed_qkv; - parameters->num_sparse_layout = static_cast(block_mask_dim[0]); + parameters->num_sparse_layout = static_cast(block_row_indices_dim[0]); + parameters->stride_row_indices = static_cast(block_row_indices_dim[1]); + parameters->stride_col_indices = static_cast(block_col_indices_dim[1]); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu index 5d6182b613de..d833a7cf0298 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "contrib_ops/cuda/sparse/sparse_attention_impl.h" -#include "contrib_ops/cuda/sparse/block_mask.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/rotary_embedding_impl.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" @@ -88,7 +87,7 @@ Status LaunchConcatKVInPlace(contrib::SparseAttentionParameters& parameters, return LaunchConcatKVInPlace(parameters.batch_size, parameters.kv_num_heads, parameters.head_size, - parameters.max_sequence_length, + parameters.max_cache_sequence_length, nullptr, data.seqlens_k_total, parameters.sequence_length, @@ -112,7 +111,6 @@ Status QkvToContext( const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - // const int present_sequence_length = parameters.max_sequence_length; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; @@ -182,14 +180,14 @@ Status QkvToContext( position_ids_buff, data.cos_cache, data.sin_cache, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size, - parameters.rotary_dim, parameters.max_sequence_length, + parameters.rotary_dim, parameters.max_rotary_sequence_length, /*position_ids_format*/ 1, parameters.rotary_interleaved, max_threads_per_block, q_layout)); ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, k_buffer, reinterpret_cast(key), position_ids_buff, data.cos_cache, data.sin_cache, parameters.batch_size, parameters.sequence_length, parameters.kv_num_heads, parameters.head_size, - parameters.rotary_dim, parameters.max_sequence_length, + parameters.rotary_dim, parameters.max_rotary_sequence_length, /*position_ids_format*/ 1, parameters.rotary_interleaved, max_threads_per_block, kv_layout)); query = reinterpret_cast(q_buffer); @@ -215,29 +213,24 @@ Status QkvToContext( // TODO: only dump to total sequence length instead of max sequence length. #if DUMP_TENSOR_LEVEL > 0 - DUMP_TENSOR("key cache", data.present_key, batch_size, kv_num_heads, parameters.max_sequence_length, head_size); - DUMP_TENSOR("value cache", data.present_value, batch_size, kv_num_heads, parameters.max_sequence_length, head_size); - - DUMP_TENSOR("block_mask", - data.kernel_layout.mask, - data.kernel_layout.num_layout, - data.kernel_layout.num_rows, - data.kernel_layout.num_cols); + DUMP_TENSOR("key cache", data.present_key, batch_size, kv_num_heads, + parameters.max_cache_sequence_length, head_size); + DUMP_TENSOR("value cache", data.present_value, batch_size, kv_num_heads, + parameters.max_cache_sequence_length, head_size); DUMP_TENSOR("csr_col_indices", data.kernel_layout.csr_col_indices, data.kernel_layout.num_layout, - data.kernel_layout.num_rows, - data.kernel_layout.num_cols); + parameters.stride_col_indices); DUMP_TENSOR("csr_row_indices", data.kernel_layout.csr_row_indices, data.kernel_layout.num_layout, - data.kernel_layout.num_rows + 1); + parameters.stride_row_indices); printf( "batch_size=%d, sequence_length=%d, num_heads=%d, kv_num_heads=%d head_size=%d, " - "total_sequence_length=%d, max_sequence_length=%d scale=%f block_size=%d " + "total_sequence_length=%d, max_sequence_length=%d max_cache_sequence_length=%d scale=%f block_size=%d " "row_stride=%d col_stride=%d num_layout=%d\n", parameters.batch_size, parameters.sequence_length, @@ -246,10 +239,11 @@ Status QkvToContext( parameters.head_size, parameters.total_sequence_length, parameters.max_sequence_length, + parameters.max_cache_sequence_length, parameters.scale, data.kernel_layout.block_size, - data.kernel_layout.num_rows + 1, - data.kernel_layout.num_rows * data.kernel_layout.num_cols, + parameters.stride_row_indices, + parameters.stride_col_indices, data.kernel_layout.num_layout); #endif @@ -262,19 +256,20 @@ Status QkvToContext( reinterpret_cast(query), reinterpret_cast(data.present_key), reinterpret_cast(data.present_value), + q_layout == LAYOUT_BNSH, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.total_sequence_length, - parameters.max_sequence_length, + parameters.max_cache_sequence_length, parameters.scale, - data.kernel_layout.block_size, // kernel_block_size - data.kernel_layout.csr_row_indices, // skip past_seq_len in row indices - data.kernel_layout.csr_col_indices, // (num_layout, num_rows, num_cols) - data.kernel_layout.num_rows + 1, // stride per head in row indices - data.kernel_layout.num_rows * data.kernel_layout.num_cols, // stride per head in col indices + data.kernel_layout.block_size, // kernel_block_size + data.kernel_layout.csr_row_indices, // shape (num_layout, stride_row_indices) + data.kernel_layout.csr_col_indices, // shape (num_layout, stride_col_indices) + parameters.stride_row_indices, + parameters.stride_col_indices, data.kernel_layout.num_layout, data.active_q_blocks, data.q_batch_starts, @@ -297,19 +292,20 @@ Status QkvToContext( reinterpret_cast(query), reinterpret_cast(data.present_key), reinterpret_cast(data.present_value), + q_layout == LAYOUT_BNSH, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.total_sequence_length, - parameters.max_sequence_length, + parameters.max_cache_sequence_length, parameters.scale, - data.kernel_layout.block_size, // kernel_block_size - data.kernel_layout.csr_row_indices, // (num_layout, num_rows + 1) - data.kernel_layout.csr_col_indices, // (num_layout, num_rows, num_cols) - data.kernel_layout.num_rows + 1, // stride per head in row indices - data.kernel_layout.num_rows * data.kernel_layout.num_cols, // stride per head in col indices + data.kernel_layout.block_size, // kernel_block_size + data.kernel_layout.csr_row_indices, // (num_layout, stride_row_indices) + data.kernel_layout.csr_col_indices, // (num_layout, stride_row_indices) + parameters.stride_row_indices, + parameters.stride_col_indices, data.kernel_layout.num_layout); if constexpr (std::is_same::value) { @@ -319,7 +315,7 @@ Status QkvToContext( } } - DUMP_TENSOR("output", reinterpret_cast(data.output), batch_size, num_heads, sequence_length, head_size); + DUMP_TENSOR("output", reinterpret_cast(data.output), batch_size, sequence_length, num_heads, head_size); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h index 03e2a3dd08f6..0b07b234b731 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h @@ -16,14 +16,10 @@ namespace contrib { namespace cuda { struct BlockLayout { - const int32_t* mask; // shape (num_layout, num_rows, num_cols), where num_rows = num_cols = max_seq_len / block_size. int num_layout; - int block_size; // kernel block size, which is <= sparse_block_size - - const int* csr_col_indices; - const int* csr_row_indices; - int num_rows; - int num_cols; + int block_size; // kernel block size, which is <= sparse_block_size + const int* csr_row_indices; // shape [num_layout, stride_row_indices] + const int* csr_col_indices; // shape [num_layout, stride_col_indices] }; template diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h index d69d3621d0ec..a90c603d7dea 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h @@ -22,6 +22,8 @@ struct SparseAttentionParams { const void* k; const void* v; + bool is_q_bnsh; + int batch_size; int num_heads; int kv_num_heads; @@ -30,7 +32,7 @@ struct SparseAttentionParams { int sequence_length; int past_sequence_length; int total_sequence_length; - int max_sequence_length; + int max_cache_sequence_length; float scale; @@ -64,13 +66,14 @@ struct SparseAttentionParams { const void* q, const void* k, const void* v, + bool is_q_bnsh, int batch_size, int sequence_length, int num_heads, int kv_num_heads, int head_size, int total_sequence_length, - int max_sequence_length, + int max_cache_sequence_length, float scale, int kernel_block_size, const int* layout_csr_row_indices, @@ -84,6 +87,7 @@ struct SparseAttentionParams { this->q = q; this->k = k; this->v = v; + this->is_q_bnsh = is_q_bnsh; this->batch_size = batch_size; this->sequence_length = sequence_length; this->num_heads = num_heads; @@ -91,7 +95,7 @@ struct SparseAttentionParams { this->head_size = head_size; this->past_sequence_length = total_sequence_length - sequence_length; this->total_sequence_length = total_sequence_length; - this->max_sequence_length = max_sequence_length; + this->max_cache_sequence_length = max_cache_sequence_length; this->scale = scale == 0.0f ? 1.0f / sqrtf(static_cast(head_size)) : scale; this->kernel_block_size = kernel_block_size; this->layout_csr_row_indices = layout_csr_row_indices; @@ -101,18 +105,16 @@ struct SparseAttentionParams { this->num_layout = num_layout; this->stride_qb = this->num_heads * this->sequence_length * this->head_size; - this->stride_qh = this->sequence_length * this->head_size; + this->stride_qh = (is_q_bnsh ? this->sequence_length : this->num_heads) * this->head_size; this->stride_qm = this->head_size; // When kv buffer has max sequence length, stride should match max sequence length. - int kv_buffer_sequence_length = max_sequence_length; - // KV cache is in BNSH format - this->stride_kb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size; - this->stride_kh = kv_buffer_sequence_length * this->head_size; + this->stride_kb = this->kv_num_heads * max_cache_sequence_length * this->head_size; + this->stride_kh = max_cache_sequence_length * this->head_size; this->stride_kn = this->head_size; - this->stride_vb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size; - this->stride_vh = kv_buffer_sequence_length * this->head_size; + this->stride_vb = this->kv_num_heads * max_cache_sequence_length * this->head_size; + this->stride_vh = max_cache_sequence_length * this->head_size; this->stride_vn = this->head_size; // Output is BSNH format @@ -142,8 +144,8 @@ struct SparseAttentionParams { #if DUMP_TENSOR_LEVEL > 0 DUMP_TENSOR_INIT(); DUMP_TENSOR("q", reinterpret_cast(q), batch_size, num_heads, sequence_length, head_size); - DUMP_TENSOR("k", reinterpret_cast(k), batch_size, kv_num_heads, max_sequence_length, head_size); - DUMP_TENSOR("v", reinterpret_cast(v), batch_size, kv_num_heads, max_sequence_length, head_size); + DUMP_TENSOR("k", reinterpret_cast(k), batch_size, kv_num_heads, max_cache_sequence_length, head_size); + DUMP_TENSOR("v", reinterpret_cast(v), batch_size, kv_num_heads, max_cache_sequence_length, head_size); DUMP_TENSOR("csr_col_indices", layout_csr_col_indices, num_layout, diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h index 328cb8b5d8f8..af19a90b323d 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h @@ -18,6 +18,8 @@ struct SparseAttentionParams { const void* k; const void* v; + bool is_q_bnsh; + int batch_size; int num_heads; int kv_num_heads; @@ -26,7 +28,7 @@ struct SparseAttentionParams { int sequence_length; int past_sequence_length; int total_sequence_length; - int max_sequence_length; + int max_cache_sequence_length; float scale; @@ -70,13 +72,14 @@ struct SparseAttentionParams { const void* q, const void* k, const void* v, + bool is_q_bnsh, int batch_size, int sequence_length, int num_heads, int kv_num_heads, int head_size, int total_sequence_length, - int max_sequence_length, + int max_cache_sequence_length, float scale, int kernel_block_size, const int* layout_csr_row_indices, @@ -97,6 +100,7 @@ struct SparseAttentionParams { this->q = q; this->k = k; this->v = v; + this->is_q_bnsh = is_q_bnsh; this->batch_size = batch_size; this->sequence_length = sequence_length; this->num_heads = num_heads; @@ -104,7 +108,7 @@ struct SparseAttentionParams { this->head_size = head_size; this->past_sequence_length = total_sequence_length - sequence_length; this->total_sequence_length = total_sequence_length; - this->max_sequence_length = max_sequence_length; + this->max_cache_sequence_length = max_cache_sequence_length; this->scale = scale == 0.0f ? 1.0f / sqrtf(static_cast(head_size)) : scale; this->kernel_block_size = kernel_block_size; this->layout_csr_row_indices = layout_csr_row_indices; @@ -113,20 +117,18 @@ struct SparseAttentionParams { this->layout_col_stride_h = layout_col_stride_h; this->num_layout = num_layout; - // Q is in BNSH format + // Q can be either BNSH or BSNH format this->stride_qb = this->num_heads * this->sequence_length * this->head_size; - this->stride_qh = this->sequence_length * this->head_size; + this->stride_qh = (is_q_bnsh ? this->sequence_length : this->num_heads) * this->head_size; this->stride_qt = this->head_size; // When kv buffer has max sequence length, stride should match max sequence length. - int kv_buffer_sequence_length = max_sequence_length; - // KV cache is in BNSH format - this->stride_kb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size; - this->stride_kh = kv_buffer_sequence_length * this->head_size; + this->stride_kb = this->kv_num_heads * max_cache_sequence_length * this->head_size; + this->stride_kh = max_cache_sequence_length * this->head_size; this->stride_kt = this->head_size; - this->stride_vb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size; - this->stride_vh = kv_buffer_sequence_length * this->head_size; + this->stride_vb = this->kv_num_heads * max_cache_sequence_length * this->head_size; + this->stride_vh = max_cache_sequence_length * this->head_size; this->stride_vt = this->head_size; // Output is BSNH format @@ -167,8 +169,8 @@ struct SparseAttentionParams { #if DUMP_TENSOR_LEVEL > 0 DUMP_TENSOR_INIT(); DUMP_TENSOR("q", reinterpret_cast(q), batch_size, num_heads, sequence_length, head_size); - DUMP_TENSOR("k", reinterpret_cast(k), batch_size, kv_num_heads, max_sequence_length, head_size); - DUMP_TENSOR("v", reinterpret_cast(v), batch_size, kv_num_heads, max_sequence_length, head_size); + DUMP_TENSOR("k", reinterpret_cast(k), batch_size, kv_num_heads, max_cache_sequence_length, head_size); + DUMP_TENSOR("v", reinterpret_cast(v), batch_size, kv_num_heads, max_cache_sequence_length, head_size); DUMP_TENSOR("csr_col_indices", layout_csr_col_indices, num_layout, @@ -187,13 +189,13 @@ struct SparseAttentionParams { DUMP_TENSOR("q_start_sids", q_start_sids, 1, active_q_blocks); printf( - "layout_row_stride_h=%d, layout_col_stride_h=%d, num_layout=%d, scale=%f,\n" + "layout_row_stride_h=%d, layout_col_stride_h=%d, num_layout=%d, scale=%f, is_q_bnsh=%d,\n" "stride_qb=%d, stride_qt=%d, stride_qh=%d, stride_kb=%d, stride_kt=%d, stride_kh=%d,\n" "stride_vb=%d, stride_vt=%d, stride_vh=%d, stride_ob=%d, stride_ot=%d, stride_oh=%d,\n" "num_heads=%d, kv_num_heads=%d, total_sequence_length=%d, past_sequence_length=%d\n" "output=%p, q=%p, k=%p, v=%p, layout_csr_row_indices=%p, layout_csr_col_indices=%p\n" "q_batch_starts=%p, q_batch_ends=%p, k_batch_starts=%p, k_batch_ends=%p, q_batch_ids=%p, q_start_sids=%p active_q_blocks=%d\n", - layout_row_stride_h, layout_col_stride_h, num_layout, scale, + layout_row_stride_h, layout_col_stride_h, num_layout, scale, static_cast(is_q_bnsh), stride_qb, stride_qt, stride_qh, stride_kb, stride_kt, stride_kh, stride_vb, stride_vt, stride_vh, stride_ob, stride_ot, stride_oh, num_heads, kv_num_heads, total_sequence_length, past_sequence_length, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 6cac8f9a53af..50b4f56813ac 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -287,7 +287,7 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); - // past key has shape (batch_size, kv_num_heads, max_sequence_length, head_size) + // past key has shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size) if (past_dims.size() != 4) { fail_shape_inference("The past_key input shall be 4 dimensions"); } @@ -1151,11 +1151,29 @@ block_mask can be used to configure sparse layout for different head. When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically. For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3). -Padding shall be on the right side. +The block_row_indices and block_col_indices are the CSR representation of block mask. The block_col_indices might contain +paddings at the right side when different layout has different number of non-zeros in block mask. -When do_rotary is True, cos_cache and sin_cache are required. +An example of block mask with 2 layouts where each layout is 4 x 4 blocks: + [[[1, 0, 0, 0], + [1, 1, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 1]], + + [[1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 0, 1, 1]]] + +The corresponding CSR format: + block_col_indices = [[0, 0, 1, 1, 2, 1, 2, 3, -1], [0, 0, 1, 0, 1, 2, 0, 2, 3]] + block_row_indices = [[0, 1, 3, 5, 8], [0, 1, 3, 6, 9]] + +When do_rotary is True, cos_cache and sin_cache are required. Note that the maximum sequence length supported by cos +or sin cache can be different from the maximum sequence length used by kv cache. Only supports unidirectional attention with cache of past key and value in linear buffers. + For performance, past_key and present_key share same memory buffer, and past_value and present_value too. )DOC"; @@ -1189,36 +1207,38 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(3, "past_key", - "Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)", - "T", - OpSchema::Optional) + "Key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)", + "T") .Input(4, "past_value", - "Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)", - "T", - OpSchema::Optional) + "Value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size)", + "T") .Input(5, - "block_mask", - "block mask. 1 indicates attention and 0 no attention. " - "Its shape is (num_layout, max_blocks, max_blocks), " - "where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.", + "block_row_indices", + "The row indices of CSR format of block mask with shape (num_layout, max_blocks + 1)." + "The num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.", "M") .Input(6, + "block_col_indices", + "The col indices of CSR format of block mask with shape (num_layout, max_nnz_blocks)." + "The max_nnz_blocks is the maximum number of non-zeros per layout in block mask.", + "M") + .Input(7, "total_sequence_length", "Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.", "M") - .Input(7, + .Input(8, "key_total_sequence_lengths", "1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.", "M") - .Input(8, + .Input(9, "cos_cache", - "Cos cache of rotary with shape (max_sequence_length, head_size / 2).", + "Cos cache of rotary with shape (max_rotary_sequence_length, head_size / 2).", "T", OpSchema::Optional) - .Input(9, + .Input(10, "sin_cache", - "Sin cache of rotary with shape (max_sequence_length, head_size / 2).", + "Sin cache of rotary with shape (max_rotary_sequence_length, head_size / 2).", "T", OpSchema::Optional) .Output(0, @@ -1227,11 +1247,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Output(1, "present_key", - "Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).", + "Updated key cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).", "T") .Output(2, "present_value", - "Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).", + "Updated value cache with shape (batch_size, kv_num_heads, max_cache_sequence_length, head_size).", "T") .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain integer type.") diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index 9fcc23288ef4..2dd6dc627b09 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -38,21 +38,27 @@ def __init__( dtype=torch.float16, share_buffer: bool = True, is_packed_qkv: bool = False, + max_cache_sequence_length=None, + max_rotary_sequence_length=None, ): self.operator = operator self.batch_size = batch_size self.sequence_length = sequence_length self.max_sequence_length = max_sequence_length + self.max_cache_sequence_length = max_cache_sequence_length or max_sequence_length + self.max_rotary_sequence_length = max_rotary_sequence_length or max_sequence_length self.past_sequence_length = past_sequence_length self.num_heads = num_heads self.kv_num_heads = kv_num_heads self.head_size = head_size - self.softmax_scale = softmax_scale if softmax_scale is not None else 1.0 / (head_size**0.5) + self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5)) # Derived values self.total_sequence_length = sequence_length + past_sequence_length - self.past_buffer_length = max_sequence_length if share_buffer else past_sequence_length - self.present_buffer_length = max_sequence_length if share_buffer else (past_sequence_length + sequence_length) + self.past_buffer_length = self.max_cache_sequence_length if share_buffer else past_sequence_length + self.present_buffer_length = ( + self.max_cache_sequence_length if share_buffer else (past_sequence_length + sequence_length) + ) self.do_rotary = do_rotary self.rotary_interleaved = rotary_interleaved @@ -75,8 +81,8 @@ def shape_dict(self): "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), "present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size), "present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size), - "cos_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), - "sin_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), + "cos_cache": (self.max_rotary_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), + "sin_cache": (self.max_rotary_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), } if not self.is_packed_qkv: @@ -92,7 +98,7 @@ def shape_dict(self): def get_cos_sin_cache(self, dtype): rotary_fraction = 1.0 rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16 - angle = torch.rand(self.max_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi + angle = torch.rand(self.max_rotary_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi cos = torch.cos(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype) return cos.to(device=self.device), sin.to(device=self.device) @@ -151,6 +157,8 @@ def __init__( local_window_size: int = -1, attention_mask=None, is_packed_qkv=False, + max_cache_sequence_length=None, + max_rotary_sequence_length=None, ): super().__init__( "GroupQueryAttention", @@ -166,6 +174,8 @@ def __init__( rotary_interleaved, device, is_packed_qkv=is_packed_qkv, + max_cache_sequence_length=max_cache_sequence_length, + max_rotary_sequence_length=max_rotary_sequence_length, ) # local_window_size is for ORT only, not for Torch implementation. self.local_window_size = local_window_size @@ -212,6 +222,8 @@ def __init__( rotary_interleaved: bool = False, device="cuda", is_packed_qkv=False, + max_cache_sequence_length=None, + max_rotary_sequence_length=None, ): super().__init__( "SparseAttention", @@ -227,6 +239,8 @@ def __init__( rotary_interleaved, device, is_packed_qkv=is_packed_qkv, + max_cache_sequence_length=max_cache_sequence_length, + max_rotary_sequence_length=max_rotary_sequence_length, ) self.sparse_block_size = sparse_block_size self.num_layout = num_layout @@ -237,18 +251,23 @@ def __init__( def block_mask(self): return get_block_mask(self.num_layout, self.max_blocks, self.local_blocks, self.vert_stride).to(self.device) + def block_indices(self): + row_indices, col_indices = dense_to_csr(self.block_mask()) + return row_indices.to(torch.int32).to(self.device), col_indices.to(torch.int32).to(self.device) + def dense_mask(self): - expand_block_mask = self.block_mask() dense_mask = get_dense_mask( - expand_block_mask, self.total_sequence_length, self.sequence_length, self.sparse_block_size + self.block_mask(), self.total_sequence_length, self.sequence_length, self.sparse_block_size ) return dense_mask.repeat(self.batch_size, self.num_heads // self.num_layout, 1, 1).to(self.device) def shape_dict(self): shapes = super().shape_dict() + block_row_indices, block_col_indices = self.block_indices() shapes.update( { - "block_mask": (self.num_layout, self.max_blocks, self.max_blocks), + "block_row_indices": tuple(block_row_indices.shape), + "block_col_indices": tuple(block_col_indices.shape), "key_total_sequence_lengths": (self.batch_size,), } ) @@ -257,10 +276,11 @@ def shape_dict(self): def random_inputs(self): feeds = super().random_inputs() k_seqlens = torch.ones((self.batch_size,), device=self.device, dtype=torch.int32) * self.total_sequence_length + block_row_indices, block_col_indices = self.block_indices() feeds.update( { - "block_mask": self.block_mask(), - "total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32), + "block_row_indices": block_row_indices, + "block_col_indices": block_col_indices, "key_total_sequence_lengths": k_seqlens, } ) @@ -281,6 +301,8 @@ def get_comparable_ort_gqa_config(self, use_local=False) -> GroupQueryAttentionC self.device, local_window_size=self.local_blocks * self.sparse_block_size if use_local else -1, is_packed_qkv=self.is_packed_qkv, + max_cache_sequence_length=self.max_cache_sequence_length, + max_rotary_sequence_length=self.max_rotary_sequence_length, ) def get_comparable_torch_gqa_config(self, use_sparse=False) -> GroupQueryAttentionConfig: @@ -305,6 +327,8 @@ def get_comparable_torch_gqa_config(self, use_sparse=False) -> GroupQueryAttenti self.device, attention_mask=attention_mask, is_packed_qkv=False, # torch reference implementation does not support packed qkv. + max_cache_sequence_length=self.max_cache_sequence_length, + max_rotary_sequence_length=self.max_rotary_sequence_length, ) @@ -327,6 +351,19 @@ def get_block_mask(num_layout, max_blocks, local_blocks, vert_stride): return block_mask +def dense_to_csr(x): + """Turning a 3D torch tensor (x) to CSR rows/cols indexing.""" + assert x.dim() == 3 + pad = -1 + x = [xi.to_sparse_csr() for xi in x] + row_indices = torch.vstack([xi.crow_indices() for xi in x]) + cols = [xi.col_indices() for xi in x] + max_cols = max(len(xi) for xi in cols) + cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols] + col_indices = torch.vstack(cols) + return row_indices, col_indices + + def get_dense_mask(block_mask, total_seq_len, query_seq_len, block_size): dense_mask = torch.kron(block_mask, block_mask.new_ones((block_size, block_size)))[ :, :total_seq_len, :total_seq_len @@ -350,7 +387,8 @@ def create_sparse_attention_onnx_model(config: SparseAttentionConfig): "value" + suffix if not config.is_packed_qkv else "", "past_key" + suffix, "past_value" + suffix, - "block_mask", + "block_row_indices", # no suffix since int32 need not cast for bfloat graph. + "block_col_indices", "total_sequence_length" if config.share_buffer else "", "key_total_sequence_lengths", "cos_cache" + suffix if config.do_rotary else "", @@ -410,7 +448,12 @@ def create_sparse_attention_onnx_model(config: SparseAttentionConfig): [ helper.make_tensor_value_info("past_key", io_float_type, list(shape_dict["past_key"])), helper.make_tensor_value_info("past_value", io_float_type, list(shape_dict["past_value"])), - helper.make_tensor_value_info("block_mask", TensorProto.INT32, list(shape_dict["block_mask"])), + helper.make_tensor_value_info( + "block_row_indices", TensorProto.INT32, list(shape_dict["block_row_indices"]) + ), + helper.make_tensor_value_info( + "block_col_indices", TensorProto.INT32, list(shape_dict["block_col_indices"]) + ), helper.make_tensor_value_info( "total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"]) ), @@ -704,7 +747,8 @@ def __init__(self, config: SparseAttentionConfig): print("query(BSNH, SA)", query) print("key(BSNH, SA)", key) print("value(BSNH, SA)", value) - print("block_mask (SA)", self.feed_dict["block_mask"]) + print("block_row_indices", self.feed_dict["block_row_indices"]) + print("block_col_indices", self.feed_dict["block_col_indices"]) print("total_sequence_length", self.feed_dict["total_sequence_length"]) print("key_total_sequence_lengths", self.feed_dict["key_total_sequence_lengths"]) @@ -778,6 +822,7 @@ def run_relevance_no_past(self, sm: int, device): softmax_scale=1.8 / (128**0.5), device=device, is_packed_qkv=packed_qkv, + max_cache_sequence_length=None if seq_len >= 128 else 128, # test smaller kv cache buffer. ) self.run_one_relevance_test(config) @@ -806,6 +851,7 @@ def run_relevance_past(self, sm: int, device, do_rotary: bool): rotary_interleaved=(past_seq_len % 2 == 1), device=device, is_packed_qkv=packed_qkv, + max_rotary_sequence_length=None if past_seq_len >= 128 else 128, # test smaller rotary buffer. ) if do_rotary: diff --git a/tools/ci_build/github/apple/package_release_tasks.py b/tools/ci_build/github/apple/package_release_tasks.py index 38c3509dc84a..7885023f23a5 100755 --- a/tools/ci_build/github/apple/package_release_tasks.py +++ b/tools/ci_build/github/apple/package_release_tasks.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse +import glob import os import shlex import subprocess @@ -50,6 +51,15 @@ def update_podspec(pod_archive_path: Path, podspec_path: Path): podspec_path.write_text(podspec_content) +def _resolve_single_path_from_pattern(path_pattern: str) -> Path: + matches = glob.glob(path_pattern) + if len(matches) != 1: + raise argparse.ArgumentTypeError( + f"Expected exactly 1 match for pattern '{path_pattern}' but got {len(matches)} matches." + ) + return Path(matches[0]).resolve(strict=True) + + def _parse_args(): parser = argparse.ArgumentParser( description="Helper script to perform release tasks. " @@ -58,14 +68,14 @@ def _parse_args(): parser.add_argument( "--pod-archive-path", - type=Path, - help="Pod archive path.", + type=_resolve_single_path_from_pattern, + help="Pod archive path. It may be a pattern, in which case it must match exactly one path.", ) parser.add_argument( "--podspec-path", - type=Path, - help="Podspec path.", + type=_resolve_single_path_from_pattern, + help="Podspec path. It may be a pattern, in which case it must match exactly one path.", ) parser.add_argument( @@ -82,11 +92,9 @@ def _validate_args( ): if require_pod_archive_path: assert args.pod_archive_path is not None, "--pod-archive-path must be specified." - args.pod_archive_path = args.pod_archive_path.resolve(strict=True) if require_podspec_path: assert args.podspec_path is not None, "--podspec-path must be specified." - args.podspec_path = args.podspec_path.resolve(strict=True) def main(): diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 21bcd5a767d1..4f645085c290 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -99,6 +99,17 @@ variables: ${{ if eq(parameters.CudaVersion, '12.2') }}: value: 12.4 +- name: win_trt_home + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: $(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: $(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-12.4 +- name: win_cuda_home + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: $(Agent.TempDirectory)\v11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: $(Agent.TempDirectory)\v12.2 + stages: - stage: Setup jobs: @@ -123,11 +134,25 @@ stages: echo "##vso[task.setvariable variable=ReleaseVersionSuffix;isOutput=true]" fi name: Set_Release_Version_Suffix + - script: | + # Extracting hours and minutes + date=$(date +'%Y%m%d') + # Set the hhmm value as a pipeline variable + echo "##vso[task.setvariable variable=BuildDate;isOutput=true]$date" + displayName: 'Set Start Date as Variable' + name: Set_Build_Date + + - script: | + # Extracting hours and minutes + hhmm=$(date +'%H%M') + # Set the hhmm value as a pipeline variable + echo "##vso[task.setvariable variable=BuildTime;isOutput=true]$hhmm" + displayName: 'Set Start Time as Variable' + name: Set_Build_Time - template: templates/component-governance-component-detection-steps.yml parameters : condition : 'succeeded' - - stage: Debug dependsOn: Setup jobs: @@ -136,17 +161,20 @@ stages: vmImage: ubuntu-latest variables: MyVar: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] + BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] steps: - checkout: none - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' condition: always() - bash: echo $(MyVar) + - bash: echo $(BuildTime) + - bash: echo $(BuildDate) - template: templates/component-governance-component-detection-steps.yml parameters : condition : 'succeeded' - - stage: Download_Java_Tools dependsOn: [] jobs: @@ -212,41 +240,6 @@ stages: buildJava: true buildNodejs: true -#CUDA without tensorrt -- template: templates/win-ci.yml - parameters: - ort_build_pool_name: 'onnxruntime-Win2022-GPU-T4' - DoCompliance: ${{ parameters.DoCompliance }} - DoEsrp: ${{ parameters.DoEsrp }} - stage_name_suffix: gpu - buildArch: x64 - msbuildPlatform: x64 - packageName: x64-cuda - buildparameter: --use_cuda --cuda_home=$(Agent.TempDirectory)\v${{ parameters.CudaVersion }} --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" ${{parameters.AdditionalBuildFlag}} - runTests: ${{ parameters.RunOnnxRuntimeTests }} - buildJava: true - java_artifact_id: onnxruntime_gpu - CudaVersion: ${{ parameters.CudaVersion }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - -# CUDA with Tensorrt -- template: templates/win-ci.yml - parameters: - ort_build_pool_name: 'onnxruntime-Win2022-GPU-T4' - DoCompliance: ${{ parameters.DoCompliance }} - DoEsrp: ${{ parameters.DoEsrp }} - stage_name_suffix: tensorrt - buildArch: x64 - msbuildPlatform: x64 - packageName: x64-tensorrt - buildparameter: --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-${{ variables.win_trt_version }}" --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" - runTests: ${{ parameters.RunOnnxRuntimeTests }} - buildJava: true - java_artifact_id: onnxruntime_gpu - CudaVersion: ${{ parameters.CudaVersion }} - UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} - # ROCm - stage: Linux_C_API_Packaging_ROCm_x64 dependsOn: [] @@ -313,503 +306,29 @@ stages: condition: 'succeeded' - template: templates/clean-agent-build-directory-step.yml -- stage: Jar_Packaging_GPU - dependsOn: - - Linux_C_API_Packaging_GPU_x64 - - Linux_C_API_Packaging_GPU_TensorRT_x64 - - Windows_Packaging_gpu_Testing - - Windows_Packaging_tensorrt_Testing - - Download_Java_Tools - condition: succeeded() - jobs: - - job: - workspace: - clean: all - pool: 'onnxruntime-Win-CPU-2022' - - steps: - - checkout: self - submodules: false - - template: templates/set-version-number-variables-step.yml - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - Win x64' - ArtifactName: 'drop-onnxruntime-java-win-x64-tensorrt' - TargetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - stepName: 'Download Pipeline Artifact - Linux x64' - artifactName: 'drop-onnxruntime-java-linux-x64-cuda' - targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-linux-x64' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - Linux x64' - ArtifactName: 'drop-onnxruntime-java-linux-x64-tensorrt' - targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-linux-x64-tensorrt' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - task: PowerShell@2 - displayName: 'PowerShell Script' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_gpu_packaging.ps1 - failOnStderr: true - showWarnings: true - workingDirectory: '$(Build.BinariesDirectory)\java-artifact' - - - task: CopyFiles@2 - displayName: 'Copy Java Files to Artifact Staging Directory' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishPipelineArtifact@1 - displayName: 'Publish Pipeline Artifact' - inputs: - targetPath: '$(Build.ArtifactStagingDirectory)' - artifact: 'onnxruntime-java-gpu' - - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - -- stage: Final_Jar_Testing_Windows_GPU - dependsOn: - Jar_Packaging_GPU - jobs: - - job: - workspace: - clean: all - pool: 'onnxruntime-Win2022-GPU-T4' - timeoutInMinutes: 60 - variables: - - name: runCodesignValidationInjection - value: false - - steps: - - template: templates/set-version-number-variables-step.yml - - - template: templates/jobs/download_win_gpu_library.yml - parameters: - CudaVersion: ${{ parameters.CudaVersion }} - DownloadCUDA: true - DownloadTRT: true - - - template: templates\flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Final Jar' - ArtifactName: onnxruntime-java-gpu - TargetPath: '$(Build.BinariesDirectory)\final-jar' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: templates\flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Jar Tools' - ArtifactName: onnxruntime-java-tools - TargetPath: '$(Build.BinariesDirectory)\final-jar' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - task: CmdLine@2 - inputs: - script: | - mkdir test - pushd test - jar xf $(Build.BinariesDirectory)\final-jar\testing.jar - popd - java -DUSE_CUDA=1 -jar junit-platform-console-standalone-1.6.2.jar -cp .;.\test;protobuf-java-3.21.7.jar;onnxruntime_gpu-$(OnnxRuntimeVersion).jar --scan-class-path --fail-if-no-tests --disable-banner - workingDirectory: '$(Build.BinariesDirectory)\final-jar' - - - template: templates/component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - -- stage: Final_Jar_Testing_Linux_GPU - dependsOn: - Jar_Packaging_GPU - jobs: - - job: - workspace: - clean: all - pool: 'Onnxruntime-Linux-GPU' - variables: - - name: runCodesignValidationInjection - value: false - timeoutInMinutes: 60 - - steps: - - checkout: self - submodules: false - - template: templates/set-version-number-variables-step.yml - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Final Jar' - ArtifactName: onnxruntime-java-gpu - TargetPath: '$(Build.BinariesDirectory)/final-jar' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 - Context: tools/ci_build/github/linux/docker/ - DockerBuildArgs: " - --build-arg BUILD_UID=$( id -u ) - --build-arg BASEIMAGE=${{ variables.docker_base_image }} - --build-arg TRT_VERSION=${{ variables.linux_trt_version }} - " - Repository: onnxruntimeubi8packagestest - UpdateDepsTxt: false - - - bash: | - docker run --rm \ - --gpus all \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume /data/models:/build/models:ro \ - onnxruntimeubi8packagestest \ - /bin/bash /onnxruntime_src/tools/ci_build/github/linux/java_linux_final_test.sh -r /build -v $(OnnxRuntimeVersion) - displayName: 'Test' - - - template: templates/component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - -- stage: Windows_Packaging_combined_GPU - dependsOn: - - Windows_Packaging_gpu_Testing - - Windows_Packaging_tensorrt_Testing - condition: succeeded() - jobs: - - job: - workspace: - clean: all - pool: 'onnxruntime-Win2022-GPU-T4' - - steps: - - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime - - checkout: onnxruntime-inference-examples # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime-inference-examples - submodules: false - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - script: dir $(Build.SourcesDirectory) - - template: templates/jobs/download_win_gpu_library.yml - parameters: - CudaVersion: ${{ parameters.CudaVersion }} - DownloadCUDA: true - DownloadTRT: true - - template: templates/set-version-number-variables-step.yml - parameters: - versionFileDirectory: '$(Build.SourcesDirectory)\onnxruntime' - workingDirectory: '$(Build.SourcesDirectory)\onnxruntime' - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact - Combined GPU' - inputs: - artifactName: 'onnxruntime-win-x64-cuda' - targetPath: '$(Build.BinariesDirectory)/zip-artifacts' - - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact - Combined GPU' - inputs: - artifactName: 'onnxruntime-win-x64-tensorrt' - targetPath: '$(Build.BinariesDirectory)/zip-artifacts' - - - task: PowerShell@2 - displayName: 'PowerShell Script' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\onnxruntime\tools\ci_build\github\windows\extract_zip_files_gpu.ps1 - - - script: | - dir - workingDirectory: '$(Build.BinariesDirectory)/zip-artifacts' - displayName: 'List artifacts' - - - task: BatchScript@1 - displayName: 'Bundle CUDA/TRT EP binaries' - inputs: - filename: $(Build.SourcesDirectory)\onnxruntime\tools\ci_build\github\windows\bundle_dlls_gpu.bat - workingFolder: $(Build.BinariesDirectory)\zip-artifacts - - - task: CopyFiles@2 - displayName: 'Copy zip file to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\zip-artifacts' - Contents: 'onnxruntime-win-x64-gpu-*.zip' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - template: templates/validate-package.yml - parameters: - PackageType: 'zip' - PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip' - ScriptPath: '$(Build.SourcesDirectory)\onnxruntime\tools\nuget\validate_package.py' - PlatformsSupported: 'win-x64' - VerifyNugetSigning: false - workingDirectory: '$(Build.ArtifactStagingDirectory)' - - - task: BatchScript@1 - displayName: 'Test C API application for GPU package' - inputs: - filename: $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet\run_capi_application.bat - arguments: $(Build.SourcesDirectory)\onnxruntime $(Build.ArtifactStagingDirectory)\onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet - workingFolder: '$(Build.ArtifactStagingDirectory)' - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Combined GPU Package Artifact' - inputs: - artifactName: 'onnxruntime-win-x64-gpu' - targetPath: '$(Build.ArtifactStagingDirectory)/onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip' - - - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - -- stage: NuGet_Packaging_GPU - dependsOn: - - Setup - - Windows_Packaging_gpu_Testing - - Windows_Packaging_CPU_x64_default - - Windows_Packaging_tensorrt_Testing - - Linux_C_API_Packaging_GPU_x64 - - Linux_C_API_Packaging_GPU_TensorRT_x64 - condition: succeeded() - jobs: - - job: - workspace: - clean: all - # we need to use the 2022 pool to create the nuget package with both pre-net6+Xamarin and net6 targets. - # VS2019 has no support for net6 and we need to use msbuild (from the VS install) to do the packing - pool: 'Onnxruntime-Win-CPU-2022' - variables: - breakCodesignValidationInjection: ${{ parameters.DoEsrp }} - ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - BuildDate: $[format('{0:yyyyMMdd}', pipeline.startTime)] - BuildTime: $[format('{0:HHmm}', pipeline.startTime)] - - steps: - - checkout: self - submodules: true - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - NuGet' - ArtifactName: 'onnxruntime-win-x64-cuda' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - NuGet' - ArtifactName: 'onnxruntime-win-x64-tensorrt' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - NuGet' - ArtifactName: 'onnxruntime-linux-x64-cuda' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - NuGet' - ArtifactName: 'onnxruntime-linux-x64-tensorrt' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - # The following one is from a CPU job that publishes protoc.exe - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - NuGet' - ArtifactName: 'drop-extra' - TargetPath: '$(Build.BinariesDirectory)/extra-artifact' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - # Reconstruct the build dir - - task: PowerShell@2 - displayName: 'PowerShell Script' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\extract_nuget_files_gpu.ps1 - - - script: | - dir - workingDirectory: '$(Build.BinariesDirectory)/nuget-artifact' - displayName: 'List artifacts' - - - script: | - mklink /D /J models C:\local\models - workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Create models link' - - - task: NuGetToolInstaller@0 - displayName: Use Nuget 6.2.1 - inputs: - versionSpec: 6.2.1 - - - task: PowerShell@2 - displayName: Install mobile workloads - inputs: - targetType: 'inline' - script: | - dotnet workload install android ios maccatalyst - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: MSBuild@1 - displayName: 'Restore NuGet Packages and create project.assets.json' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:restore -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu"' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: MSBuild@1 - displayName: 'Build C# bindings' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - configuration: RelWithDebInfo - platform: 'Any CPU' - msbuildArguments: '-p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu" -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - template: templates/win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' - DisplayName: 'ESRP - Sign C# dlls' - DoEsrp: ${{ parameters.DoEsrp }} - - - task: MSBuild@1 - displayName: 'Build Nuget Packages Microsoft.ML.OnnxRuntime.Gpu' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' - configuration: RelWithDebInfo - platform: 'Any CPU' - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:CurrentDate=$(BuildDate) -p:CurrentTime=$(BuildTime)' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - task: BatchScript@1 - displayName: 'Add TensorRT header file to the native nuGet package' - inputs: - filename: $(Build.SourcesDirectory)\tools\ci_build\github\windows\bundle_nuget_with_native_headers.bat - workingFolder: $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - Contents: '*.snupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: CopyFiles@2 - displayName: 'Copy nuget packages to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo' - Contents: '*.nupkg' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - template: templates/esrp_nuget.yml - parameters: - DisplayName: 'ESRP - sign NuGet package' - FolderPath: '$(Build.ArtifactStagingDirectory)' - DoEsrp: ${{ parameters.DoEsrp }} - - - template: templates/validate-package.yml - parameters: - PackageType: 'nuget' - PackagePath: '$(Build.ArtifactStagingDirectory)' - PlatformsSupported: 'win-x64,linux-x64' - # 1* stands for version number. we use it to filter Gpu.Windows and Gpu.Linux packages - PackageName: 'Microsoft.ML.OnnxRuntime.Gpu.1*nupkg' - VerifyNugetSigning: false - - - template: templates/validate-package.yml - parameters: - PackageType: 'nuget' - PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'Microsoft.ML.OnnxRuntime.Gpu.Windows.*nupkg' - PlatformsSupported: 'win-x64' - VerifyNugetSigning: false - - - template: templates/validate-package.yml - parameters: - PackageType: 'nuget' - PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'Microsoft.ML.OnnxRuntime.Gpu.Linux.*nupkg' - PlatformsSupported: 'linux-x64' - VerifyNugetSigning: false - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-GPU' - targetPath: '$(Build.ArtifactStagingDirectory)' - - - - task: MSBuild@1 - displayName: 'Clean C#' - inputs: - solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' - platform: 'Any CPU' - configuration: RelWithDebInfo - msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' - workingDirectory: '$(Build.SourcesDirectory)\csharp' - - - - task: RoslynAnalyzers@2 - displayName: 'Run Roslyn Analyzers' - inputs: - userProvideBuildInfo: msBuildInfo - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\msbuild.exe" $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln -p:configuration="RelWithDebInfo" -p:Platform="Any CPU" -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' - condition: and(succeeded(), eq('${{ parameters.DoCompliance }}', true)) +- template: stages/java-cuda-packaging-stage.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + SpecificArtifact: ${{ parameters.SpecificArtifact }} + BuildId: ${{ parameters.BuildId }} - - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' +- template: stages/nuget-win-cuda-packaging-stage.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + SpecificArtifact: ${{ parameters.SpecificArtifact }} + BuildId: ${{ parameters.BuildId }} + DoEsrp: ${{ parameters.DoEsrp }} + DoCompliance: ${{ parameters.DoCompliance }} + UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + win_trt_home: ${{ variables.win_trt_home }} + win_cuda_home: ${{ variables.win_cuda_home }} - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() +- template: stages/nuget-combine-cuda-stage.yml + parameters: + DoEsrp: ${{ parameters.DoEsrp }} + DoCompliance: ${{ parameters.DoCompliance }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - stage: NuGet_Packaging_ROCm dependsOn: @@ -826,6 +345,8 @@ stages: variables: breakCodesignValidationInjection: ${{ parameters.DoEsrp }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] + BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] steps: - checkout: self @@ -953,7 +474,7 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' configuration: RelWithDebInfo platform: 'Any CPU' - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)' + msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.ROCm -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:CurrentTime=$(BuildTime) -p:CurrentDate=$(BuildDate)' workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: CopyFiles@2 @@ -1032,8 +553,8 @@ stages: - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-A10' - NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' + AgentPool: 'onnxruntime-Win2022-GPU-A10' + NugetPackageName: 'Microsoft.ML.OnnxRuntime.Gpu' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' Skipx86Tests: 'true' @@ -1043,8 +564,8 @@ stages: - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-A10' - NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows' + AgentPool: 'onnxruntime-Win2022-GPU-A10' + NugetPackageName: 'Microsoft.ML.OnnxRuntime.Gpu.Windows' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' MoreSuffix: '_Windows' @@ -1055,7 +576,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU-A10 + AgentPool: Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' @@ -1065,7 +586,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU-A10 + AgentPool: Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' MoreSuffix: '_Linux' @@ -1079,14 +600,14 @@ stages: AgentPool: AMD-GPU ArtifactSuffix: 'ROCm' StageSuffix: 'ROCm' - NugetPackageName : 'Microsoft.ML.OnnxRuntime.ROCm' + NugetPackageName: 'Microsoft.ML.OnnxRuntime.ROCm' SpecificArtifact: ${{ parameters.specificArtifact }} CustomOpArtifactName: 'onnxruntime-linux-x64-rocm' BuildId: ${{ parameters.BuildId }} - template: nuget/templates/dml-vs-2022.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-dml-A10' + AgentPool: 'onnxruntime-Win2022-GPU-dml-A10' IsReleaseBuild: ${{ parameters.IsReleaseBuild }} ArtifactName: 'drop-nuget-dml' StageName: 'Windows_CI_GPU_DML_Dev' @@ -1096,11 +617,11 @@ stages: EnvSetupScript: 'setup_env.bat' sln_platform: 'x64' DoDebugBuild: 'false' - DoNugetPack : 'true' + DoNugetPack: 'true' DoCompliance: 'false' DoEsrp: ${{ parameters.DoEsrp }} NuPackScript: | - msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreatePackage /p:OrtPackageId=Microsoft.ML.OnnxRuntime.DirectML /p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} + msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreatePackage /p:OrtPackageId=Microsoft.ML.OnnxRuntime.DirectML /p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} /p:CurrentData=$(BuildDate) /p:CurrentTime=$(BuildTime) copy $(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo\*.nupkg $(Build.ArtifactStagingDirectory) copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.nupkg $(Build.ArtifactStagingDirectory) mkdir $(Build.ArtifactStagingDirectory)\testdata @@ -1108,7 +629,7 @@ stages: - template: nuget/templates/dml-vs-2022.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-dml-A10' + AgentPool: 'onnxruntime-Win2022-GPU-dml-A10' IsReleaseBuild: ${{ parameters.IsReleaseBuild }} ArtifactName: 'drop-win-dml-x86-zip' StageName: 'Windows_CI_GPU_DML_Dev_x86' @@ -1117,7 +638,7 @@ stages: EnvSetupScript: 'setup_env_x86.bat' sln_platform: 'Win32' DoDebugBuild: 'false' - DoNugetPack : 'true' + DoNugetPack: 'true' DoCompliance: ${{ parameters.DoCompliance }} DoEsrp: ${{ parameters.DoEsrp }} RunTests: 'false' @@ -1131,7 +652,7 @@ stages: - template: nuget/templates/dml-vs-2022.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-dml-A10' + AgentPool: 'onnxruntime-Win2022-GPU-dml-A10' IsReleaseBuild: ${{ parameters.IsReleaseBuild }} ArtifactName: 'drop-win-dml-arm64-zip' StageName: 'Windows_CI_GPU_DML_Dev_arm64' @@ -1140,7 +661,7 @@ stages: EnvSetupScript: 'setup_env.bat' sln_platform: 'arm64' DoDebugBuild: 'false' - DoNugetPack : 'true' + DoNugetPack: 'true' DoCompliance: ${{ parameters.DoCompliance }} DoEsrp: ${{ parameters.DoEsrp }} RunTests: 'false' @@ -1246,5 +767,5 @@ stages: artifactName: 'drop-signed-nuget-dml' targetPath: '$(Build.ArtifactStagingDirectory)' - template: templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' + parameters: + condition: 'succeeded' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml deleted file mode 100644 index 7260fef085e2..000000000000 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ /dev/null @@ -1,199 +0,0 @@ -parameters: - - name: RunOnnxRuntimeTests - displayName: Run Tests? - type: boolean - default: true - - - name: UseIncreasedTimeoutForTests - displayName: Increase timeout for tests? Set it to false if you are doing an Onnx Runtime release. - type: boolean - default: false - - - name: DoCompliance - displayName: Run Compliance Tasks? - type: boolean - default: true - - - name: DoEsrp - displayName: Run code sign tasks? Must be true if you are doing an ONNX Runtime release - type: boolean - default: true - - - name: IsReleaseBuild - displayName: Is a release build? Set it to true if you are doing an ONNX Runtime release. - type: boolean - default: false - - - name: PreReleaseVersionSuffixString - displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. - type: string - values: - - alpha - - beta - - rc - - none - default: none - - - name: PreReleaseVersionSuffixNumber - displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. - type: number - default: 0 - - # these 2 parameters are used for debugging. - - name: SpecificArtifact - displayName: Use Specific Artifact (Debugging only) - type: boolean - default: false - - - name: BuildId - displayName: Pipeline BuildId, you could find it in the URL - type: string - default: '0' - - - name: CudaVersion - displayName: CUDA version - type: string - default: '12.2' - values: - - 11.8 - - 12.2 - -variables: - - name: ReleaseVersionSuffix - value: '' - - name: docker_base_image - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 - ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 - - name: linux_trt_version - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.0.1.6-1.cuda11.8 - ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.0.1.6-1.cuda12.4 - - name: win_trt_home - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: $(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-11.8 - ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: $(Agent.TempDirectory)\TensorRT-10.0.1.6.Windows10.x86_64.cuda-12.4 - - name: win_cuda_home - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: $(Agent.TempDirectory)\v11.8 - ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: $(Agent.TempDirectory)\v12.2 -resources: - repositories: - - repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step - type: github - endpoint: ort-examples - name: microsoft/onnxruntime-inference-examples - - repository: manylinux - type: Github - endpoint: Microsoft - name: pypa/manylinux - ref: 5eda9aded5462201e6310105728d33016e637ea7 - -stages: -# Set ReleaseVersionSuffix - - stage: Set_ReleaseVersionSuffix - jobs: - - job: Set_Variables - pool: - vmImage: ubuntu-latest - steps: - - checkout: none - - bash: | - # Do not output ##vso[] commands with `set -x` or they may be parsed again and include a trailing quote. - set +x - if [[ "${{ parameters.IsReleaseBuild }}" = True && "${{ parameters.PreReleaseVersionSuffixString }}" != "none" ]]; then - if [[ "${{ parameters.PreReleaseVersionSuffixNumber }}" -eq 0 ]]; then - echo "##vso[task.setvariable variable=ReleaseVersionSuffix;isOutput=true]-${{ parameters.PreReleaseVersionSuffixString }}" - else - echo "##vso[task.setvariable variable=ReleaseVersionSuffix;isOutput=true]-${{ parameters.PreReleaseVersionSuffixString }}.${{ parameters.PreReleaseVersionSuffixNumber }}" - fi - else - echo "##vso[task.setvariable variable=ReleaseVersionSuffix;isOutput=true]" - fi - name: Set_Release_Version_Suffix - - bash: echo $(ReleaseVersionSuffix) - name: Debug_Release_Version_Suffix - # this is needed for certain artifacts to be published - - stage: Linux_C_API_Packaging_CPU_x64 - dependsOn: [ ] - jobs: - - template: templates/c-api-linux-cpu.yml - parameters: - BaseImage: 'registry.access.redhat.com/ubi8/ubi' - OnnxruntimeArch: 'x64' - OnnxruntimeNodejsBindingArch: 'x64' - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' - PackageJava: false - PackageNodeJS: false - # Nuget Packaging - - - template: stages/nuget-linux-cuda-packaging-stage.yml - parameters: - CudaVersion: ${{ parameters.CudaVersion }} - docker_base_image: ${{ variables.docker_base_image }} - linux_trt_version: ${{ variables.linux_trt_version }} - - template: stages/nuget-win-cuda-packaging-stage.yml - parameters: - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} - CudaVersion: ${{ parameters.CudaVersion }} - win_trt_home: ${{ variables.win_trt_home }} - win_cuda_home: ${{ variables.win_cuda_home }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - template: stages/nuget-combine-cuda-stage.yml - parameters: - DoCompliance: ${{ parameters.DoCompliance }} - DoEsrp: ${{ parameters.DoEsrp }} - IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - # Testing - - template: nuget/templates/test_win.yml - parameters: - AgentPool : 'onnxruntime-Win2022-GPU-A10' - NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' - ArtifactSuffix: 'GPU' - StageSuffix: 'GPU' - Skipx86Tests: 'true' - CudaVersion: ${{ parameters.CudaVersion }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: nuget/templates/test_win.yml - parameters: - AgentPool : 'onnxruntime-Win2022-GPU-A10' - NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows' - ArtifactSuffix: 'GPU' - StageSuffix: 'GPU' - MoreSuffix: '_Windows' - Skipx86Tests: 'true' - CudaVersion: ${{ parameters.CudaVersion }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: nuget/templates/test_linux.yml - parameters: - AgentPool : Onnxruntime-Linux-GPU-A10 - ArtifactSuffix: 'GPU' - StageSuffix: 'GPU' - NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' - CudaVersion: ${{ parameters.CudaVersion }} - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: nuget/templates/test_linux.yml - parameters: - AgentPool : Onnxruntime-Linux-GPU-A10 - ArtifactSuffix: 'GPU' - StageSuffix: 'GPU' - MoreSuffix: '_Linux' - NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Linux' - CudaVersion: ${{ parameters.CudaVersion }} - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - -## Win/Linux GPU Combined Publishing -#- template: templates/publish-nuget.yml diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 3a3375a313ca..cc1e798e6cd2 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -24,7 +24,7 @@ parameters: IsReleaseBuild: false stages: - stage: ${{ parameters.StageName }} - dependsOn: [] + dependsOn: Setup jobs: - job: timeoutInMinutes: 200 @@ -47,6 +47,8 @@ stages: runCodesignValidationInjection: and(${{ parameters.DoNodejsPack }},${{ parameters. DoEsrp}}) #For the others, code sign is in a separated job DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} + BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] ${{ if eq(parameters.EnableLto, true) }}: build_py_lto_flag: --enable_lto diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 8cfcaa443cc8..18c3bb783e92 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -73,6 +73,7 @@ jobs: FolderPath: '$(Build.BinariesDirectory)\Windows\${{ parameters.build_config }}\${{ parameters.build_config }}' DisplayName: 'ESRP - Sign dlls' DoEsrp: ${{ parameters.DoEsrp }} + Pattern: 'onnxruntime.dll' - task: CmdLine@2 displayName: 'Generating nuspec for the native Nuget package x64' @@ -178,6 +179,7 @@ jobs: FolderPath: '$(Build.BinariesDirectory)\Win_arm64\${{ parameters.build_config }}\${{ parameters.build_config }}' DisplayName: 'ESRP - Sign dlls' DoEsrp: ${{ parameters.DoEsrp }} + Pattern: 'onnxruntime.dll' - task: CmdLine@2 displayName: 'Generating nuspec for the native Nuget package arm64' diff --git a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml new file mode 100644 index 000000000000..8c81972d607e --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml @@ -0,0 +1,187 @@ +parameters: +- name: CudaVersion + type: string +- name: SpecificArtifact + type: string +- name: BuildId + type: string + +stages: +- stage: Java_GPU_Packaging + dependsOn: + - Linux_C_API_Packaging_Combined_CUDA + - Windows_Packaging_CUDA + - Windows_Packaging_TensorRT + - Download_Java_Tools + jobs: + - job: Jar_Packaging_GPU + workspace: + clean: all + pool: 'onnxruntime-Win-CPU-2022' + dependsOn: [] + condition: succeeded() + steps: + - checkout: self + submodules: false + - template: ../templates/set-version-number-variables-step.yml + + - template: ../templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Win x64' + ArtifactName: 'drop-onnxruntime-java-win-x64-tensorrt' + TargetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: ../templates/flex-downloadPipelineArtifact.yml + parameters: + stepName: 'Download Pipeline Artifact - Linux x64' + artifactName: 'drop-onnxruntime-java-linux-x64-cuda' + targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-linux-x64' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: ../templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - Linux x64' + ArtifactName: 'drop-onnxruntime-java-linux-x64-tensorrt' + targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-linux-x64-tensorrt' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - task: PowerShell@2 + displayName: 'PowerShell Script' + inputs: + targetType: filePath + filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_gpu_packaging.ps1 + failOnStderr: true + showWarnings: true + workingDirectory: '$(Build.BinariesDirectory)\java-artifact' + + - task: CopyFiles@2 + displayName: 'Copy Java Files to Artifact Staging Directory' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishPipelineArtifact@1 + displayName: 'Publish Pipeline Artifact' + inputs: + targetPath: '$(Build.ArtifactStagingDirectory)' + artifact: 'onnxruntime-java-gpu' + + - template: ../templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + - job: Final_Jar_Testing_Windows_GPU + dependsOn: + Jar_Packaging_GPU + workspace: + clean: all + pool: 'onnxruntime-Win2022-GPU-T4' + timeoutInMinutes: 60 + variables: + - name: runCodesignValidationInjection + value: false + + steps: + - template: ../templates/set-version-number-variables-step.yml + + - template: ../templates/jobs/download_win_gpu_library.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + DownloadCUDA: true + DownloadTRT: true + + - template: ../templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Final Jar' + ArtifactName: onnxruntime-java-gpu + TargetPath: '$(Build.BinariesDirectory)\final-jar' + SpecificArtifact: ${{ parameters.SpecificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: ../templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Jar Tools' + ArtifactName: onnxruntime-java-tools + TargetPath: '$(Build.BinariesDirectory)\final-jar' + SpecificArtifact: ${{ parameters.SpecificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - task: CmdLine@2 + inputs: + script: | + mkdir test + pushd test + jar xf $(Build.BinariesDirectory)\final-jar\testing.jar + popd + java -DUSE_CUDA=1 -jar junit-platform-console-standalone-1.6.2.jar -cp .;.\test;protobuf-java-3.21.7.jar;onnxruntime_gpu-$(OnnxRuntimeVersion).jar --scan-class-path --fail-if-no-tests --disable-banner + workingDirectory: '$(Build.BinariesDirectory)\final-jar' + + - template: ../templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + - job: Final_Jar_Testing_Linux_GPU + dependsOn: + Jar_Packaging_GPU + workspace: + clean: all + pool: 'Onnxruntime-Linux-GPU' + variables: + - name: runCodesignValidationInjection + value: false + - name: docker_base_image + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 + timeoutInMinutes: 60 + + steps: + - checkout: self + submodules: false + - template: ../templates/set-version-number-variables-step.yml + + - template: ../templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Final Jar' + ArtifactName: onnxruntime-java-gpu + TargetPath: '$(Build.BinariesDirectory)/final-jar' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: ../templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 + Context: tools/ci_build/github/linux/docker/ + DockerBuildArgs: " + --build-arg BUILD_UID=$( id -u ) + --build-arg BASEIMAGE=${{ variables.docker_base_image }} + --build-arg TRT_VERSION=${{ variables.linux_trt_version }} + " + Repository: onnxruntimeubi8packagestest + UpdateDepsTxt: false + + - bash: | + docker run --rm \ + --gpus all \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + onnxruntimeubi8packagestest \ + /bin/bash /onnxruntime_src/tools/ci_build/github/linux/java_linux_final_test.sh -r /build -v $(OnnxRuntimeVersion) + displayName: 'Test' + + - template: ../templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/linux-gpu-tensorrt-packaging-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/linux-gpu-tensorrt-packaging-job.yml new file mode 100644 index 000000000000..541ab1ac7487 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/linux-gpu-tensorrt-packaging-job.yml @@ -0,0 +1,108 @@ +parameters: +- name: artifactName + type: string + default: 'onnxruntime-linux-x64-gpu-tensorrt-$(OnnxRuntimeVersion)' + +- name: artifactNameNoVersionString + type: string + default: 'onnxruntime-linux-x64-gpu-tensorrt' + +- name: buildJava + type: boolean + default: false + +- name: buildJavaOption + type: string + default: '' + +- name: buildNodejs + type: boolean + default: true + +- name: buildNodejsOption + type: string + default: '' + +- name: CudaVersion + displayName: CUDA version + type: string + default: '11.8' + values: + - 11.8 + - 12.2 + +jobs: +- job: Linux_C_API_Packaging_TensorRT + dependsOn: [] + workspace: + clean: all + timeoutInMinutes: 180 + pool: 'Onnxruntime-Linux-GPU' + variables: + - name: CUDA_VERSION_MAJOR + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: '11' + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: '12' + - name: CUDA_VERSION + value: ${{ parameters.CudaVersion }} + - name: linux_trt_version + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: 10.0.1.6-1.cuda11.8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: 10.0.1.6-1.cuda12.4 + - name: docker_base_image + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 + steps: + - checkout: self + clean: true + submodules: recursive + - template: ../../templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=${{ variables.docker_base_image }} + --build-arg TRT_VERSION=${{ variables.linux_trt_version }} + --build-arg BUILD_UID=$( id -u ) + " + Repository: onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build + - template: ../../templates/set-version-number-variables-step.yml + + - script: $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Build and Test' + + - ${{ if eq(parameters.buildJava, true) }}: + - template: ../../templates/java-api-artifacts-package-and-publish-steps-posix.yml + parameters: + arch: 'linux-x64' + buildConfig: 'Release' + artifactName: 'onnxruntime-java-linux-x64-tensorrt' + version: '$(OnnxRuntimeVersion)' + libraryName: 'libonnxruntime.so' + nativeLibraryName: 'libonnxruntime4j_jni.so' + + - ${{ if eq(parameters.buildNodejs, 'true') }}: + - template: ../../templates/nodejs-artifacts-package-and-publish-steps-posix.yml + parameters: + arch: 'x64' + os: 'linux' + artifactName: 'drop-onnxruntime-nodejs-linux-x64-tensorrt' + + - template: ../../templates/c-api-artifacts-package-and-publish-steps-posix.yml + parameters: + buildConfig: 'Release' + artifactName: ${{ parameters.artifactName }} + artifactNameNoVersionString: ${{ parameters.artifactNameNoVersionString }} + libraryName: 'libonnxruntime.so.$(OnnxRuntimeVersion)' + + + - template: ../../templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' + - template: ../../templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index 2452e2885e74..9f5ca3db3170 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -1,27 +1,22 @@ parameters: - name: DoCompliance type: boolean - default: true - name: DoEsrp type: boolean - default: true - name: IsReleaseBuild type: boolean - default: false stages: ######## Nuget ######## # Win/Linux CUDA Combined packaging - stage: NuGet_Packaging_GPU dependsOn: - - Set_ReleaseVersionSuffix - - Windows_Packaging_gpu_Testing - - Windows_Packaging_tensorrt_Testing - - Linux_C_API_Packaging_CPU_x64 - - Linux_C_API_Packaging_GPU_x64 - - Linux_C_API_Packaging_GPU_TensorRT_x64 + - Setup + - Windows_Packaging_CUDA + - Windows_Packaging_TensorRT + - Linux_C_API_Packaging_Combined_CUDA condition: succeeded() jobs: - job: @@ -37,35 +32,35 @@ stages: steps: - checkout: self submodules: true - # Download the all artifacts - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact from Linux_C_API_Packaging_GPU_x64 Stage' - inputs: - artifactName: 'onnxruntime-win-x64-cuda' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact from Linux_C_API_Packaging_GPU_TensorRT_x64 Stage' - inputs: - artifactName: 'onnxruntime-win-x64-tensorrt' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - template: ../templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - NuGet' + ArtifactName: 'onnxruntime-win-x64-cuda' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact from Windows_Packaging_gpu Stage' - inputs: - artifactName: 'onnxruntime-linux-x64-cuda' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - template: ../templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - NuGet' + ArtifactName: 'onnxruntime-win-x64-tensorrt' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact from Windows_Packaging_tensorrt Stage' - inputs: - artifactName: 'onnxruntime-linux-x64-tensorrt' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' + - template: ../templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - NuGet' + ArtifactName: 'onnxruntime-linux-x64-cuda' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact - protoc from Windows_Packaging_(cpu|gpu) Stage' - inputs: - artifactName: 'drop-extra' - targetPath: '$(Build.BinariesDirectory)/extra-artifact' + - template: ../templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - NuGet' + ArtifactName: 'onnxruntime-linux-x64-tensorrt' + TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' + + - template: ../templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Pipeline Artifact - NuGet' + ArtifactName: 'drop-extra' + TargetPath: '$(Build.BinariesDirectory)/extra-artifact' # Reconstruct the build dir - task: PowerShell@2 @@ -222,7 +217,6 @@ stages: msbuildArguments: '-t:Clean -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=Microsoft.ML.OnnxRuntime.Gpu' workingDirectory: '$(Build.SourcesDirectory)\csharp' - - task: RoslynAnalyzers@2 displayName: 'Run Roslyn Analyzers' inputs: @@ -242,4 +236,4 @@ stages: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' - condition: always() + condition: always() \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index 6530400deedf..d42f89b26774 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -14,23 +14,27 @@ parameters: default: false stages: - # Linux CUDA without TensorRT Packaging -- stage: Linux_C_API_Packaging_GPU_x64 +- stage: Linux_C_API_Packaging_Combined_CUDA dependsOn: [] jobs: - - job: + - job: Linux_C_API_Packaging_CUDA workspace: clean: all timeoutInMinutes: 150 pool: 'Onnxruntime-Linux-GPU' variables: - - name: CUDA_VERSION_MAJOR - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: '11' - ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: '12' - - name: CUDA_VERSION - value: ${{ parameters.CudaVersion }} + - name: CUDA_VERSION_MAJOR + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: '11' + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: '12' + - name: CUDA_VERSION + value: ${{ parameters.CudaVersion }} + - name: docker_base_image + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 steps: - template: ../templates/set-version-number-variables-step.yml - template: ../templates/get-docker-image-steps.yml @@ -65,122 +69,120 @@ stages: parameters: condition: 'succeeded' - template: ../templates/clean-agent-build-directory-step.yml -# Linux CUDA with TensorRT Packaging -- template: ../templates/linux-gpu-tensorrt-packaging-pipeline.yml - parameters: - artifactName: 'onnxruntime-linux-x64-tensorrt-$(OnnxRuntimeVersion)' - artifactNameNoVersionString: 'onnxruntime-linux-x64-tensorrt' - buildJava: ${{ parameters.buildJava }} - buildJavaOption: '--build_java' - buildNodejs: ${{ parameters.buildNodejs }} - buildNodejsOption: '--build_nodejs' - CudaVersion: ${{ parameters.CudaVersion }} -# Linux CUDA Combined Testing and Publishing -- stage: Linux_Packaging_combined_GPU - dependsOn: - - Linux_C_API_Packaging_GPU_x64 - - Linux_C_API_Packaging_GPU_TensorRT_x64 - condition: succeeded() - jobs: - - job: - workspace: - clean: all - pool: 'Onnxruntime-Linux-GPU' - variables: - - name: CUDA_VERSION_MAJOR - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: '11' - ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: '12' - steps: - - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime - submodules: false - - checkout: onnxruntime-inference-examples # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime-inference-examples - submodules: false - - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux - submodules: false + # Linux CUDA with TensorRT Packaging + - template: jobs/linux-gpu-tensorrt-packaging-job.yml + parameters: + artifactName: 'onnxruntime-linux-x64-tensorrt-$(OnnxRuntimeVersion)' + artifactNameNoVersionString: 'onnxruntime-linux-x64-tensorrt' + buildJava: ${{ parameters.buildJava }} + buildJavaOption: '--build_java' + buildNodejs: ${{ parameters.buildNodejs }} + buildNodejsOption: '--build_nodejs' + CudaVersion: ${{ parameters.CudaVersion }} + # Linux CUDA Combined Testing and Publishing + - job: Linux_Packaging_combined_CUDA + dependsOn: + - Linux_C_API_Packaging_CUDA + - Linux_C_API_Packaging_TensorRT + condition: succeeded() + workspace: + clean: all + pool: 'Onnxruntime-Linux-GPU' + variables: + - name: CUDA_VERSION_MAJOR + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: '11' + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: '12' + steps: + - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime + submodules: false + - checkout: onnxruntime-inference-examples # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime-inference-examples + submodules: false + - checkout: manylinux # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/manylinux + submodules: false - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() - - script: | - set -e -x - cd $(Build.SourcesDirectory) - mv manylinux onnxruntime - ls + - script: | + set -e -x + cd $(Build.SourcesDirectory) + mv manylinux onnxruntime + ls - - template: ../templates/with-container-registry-steps.yml - parameters: - Steps: - - script: | - tools/ci_build/get_docker_image.py \ - --dockerfile tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda \ - --context tools/ci_build/github/linux/docker \ - --docker-build-args "--network=host --build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION=${{ parameters.linux_trt_version }} --build-arg BUILD_UID=$( id -u )" \ - --container-registry onnxruntimebuildcache \ - --multiple_repos \ - --repository onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build - displayName: "Get onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build image for tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda" - workingDirectory: $(Build.SourcesDirectory)/onnxruntime - ContainerRegistry: onnxruntimebuildcache + - template: ../templates/with-container-registry-steps.yml + parameters: + Steps: + - script: | + tools/ci_build/get_docker_image.py \ + --dockerfile tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda \ + --context tools/ci_build/github/linux/docker \ + --docker-build-args "--network=host --build-arg BASEIMAGE=${{ parameters.docker_base_image }} --build-arg TRT_VERSION=${{ parameters.linux_trt_version }} --build-arg BUILD_UID=$( id -u )" \ + --container-registry onnxruntimebuildcache \ + --multiple_repos \ + --repository onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build + displayName: "Get onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build image for tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda" + workingDirectory: $(Build.SourcesDirectory)/onnxruntime + ContainerRegistry: onnxruntimebuildcache - - template: ../templates/set-version-number-variables-step.yml - parameters: - versionFileDirectory: '$(Build.SourcesDirectory)/onnxruntime' - workingDirectory: '$(Build.SourcesDirectory)/onnxruntime' - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact - Combined GPU' - inputs: - artifactName: 'onnxruntime-linux-x64-cuda' - targetPath: '$(Build.BinariesDirectory)/tgz-artifacts' + - template: ../templates/set-version-number-variables-step.yml + parameters: + versionFileDirectory: '$(Build.SourcesDirectory)/onnxruntime' + workingDirectory: '$(Build.SourcesDirectory)/onnxruntime' + - task: DownloadPipelineArtifact@2 + displayName: 'Download Pipeline Artifact - Combined GPU' + inputs: + artifactName: 'onnxruntime-linux-x64-cuda' + targetPath: '$(Build.BinariesDirectory)/tgz-artifacts' - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact - Combined GPU' - inputs: - artifactName: 'onnxruntime-linux-x64-tensorrt' - targetPath: '$(Build.BinariesDirectory)/tgz-artifacts' + - task: DownloadPipelineArtifact@2 + displayName: 'Download Pipeline Artifact - Combined GPU' + inputs: + artifactName: 'onnxruntime-linux-x64-tensorrt' + targetPath: '$(Build.BinariesDirectory)/tgz-artifacts' - - task: ShellScript@2 - displayName: 'Shell Script' - inputs: - scriptPath: 'onnxruntime/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh' - args: '-a $(Build.BinariesDirectory)/tgz-artifacts' - workingDirectory: '$(Build.BinariesDirectory)/tgz-artifacts' + - task: ShellScript@2 + displayName: 'Shell Script' + inputs: + scriptPath: 'onnxruntime/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh' + args: '-a $(Build.BinariesDirectory)/tgz-artifacts' + workingDirectory: '$(Build.BinariesDirectory)/tgz-artifacts' - - task: ArchiveFiles@2 - inputs: - rootFolderOrFile: '$(Build.BinariesDirectory)/tgz-artifacts/onnxruntime-linux-x64-gpu' - includeRootFolder: false - archiveType: 'tar' # Options: zip, 7z, tar, wim - tarCompression: 'gz' - archiveFile: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' - replaceExistingArchive: true + - task: ArchiveFiles@2 + inputs: + rootFolderOrFile: '$(Build.BinariesDirectory)/tgz-artifacts/onnxruntime-linux-x64-gpu' + includeRootFolder: false + archiveType: 'tar' # Options: zip, 7z, tar, wim + tarCompression: 'gz' + archiveFile: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' + replaceExistingArchive: true - - template: ../templates/validate-package.yml - parameters: - PackageType: 'tarball' - PackagePath: '$(Build.ArtifactStagingDirectory)' - PackageName: 'onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' - ScriptPath: '$(Build.SourcesDirectory)/onnxruntime/tools/nuget/validate_package.py' - PlatformsSupported: 'linux-x64' - VerifyNugetSigning: false - workingDirectory: '$(Build.ArtifactStagingDirectory)' + - template: ../templates/validate-package.yml + parameters: + PackageType: 'tarball' + PackagePath: '$(Build.ArtifactStagingDirectory)' + PackageName: 'onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' + ScriptPath: '$(Build.SourcesDirectory)/onnxruntime/tools/nuget/validate_package.py' + PlatformsSupported: 'linux-x64' + VerifyNugetSigning: false + workingDirectory: '$(Build.ArtifactStagingDirectory)' - - task: CmdLine@2 - displayName: 'Test C API application for GPU package' - inputs: - script: | - docker run --gpus all -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e NVIDIA_VISIBLE_DEVICES=all --rm --volume /data/models:/data/models --volume $(Build.SourcesDirectory):/src_dir \ - --volume $(Build.ArtifactStagingDirectory):/artifact_src -e NIGHTLY_BUILD onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build \ - /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet/run_capi_application.sh -o /src_dir/onnxruntime -p /artifact_src/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz -w /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet - workingDirectory: '$(Build.ArtifactStagingDirectory)' + - task: CmdLine@2 + displayName: 'Test C API application for GPU package' + inputs: + script: | + docker run --gpus all -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e NVIDIA_VISIBLE_DEVICES=all --rm --volume /data/models:/data/models --volume $(Build.SourcesDirectory):/src_dir \ + --volume $(Build.ArtifactStagingDirectory):/artifact_src -e NIGHTLY_BUILD onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build \ + /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet/run_capi_application.sh -o /src_dir/onnxruntime -p /artifact_src/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz -w /src_dir/onnxruntime-inference-examples/c_cxx/squeezenet + workingDirectory: '$(Build.ArtifactStagingDirectory)' - - task: PublishPipelineArtifact@1 - inputs: - targetPath: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' - artifactName: 'onnxruntime-linux-x64-gpu' - - template: ../templates/component-governance-component-detection-steps.yml - parameters: - condition: 'succeeded' + - task: PublishPipelineArtifact@1 + inputs: + targetPath: '$(Build.ArtifactStagingDirectory)/onnxruntime-linux-x64-gpu-$(OnnxRuntimeVersion).tgz' + artifactName: 'onnxruntime-linux-x64-gpu' + - template: ../templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml index f599f45059c0..ad5c41b9fbd1 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -18,8 +18,10 @@ parameters: - name: CudaVersion type: string default: '11.8' + - name: win_cuda_home type: string + - name: win_trt_home type: string @@ -40,16 +42,16 @@ stages: ort_build_pool_name: 'onnxruntime-Win2022-GPU-T4' DoCompliance: ${{ parameters.DoCompliance }} DoEsrp: ${{ parameters.DoEsrp }} - stage_name_suffix: gpu + stage_name_suffix: CUDA buildArch: x64 msbuildPlatform: x64 packageName: x64-cuda CudaVersion: ${{ parameters.CudaVersion }} buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" runTests: ${{ parameters.RunOnnxRuntimeTests }} - buildJava: false + buildJava: true java_artifact_id: onnxruntime_gpu - PublishProtoc: true + UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} # Windows CUDA with TensorRT Packaging @@ -58,14 +60,14 @@ stages: ort_build_pool_name: 'onnxruntime-Win2022-GPU-T4' DoCompliance: ${{ parameters.DoCompliance }} DoEsrp: ${{ parameters.DoEsrp }} - stage_name_suffix: tensorrt + stage_name_suffix: TensorRT buildArch: x64 msbuildPlatform: x64 CudaVersion: ${{ parameters.CudaVersion }} packageName: x64-tensorrt buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" runTests: ${{ parameters.RunOnnxRuntimeTests }} - buildJava: false + buildJava: true java_artifact_id: onnxruntime_gpu UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} SpecificArtifact: ${{ parameters.SpecificArtifact }} @@ -74,8 +76,8 @@ stages: # Windows CUDA Combined Testing and Publishing - stage: Windows_Packaging_combined_GPU dependsOn: - - Windows_Packaging_gpu_Testing - - Windows_Packaging_tensorrt_Testing + - Windows_Packaging_CUDA + - Windows_Packaging_TensorRT condition: succeeded() jobs: @@ -90,6 +92,10 @@ stages: - checkout: self # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime - checkout: onnxruntime-inference-examples # due to checkout multiple repos, the root directory is $(Build.SourcesDirectory)/onnxruntime-inference-examples submodules: false + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + - script: dir $(Build.SourcesDirectory) - template: ../templates/jobs/download_win_gpu_library.yml parameters: @@ -153,24 +159,13 @@ stages: filename: $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet\run_capi_application.bat arguments: $(Build.SourcesDirectory)\onnxruntime $(Build.ArtifactStagingDirectory)\onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet workingFolder: '$(Build.ArtifactStagingDirectory)' - - script: | - dir - workingDirectory: '$(Build.ArtifactStagingDirectory)' - displayName: 'List ArtifactStagingDirectory before delete' - - - task: DeleteFiles@1 - displayName: 'Clean up none zip files from ArtifactStagingDirectory' - inputs: - SourceFolder: $(Build.ArtifactStagingDirectory) - Contents: '*/' - - - script: | - dir - workingDirectory: '$(Build.ArtifactStagingDirectory)' - displayName: 'List ArtifactStagingDirectory after delete' - task: PublishPipelineArtifact@0 displayName: 'Publish Pipeline Combined GPU Package Artifact' inputs: artifactName: 'onnxruntime-win-x64-gpu' targetPath: '$(Build.ArtifactStagingDirectory)/onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip' + + - template: ../templates/component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index f8c0009f5abe..b103bef31bac 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -306,6 +306,8 @@ stages: OrtPackageId: ${{ parameters.OrtNugetPackageId }} breakCodesignValidationInjection: ${{ parameters.DoEsrp }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] + BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] steps: - checkout: self @@ -443,7 +445,7 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' platform: 'Any CPU' configuration: RelWithDebInfo - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix)' + msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:ReleaseVersionSuffix=$(ReleaseVersionSuffix) -p:CurrentTime=$(BuildTime) -p:CurrentDate=$(BuildDate)' workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: CopyFiles@2 @@ -517,7 +519,7 @@ stages: - Windows_CI_GPU_DML_Dev - Windows_CI_GPU_DML_Dev_arm64 - Linux_C_API_Packaging_CPU - - Linux_C_API_Packaging_GPU_TensorRT_x64 + - Linux_C_API_Packaging_Combined_CUDA - MacOS_C_API_Package_Publish condition: succeeded() jobs: diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml deleted file mode 100644 index d0d89ed6abd1..000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml +++ /dev/null @@ -1,117 +0,0 @@ -parameters: -- name: artifactName - type: string - default: 'onnxruntime-linux-x64-gpu-tensorrt-$(OnnxRuntimeVersion)' - -- name: artifactNameNoVersionString - type: string - default: 'onnxruntime-linux-x64-gpu-tensorrt' - -- name: buildJava - type: boolean - default: false - -- name: buildJavaOption - type: string - default: '' - -- name: buildNodejs - type: boolean - default: true - -- name: buildNodejsOption - type: string - default: '' - -- name: CudaVersion - displayName: CUDA version - type: string - default: '11.8' - values: - - 11.8 - - 12.2 - - - -# We only have CUDA/TRT on x64. We do not have a build for CUDA/TRT for ARM64. -# Therefore this file does not have an `OnnxruntimeNodejsBindingArch` parameter - -stages: -- stage: Linux_C_API_Packaging_GPU_TensorRT_x64 - dependsOn: [] - variables: - - name: linux_trt_version - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: 10.0.1.6-1.cuda11.8 - ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: 10.0.1.6-1.cuda12.4 - - name: docker_base_image - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 - ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: nvidia/cuda:12.2.2-cudnn8-devel-ubi8 - jobs: - - job: - dependsOn: [] - workspace: - clean: all - timeoutInMinutes: 180 - pool: 'Onnxruntime-Linux-GPU' - variables: - - name: CUDA_VERSION_MAJOR - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: '11' - ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: '12' - - name: CUDA_VERSION - value: ${{ parameters.CudaVersion }} - steps: - - checkout: self - clean: true - submodules: recursive - - template: get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: " - --network=host - --build-arg BASEIMAGE=${{ variables.docker_base_image }} - --build-arg TRT_VERSION=${{ variables.linux_trt_version }} - --build-arg BUILD_UID=$( id -u ) - " - Repository: onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build - - template: set-version-number-variables-step.yml - - - script: $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh - workingDirectory: $(Build.SourcesDirectory) - displayName: 'Build and Test' - - - ${{ if eq(parameters.buildJava, true) }}: - - template: java-api-artifacts-package-and-publish-steps-posix.yml - parameters: - arch: 'linux-x64' - buildConfig: 'Release' - artifactName: 'onnxruntime-java-linux-x64-tensorrt' - version: '$(OnnxRuntimeVersion)' - libraryName: 'libonnxruntime.so' - nativeLibraryName: 'libonnxruntime4j_jni.so' - - - ${{ if eq(parameters.buildNodejs, 'true') }}: - - template: nodejs-artifacts-package-and-publish-steps-posix.yml - parameters: - arch: 'x64' - os: 'linux' - artifactName: 'drop-onnxruntime-nodejs-linux-x64-tensorrt' - - - template: c-api-artifacts-package-and-publish-steps-posix.yml - parameters: - buildConfig: 'Release' - artifactName: ${{ parameters.artifactName }} - artifactNameNoVersionString: ${{ parameters.artifactNameNoVersionString }} - libraryName: 'libonnxruntime.so.$(OnnxRuntimeVersion)' - - - - template: component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - template: clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 1836bf19e4e9..4a695e1f3c43 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -115,7 +115,7 @@ jobs: FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' DisplayName: 'ESRP - Sign Native dlls' DoEsrp: true - Pattern: '*.pyd,*.dll' + Pattern: '*.pyd' - task: PythonScript@0 displayName: 'Build wheel' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 44ffbeb0705f..dfebf17d95aa 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -127,7 +127,7 @@ jobs: FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' DisplayName: 'ESRP - Sign Native dlls' DoEsrp: true - Pattern: '*.pyd,*.dll' + Pattern: '*.pyd' - task: PythonScript@0 displayName: 'Build wheel' diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 1b7962059e30..d13cb7a99e7f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -347,13 +347,6 @@ stages: condition: succeededOrFailed() displayName: Publish React Native Detox iOS e2e Test Results - - task: PublishPipelineArtifact@1 - inputs: - artifact: e2e_test_logs - targetPath: '$(Build.SourcesDirectory)/js/react_native/e2e/artifacts' - condition: succeededOrFailed() - displayName: Publish React Native Detox E2E test logs - - script: | git restore . workingDirectory: '$(Build.SourcesDirectory)/js' @@ -381,14 +374,21 @@ stages: targetFolder: $(Build.ArtifactStagingDirectory) displayName: Create Artifacts onnxruntime-react-native + - template: component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' + + - task: PublishPipelineArtifact@1 + inputs: + artifact: e2e_test_logs_$(Build.BuildId)_$(Build.BuildNumber)_$(System.JobAttempt) + targetPath: '$(Build.SourcesDirectory)/js/react_native/e2e/artifacts' + condition: succeededOrFailed() + displayName: Publish React Native Detox E2E test logs + - task: PublishPipelineArtifact@0 inputs: artifactName: '${{parameters.PackageName}}' targetPath: '$(Build.ArtifactStagingDirectory)' displayName: Publish Pipeline Artifact - - template: component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' - - template: explicitly-defined-final-tasks.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index f2a29ef8a4c6..d6e8a30b441d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -88,8 +88,16 @@ stages: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' CUDA_MODULE_LOADING: 'LAZY' + ${{ if eq(parameters['buildJava'], 'true') }}: + buildJavaParameter: '--build_java' + ${{ else }}: + buildJavaParameter: '' + ${{ if eq(parameters['UseIncreasedTimeoutForTests'], 'true') }}: + timeoutParameter: '--test_all_timeout 72000' + ${{ else }}: + timeoutParameter: '' jobs: - - job: + - job: Windows_Packaging_${{ parameters.stage_name_suffix }} workspace: clean: all ${{ if contains(parameters.ort_build_pool_name, 'GPU') }}: @@ -166,10 +174,7 @@ stages: displayName: 'Generate cmake config' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - ${{ if eq(parameters['UseIncreasedTimeoutForTests'], 'true') }}: - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} --test_all_timeout 72000' - ${{ else }}: - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} ' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} $(timeoutParameter) $(buildJavaParameter)' workingDirectory: '$(Build.BinariesDirectory)' @@ -302,14 +307,9 @@ stages: - template: component-governance-component-detection-steps.yml parameters : condition : 'succeeded' - -- ${{ if contains(parameters.ort_build_pool_name, 'GPU') }}: - - stage: Windows_Packaging_${{ parameters.stage_name_suffix }}_Testing - dependsOn: Windows_Packaging_${{ parameters.stage_name_suffix }} - variables: - CUDA_MODULE_LOADING: 'LAZY' - jobs: + - ${{ if contains(parameters.ort_build_pool_name, 'GPU') }}: - job: Windows_Packaging_${{ parameters.stage_name_suffix }}_Testing + dependsOn: Windows_Packaging_${{ parameters.stage_name_suffix }} workspace: clean: all pool: ${{ parameters.ort_build_pool_name }}