Skip to content

Commit

Permalink
Rollback code allowing zero-sized input non-concat axes dims mismatch…
Browse files Browse the repository at this point in the history
… referance input
  • Loading branch information
satyajandhyala committed Mar 6, 2024
1 parent e8fb4d2 commit f19ece3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 203 deletions.
48 changes: 21 additions & 27 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ const validateInputs = (inputs: readonly TensorView[], referenceIndex: number, a
const referenceInput = inputs[referenceIndex];
const inputType = referenceInput.dataType;
const inputRank = referenceInput.dims.length;
const referenceInputSize = ShapeUtil.size(referenceInput.dims);
inputs.forEach((input, i) => {
if (i === referenceIndex) {
return;
Expand All @@ -30,17 +29,15 @@ const validateInputs = (inputs: readonly TensorView[], referenceIndex: number, a
if (input.dataType !== inputType) {
throw new Error('input tensors should be one type');
}
if (referenceInputSize > 0 && ShapeUtil.size(input.dims) > 0) {
// make sure the dimensionality of all inputs are the same
if (input.dims.length !== inputRank) {
throw new Error('input tensors should have the same shape');
}
input.dims.forEach((dim, i) => {
if (i !== axis && dim !== referenceInput.dims[i]) {
throw new Error('non concat dimensions must match');
}
});
// make sure the dimensionality of all inputs are the same
if (input.dims.length !== inputRank) {
throw new Error('input tensors should have the same shape');
}
input.dims.forEach((dim, i) => {
if (i !== axis && dim !== referenceInput.dims[i]) {
throw new Error('non concat dimensions must match');
}
});
});
};

Expand Down Expand Up @@ -120,16 +117,12 @@ const createConcatProgramInfo =
var indices = ${output.offsetToIndices('global_idx')};
let inputIndex = calculateInputIndex(${indicesAxis});
if (inputIndex < ${inputs.length}u) {
if (inputIndex != 0u) {
let sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}u>(${sizeInConcatAxisStr});
${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u];
}
${assignOutputData(inputVars, output)}
} else {
${output.setByOffset('global_idx', '0')}
if (inputIndex != 0u) {
let sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}u>(${sizeInConcatAxisStr});
${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u];
}
${assignOutputData(inputVars, output)}
}`;

return {
Expand All @@ -145,14 +138,15 @@ const createConcatProgramInfo =
};

export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
// find a none zero tensor to determine the output shape
// Choose input with max rank if all input tensors are zero size to make the output shape independent of the order of
// the inputs.
// find a none zero tensor as reference to determine the output shape
// choose input 0 as reference if all input tensors are zero-sized.
const inputs = context.inputs;
let referenceIndex = inputs.findIndex(input => ShapeUtil.size(input.dims) > 0);
if (referenceIndex === -1) {
referenceIndex = inputs.reduce(
(maxRankIndex, input, index, array) => input.dims > array[maxRankIndex].dims ? index : maxRankIndex, 0);
let referenceIndex = 0;
for (let i = 0; i < inputs.length; i++) {
if (ShapeUtil.size(inputs[i].dims) > 0) {
referenceIndex = i;
break;
}
}

const inputShape = inputs[referenceIndex].dims;
Expand Down
178 changes: 2 additions & 176 deletions js/web/test/data/ops/concat_zero-sized.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -558,56 +558,6 @@
}
]
},
{
"name": "Concat 2D axis=1; Preserve dims",
"operator": "Concat",
"attributes": [
{
"name": "axis",
"data": 1,
"type": "int"
}
],
"cases": [
{
"name": "Some but not all input tensors have 0 in dims along the other axis",
"inputs": [
{
"data": [],
"dims": [0, 0],
"type": "float32"
},
{
"data": [],
"dims": [0, 1],
"type": "float32"
},
{
"data": [],
"dims": [0, 2],
"type": "float32"
},
{
"data": [],
"dims": [0, 3],
"type": "float32"
},
{
"data": [1],
"dims": [1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 0, 0, 0, 0, 0, 0],
"dims": [1, 7],
"type": "float32"
}
]
}
]
},
{
"name": "Concat 2D axis=1; Preserve dims",
"operator": "Concat",
Expand All @@ -620,28 +570,13 @@
],
"cases": [
{
"name": "Some but not all input tensors have 0 in dims along the axis",
"name": "Some but not all input tensors are zero-sized",
"inputs": [
{
"data": [],
"dims": [0, 0],
"type": "float32"
},
{
"data": [],
"dims": [0, 1],
"type": "float32"
},
{
"data": [],
"dims": [0, 2],
"type": "float32"
},
{
"data": [],
"dims": [0, 3],
"type": "float32"
},
{
"data": [1],
"dims": [1, 1],
Expand Down Expand Up @@ -670,7 +605,7 @@
],
"cases": [
{
"name": "All input tensors have 0 in dims along the other axis",
"name": "All input tensors are zero-sized",
"inputs": [
{
"data": [],
Expand Down Expand Up @@ -702,114 +637,5 @@
]
}
]
},
{
"name": "Concat 2D axis=1; Preserve dims",
"operator": "Concat",
"attributes": [{ "name": "axis", "data": 0, "type": "int" }],
"cases": [
{
"name": "Zero input tensor rank is different from the other input tensors; zero dim along the axis",
"inputs": [
{
"data": [],
"dims": [0, 1, 1],
"type": "float32"
},
{
"data": [1],
"dims": [1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [1],
"dims": [1, 1],
"type": "float32"
}
]
},
{
"name": "Zero input tensor rank is different from the other input tensors",
"inputs": [
{
"data": [],
"dims": [1, 1, 0],
"type": "float32"
},
{
"data": [1],
"dims": [1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 0],
"dims": [2, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Concat 2D axis=0; Preserve dims",
"operator": "Concat",
"attributes": [{ "name": "axis", "data": 0, "type": "int" }],
"cases": [
{
"name": "All input tensors have 0 in dims along the axis",
"inputs": [
{
"data": [],
"dims": [0, 1, 1],
"type": "float32"
},
{
"data": [],
"dims": [0, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [],
"dims": [0, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Concat 2D axis=1; Preserve dims",
"operator": "Concat",
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
"cases": [
{
"name": "All input tensors have 0 in dims along the other axis",
"inputs": [
{
"data": [],
"dims": [0, 1, 1],
"type": "float32"
},
{
"data": [],
"dims": [0, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [],
"dims": [0, 2, 1],
"type": "float32"
}
]
}
]
}
]

0 comments on commit f19ece3

Please sign in to comment.