Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/WebGPU] Preserve zero size input tensor dims. #19737

Merged
merged 12 commits into from
Mar 8, 2024
Prev Previous commit
Next Next commit
Removed referenceInput.
  • Loading branch information
satyajandhyala committed Mar 7, 2024
commit ed048b095bfc4f5b2bc6714cd0edcf175f7ff60e
15 changes: 4 additions & 11 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ export interface ConcatAttributes extends AttributeWithCacheKey {
readonly axis: number;
}

const validateInputs = (inputs: readonly TensorView[], referenceIndex: number, axis: number): void => {
const validateInputs = (inputs: readonly TensorView[], axis: number): void => {
if (!inputs || inputs.length < 1) {
throw new Error('too few inputs');
}
fs-eire marked this conversation as resolved.
Show resolved Hide resolved

const referenceIndex = 0;
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
const referenceInput = inputs[referenceIndex];
const inputType = referenceInput.dataType;
const inputRank = referenceInput.dims.length;
Expand Down Expand Up @@ -141,17 +141,10 @@ export const concat = (context: ComputeContext, attributes: ConcatAttributes): v
// find a none zero tensor as reference to determine the output shape
// choose input 0 as reference if all input tensors are zero-sized.
const inputs = context.inputs;
let referenceIndex = 0;
for (let i = 0; i < inputs.length; i++) {
if (ShapeUtil.size(inputs[i].dims) > 0) {
referenceIndex = i;
break;
}
}

const inputShape = inputs[referenceIndex].dims;
const inputShape = inputs[0].dims;
const adjustedAxis = attributes.axis + (attributes.axis < 0 ? inputShape.length : 0);
validateInputs(inputs, referenceIndex, adjustedAxis);
validateInputs(inputs, adjustedAxis);
const outputShape = inputShape.slice();
outputShape[adjustedAxis] =
inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0);
Expand Down
Loading