Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/WebGPU] Preserve zero size input tensor dims. #19737

Merged
merged 12 commits into from
Mar 8, 2024
146 changes: 69 additions & 77 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,32 @@ export interface ConcatAttributes extends AttributeWithCacheKey {
readonly axis: number;
}

const validateInputs = (inputs: readonly TensorView[]): void => {
const validateInputs = (inputs: readonly TensorView[], axis: number): void => {
if (!inputs || inputs.length < 1) {
throw new Error('too few inputs');
}
fs-eire marked this conversation as resolved.
Show resolved Hide resolved

const inputType = inputs[0].dataType;
const inputDimensionality = inputs[0].dims.length;

for (const input of inputs) {
const referenceIndex = 0;
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
const referenceInput = inputs[referenceIndex];
const inputType = referenceInput.dataType;
const inputRank = referenceInput.dims.length;
inputs.forEach((input, i) => {
if (i === referenceIndex) {
return;
}
// make sure types of all inputs match
if (input.dataType !== inputType) {
throw new Error('input tensors should be one type');
}

// make sure the dimensionality of all inputs are the same
if (input.dims.length !== inputDimensionality) {
if (input.dims.length !== inputRank) {
throw new Error('input tensors should have the same shape');
}
}
input.dims.forEach((dim, i) => {
if (i !== axis && dim !== referenceInput.dims[i]) {
throw new Error('non concat dimensions must match');
}
});
});
};

const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
Expand Down Expand Up @@ -64,65 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
return codeLines.join('\n');
};

const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => {
const inputShape = inputs[0].dims.slice();
if (axis >= inputShape.length || axis < (-1 * inputShape.length)) {
throw new Error('axis specified for concat doesn\'t match input dimensionality');
}
const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis;
// ensure all of the non-concatenated axes match each other
// calculate the shape of the output tensor while we do that
const outputShape = inputShape.slice(0);
for (let i = 1; i < inputs.length; i++) {
const dataNShape = inputs[i].dims.slice();
for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
// add to the placeholder for computing output shape
if (axisIndex === adjustedAxis) {
outputShape[adjustedAxis] += dataNShape[axisIndex];
const createConcatProgramInfo =
(inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => {
const outputSize = ShapeUtil.size(outputShape);

const sizeInConcatAxis = new Array<number>(inputs.length);
const inputVars = new Array<IndicesHelper>(inputs.length);

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
// ensure all non-cancatenated axes match each other
else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
throw new Error('non concat dimensions must match');
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
}
}

const outputSize = ShapeUtil.size(outputShape);

const sizeInConcatAxis = new Array<number>(inputs.length);
const inputVars = new Array<IndicesHelper>(inputs.length);
const dataType = inputs[0].dataType;

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(outputShape));

const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `
const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `

${(() => {
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output);
})()}
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output);
})()}

${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}

Expand All @@ -140,23 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
${assignOutputData(inputVars, output)}
}`;

return {
name: 'Concat',
shaderCache: {hint: `${axis}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms,
}),
getShaderSource,
};
};
return {
name: 'Concat',
shaderCache: {hint: `${adjustedAxis}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms,
}),
getShaderSource,
};
};

export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
validateInputs(context.inputs);
const inputs = context.inputs;
const inputShape = inputs[0].dims;
const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
validateInputs(inputs, adjustedAxis);
const outputShape = inputShape.slice();
outputShape[adjustedAxis] =
inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0);
// 0 length tensors are valid for concat, remove them
const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0);
context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs});
const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0);
context.compute(
createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs});
};

export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes =>
Expand Down
80 changes: 80 additions & 0 deletions js/web/test/data/ops/concat_zero-sized.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -557,5 +557,85 @@
]
}
]
},
{
"name": "Concat 2D axis=1; Preserve dims",
"operator": "Concat",
"attributes": [
{
"name": "axis",
"data": 0,
"type": "int"
}
],
"cases": [
{
"name": "Some but not all input tensors are zero-sized",
"inputs": [
{
"data": [],
"dims": [0, 1],
"type": "float32"
},
{
"data": [1],
"dims": [1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [1],
"dims": [1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Concat 2D axis=1; Preserve dims",
"operator": "Concat",
"attributes": [
{
"name": "axis",
"data": 1,
"type": "int"
}
],
"cases": [
{
"name": "All input tensors are zero-sized",
"inputs": [
{
"data": [],
"dims": [0, 0],
"type": "float32"
},
{
"data": [],
"dims": [0, 1],
"type": "float32"
},
{
"data": [],
"dims": [0, 2],
"type": "float32"
},
{
"data": [],
"dims": [0, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [],
"dims": [0, 6],
"type": "float32"
}
]
}
]
}
]
Loading