Skip to content

Commit

Permalink
Fix Split index bugs uncovered by QNN SDK 2.19 (#19381)
Browse files Browse the repository at this point in the history
### Description
- When converting ONNX split sizes to QNN split indices, do not include
the split at index 0. QNN 2.19 assumes index 0 is implicit and throws a
validation error if provided.
- Fix bug when using an ONNX Split operator with a `num_outputs`
attribute that does not evenly divide into `shape[axis]`. The ONNX spec
states that the last chunk should be smaller, but QNN EP made the last
chunk larger.
- Fix bug when using an ONNX Split operator with a `split` input. QNN EP
was incorrectly passing the split sizes as split indices without
conversion.

### Motivation and Context
QNN SDK 2.19 updated validation criteria for Split operators. QNN EP was
previously passing a split index that should have been implicit.

Also, discovered a bugs when using `num_outputs` attribute and `split`
input.
  • Loading branch information
adrianlizarraga committed Feb 2, 2024
1 parent 09d5c1b commit a2eb967
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ Status SplitOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}

// Converts an ONNX list of split lengths to a QNN list of split indices.
// Note that the first split index at 0 is implicit (QNN SDK >= 2.19 will raise a validation error if included).
static void ConvertSplitLengthsToSplitIndices(gsl::span<const int64_t> split_lengths,
std::vector<uint32_t>& split_indices) {
uint32_t split_it = 0;
for (size_t i = 0; i < split_lengths.size(); ++i) {
if (i > 0) { // Do not include the 0th split index.
split_indices.push_back(split_it);
}
split_it += SafeInt<uint32_t>(split_lengths[i]);
}
}

Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
Expand All @@ -79,22 +92,15 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr
const int64_t* tensor_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());
size_t tensor_byte_size = unpacked_tensor.size();
size_t size = tensor_byte_size / sizeof(int64_t);
split_index.push_back(0); // QNN need the start index of each range and starts from 0
std::transform(tensor_data, tensor_data + size, std::back_inserter(split_index),
[](int64_t item) { return SafeInt<uint32_t>(item); });
split_index.pop_back();
ConvertSplitLengthsToSplitIndices({tensor_data, size}, split_index);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic split");
}
} else {
NodeAttrHelper node_helper(node_unit);
if (node_helper.HasAttr("split")) {
auto split = node_helper.Get("split", std::vector<int32_t>{0});
uint32_t split_it = 0;
for (size_t i = 0; i < split.size(); ++i) {
split_index.push_back(split_it);
split_it += split[i];
}
auto split_lengths = node_helper.Get("split", std::vector<int64_t>{0});
ConvertSplitLengthsToSplitIndices(split_lengths, split_index);
}
}

Expand All @@ -105,11 +111,19 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr
"Cannot get shape");
ORT_ENFORCE(static_cast<int32_t>(input_shape.size()) > axis_value, "axis not valid!");
ORT_RETURN_IF_NOT(input_shape.at(axis_value) > 0, "Shape value not valid!");
auto num_outputs = node_unit.Outputs().size();
auto step = SafeInt<uint32_t>(input_shape.at(axis_value) / num_outputs);

// ONNX spec states that if not evenly divisible by `num_outputs`, the last chunk is smaller.
// Therefore, we have to use ceil() when computing shape[axis] / num_outputs.
// See: core/providers/cpu/tensor/split.cc::PrepareForCompute()
const float num_outputs = static_cast<float>(node_unit.Outputs().size());
const float split_dim_size = static_cast<float>(input_shape[axis_value]);
const uint32_t step = SafeInt<uint32_t>(std::ceil(split_dim_size / num_outputs));
uint32_t split_it = 0;

for (size_t i = 0; i < num_outputs; ++i) {
split_index.push_back(split_it);
if (i > 0) { // 0th split index is implicit (QNN >= 2.19 raises validation error if included)
split_index.push_back(split_it);
}
split_it += step;
}
}
Expand Down
41 changes: 34 additions & 7 deletions onnxruntime/test/providers/qnn/split_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,19 +302,46 @@ TEST_F(QnnHTPBackendTests, Split_Int32_Opset13) {
// Test 8-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute
// and 'split' input.
TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18) {
// Split 6 into 3 outputs of lengths [2, 2, 2]
TestInputDef<float> input_def({6, 2}, false,
{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f, 9.0f, 10.0f, 11.0f});

// Use 'split' input (initializer).
RunQDQSplitOpTestOnHTP<uint8_t>(TestInputDef<float>({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}),
{2, 2}, // split
0, // axis
-1, // num_outputs
18, // opset
RunQDQSplitOpTestOnHTP<uint8_t>(input_def,
{2, 2, 2}, // split
0, // axis
-1, // num_outputs
18, // opset
ExpectedEPNodeAssignment::All);

// Use 'num_outputs' attribute.
RunQDQSplitOpTestOnHTP<uint8_t>(TestInputDef<float>({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}),
RunQDQSplitOpTestOnHTP<uint8_t>(input_def,
{}, // split (use num_outputs instead)
0, // axis
3, // num_outputs
18, // opset
ExpectedEPNodeAssignment::All);
}

// Test 8-bit QDQ Split opset 18 on HTP backend. Use an uneven split (last chunk should be smaller).
TEST_F(QnnHTPBackendTests, Split_NonEqual_Axis0_Opset18) {
// Split 7 into 3 outputs of lengths [3, 3, 1]
TestInputDef<float> input_def({7, 2}, false,
{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f});

// Use a `split` input with uneven split lengths.
RunQDQSplitOpTestOnHTP<uint8_t>(input_def,
{3, 3, 1}, // split
0, // axis
-1, // num_outputs
18, // opset
ExpectedEPNodeAssignment::All);

// Use a `num_outputs` attribute that does not evenly divide into shape[axis].
RunQDQSplitOpTestOnHTP<uint8_t>(input_def,
{}, // split (use num_outputs instead)
0, // axis
2, // num_outputs
3, // num_outputs
18, // opset
ExpectedEPNodeAssignment::All);
}
Expand Down

0 comments on commit a2eb967

Please sign in to comment.