Skip to content

Commit

Permalink
[js/webgpu] Create Split indices helpers by rank, not by shape (#19554)
Browse files Browse the repository at this point in the history
### Description
This is required to make shape uniforms really work.

### Motivation and Context
The bug was unveiled in a model with multiple Split nodes. The later
nodes would try to reuse a previous pipeline cache, while the old shapes
were hardcoded as constants in cache.
  • Loading branch information
hujiajie committed Feb 20, 2024
1 parent 7efb0db commit 1b48054
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/split.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
const dataType = inputs[0].dataType;
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
const outputs = new Array<IndicesHelper>(attributes.numOutputs);
const input = inputVariable('input', dataType, inputShape);
const input = inputVariable('input', dataType, inputShape.length);
const sizeInSplitAxis = new Array<number>(attributes.numOutputs);
const outputsTensorInfo: TensorInfo[] = [];
const outputShapes: number[][] = [];
Expand All @@ -80,7 +80,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
const outputShape = inputShape.slice();
outputShape[attributes.axis] = attributes.splitSizes[i];
outputShapes.push(outputShape);
outputs[i] = outputVariable(`output${i}`, dataType, outputShape);
outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length);
outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType});
}
programUniforms.push(
Expand Down

0 comments on commit 1b48054

Please sign in to comment.