Skip to content

Commit

Permalink
[js/web] add ConvTranspose2D to WebGL backend (#11990)
Browse files Browse the repository at this point in the history
* Add ConvTranspose

* Update docs + tests

* fix lint

* fix output shape calculations

* Revert "fix output shape calculations"

This reverts commit 8014fa9.

* fix format

* remove broken output_shape test
  • Loading branch information
101arrowz committed Jul 27, 2022
1 parent d2b25a7 commit 148b1ef
Show file tree
Hide file tree
Showing 9 changed files with 279 additions and 10 deletions.
2 changes: 1 addition & 1 deletion js/web/docs/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [ConstantOfShape](https://github.com/onnx/onnx/blob/master/docs/Operators.md#ConstantOfShape) | |
| [Conv](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv) | [1-10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Conv-1), [11+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Conv-11) |
| [ConvInteger](https://github.com/onnx/onnx/blob/master/docs/Operators.md#ConvInteger) | |
| [ConvTranspose](https://github.com/onnx/onnx/blob/master/docs/Operators.md#ConvTranspose) | |
| [ConvTranspose](https://github.com/onnx/onnx/blob/master/docs/Operators.md#ConvTranspose) | [1-10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#ConvTranspose-1), [11+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#ConvTranspose-11) |
| [Cos](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Cos) | [7+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Cos-7) |
| [Cosh](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Cosh) | |
| [CumSum](https://github.com/onnx/onnx/blob/master/docs/Operators.md#CumSum) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import * as binaryOps from './ops/binary-op';
import {cast, parseCastAttributes} from './ops/cast';
import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
import {depthToSpace, parseDepthToSpaceAttributes} from './ops/depth-to-space';
import {flatten, parseFlattenAttributes} from './ops/flatten';
import {gather, parseGatherAttributes} from './ops/gather';
Expand Down Expand Up @@ -48,6 +49,7 @@ export const WEBGL_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
['Clip', '', '11+', unaryOps.clipV11],
['Concat', '', '4+', concat, parseConcatAttributes],
['Conv', '', '1+', conv, parseConvAttributes],
['ConvTranspose', '', '1+', convTranspose, parseConvTransposeAttributes],
['Cos', '', '7+', unaryOps.cos],
['Div', '', '7+', binaryOps.div],
['Dropout', '', '7+', unaryOps.identity],
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/onnxjs/backends/webgl/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types';

import {calculateOutputShape, ConvAttributes} from './conv';
import {getActicationSnippet} from './fuse-utils';
import {getActivationSnippet} from './fuse-utils';

const createUnpackedGroupedConvProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({
name: 'GroupedConv',
Expand All @@ -33,7 +33,7 @@ const createUnpackedGroupedConvProgramInfo =
const outputShape =
calculateOutputShape(xShape, wShape, attributes.dilations, attributes.pads, attributes.strides);
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const {activationFunction, applyActivation} = getActicationSnippet(attributes);
const {activationFunction, applyActivation} = getActivationSnippet(attributes);

const shaderSource = `
const ivec2 strides = ivec2(${attributes.strides[0]}, ${attributes.strides[1]});
Expand Down
259 changes: 259 additions & 0 deletions js/web/lib/onnxjs/backends/webgl/ops/conv-transpose.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {createAttributeWithCacheKey} from '../../../attribute-with-cache-key';
import {InferenceHandler} from '../../../backend';
import {Graph} from '../../../graph';
import {OperatorImplementation, OperatorInitialization} from '../../../operators';
import {Tensor} from '../../../tensor';
import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types';

import {ConvAttributes} from './conv';
import {getActivationSnippet, parseInternalActivationAttributes} from './fuse-utils';

const computeTotalPad =
(inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) =>
(inDim - 1) * stride + adj + (kernel - 1) * dilation + 1 - outSize;

const distributePadding = (totalPad: number, autoPad: string, pads: number[], head: number, tail: number) => {
const smallPad = Math.floor(totalPad / 2);
if (autoPad === 'SAME_UPPER') {
pads[head] = smallPad;
pads[tail] = totalPad - smallPad;
} else if (autoPad === 'SAME_LOWER') {
pads[head] = totalPad - smallPad;
pads[tail] = smallPad;
}
};

const calculateOutputShapeAndPads =
(inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], autoPad: string,
pads: number[], strides: readonly number[], outputPadding: readonly number[], outputShape: number[]) => {
const spatialRank = inputShape.length - 2;
const updateShape = outputShape.length === 0;
for (let i = 0; i < spatialRank; ++i) {
const outSize = updateShape ? inputShape[i + 2] * strides[i] : outputShape[i];
const totalPad = computeTotalPad(inputShape[i + 2], strides[i], pads[i], kernelShape[i], dilations[i], outSize);
distributePadding(totalPad, autoPad, pads, i, i + spatialRank);
if (updateShape) {
outputShape.push(
strides[i] * (inputShape[i + 2] - 1) + outputPadding[i] + (kernelShape[i] - 1) * dilations[i] + 1 -
pads[i] - pads[i + spatialRank]);
}
}
};

export interface ConvTransposeAttributes extends ConvAttributes {
readonly outputPadding: readonly number[];
readonly outputShape: readonly number[];
}

export const convTranspose: OperatorImplementation<ConvTransposeAttributes> =
(inferenceHandler: InferenceHandler, inputs: Tensor[], attributes: ConvTransposeAttributes): Tensor[] => {
validateInputs(inputs, attributes); // currently will fail if not convTranspose2D
return convTranspose2d(inferenceHandler, inputs, attributes);
};

const convTranspose2d: OperatorImplementation<ConvTransposeAttributes> =
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ConvTransposeAttributes): Tensor[] => {
const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs);
return [convTranspose2DUnpacked(inferenceHandler, inputs, adjustedAttributes)];
};

const createConvTransposeProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
name: 'ConvTranspose',
inputNames: hasBias ? ['X', 'W', 'B'] : ['X', 'W'],
inputTypes: hasBias ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked] :
[TextureType.unpacked, TextureType.unpacked],
cacheHint
});

const createUnpackedConvTransposeProgramInfo =
(inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], metadata: ProgramMetadata,
attributes: ConvTransposeAttributes): ProgramInfo => {
const hasBias = inputs.length > 2;
const valueInit = hasBias ? 'getB(output_channel)' : '0.0';
const xShape = inputs[0].dims;
const wShape = inputs[1].dims;
const outputChannelsPerGroup = wShape[1];
const inputChannelsPerGroup = wShape[0] / attributes.group;
const outputShape = [inputs[0].dims[0], inputs[1].dims[1] * attributes.group, ...attributes.outputShape];
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const {activationFunction, applyActivation} = getActivationSnippet(attributes);

const shaderSource = `
const ivec2 strides = ivec2(${attributes.strides[0]}, ${attributes.strides[1]});
const ivec2 pads = ivec2(${attributes.pads[0]}, ${attributes.pads[1]});
${activationFunction}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords.x;
int output_channel = coords.y;
ivec2 loc = coords.zw + pads;
int group_id = output_channel / ${outputChannelsPerGroup};
int wOutChannel = output_channel - group_id * ${outputChannelsPerGroup};
float value = ${valueInit};
for (int inChannelOffset = 0; inChannelOffset < ${inputChannelsPerGroup}; inChannelOffset++) {
int input_channel = group_id * ${inputChannelsPerGroup} + inChannelOffset;
for (int wWOff = 0; wWOff < ${wShape[2]}; wWOff++) {
for (int wHOff = 0; wHOff < ${wShape[3]}; wHOff++) {
ivec2 wOff = ivec2(wWOff * ${attributes.dilations[0]}, wHOff * ${attributes.dilations[1]});
ivec2 wLoc = loc - wOff;
ivec2 wLocIn = wLoc / strides;
if (
wLocIn * strides == wLoc &&
wLocIn.x >= 0 && wLocIn.x < ${xShape[2]} &&
wLocIn.y >= 0 && wLocIn.y < ${xShape[3]}
) {
float xVal = getX(batch, input_channel, wLocIn.y, wLocIn.x);
float wVal = getW(input_channel, wOutChannel, wHOff, wWOff);
value += xVal * wVal;
}
}
}
}
${applyActivation}
${glsl.output} = vec4(value, .0, .0, .0);
}
`;
return {
...metadata,
output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked},
shaderSource,
hasMain: true,
};
};

const createUnpackedConvTransposeProgramInfoLoader =
(inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvTransposeAttributes):
ProgramInfoLoader => {
const metadata = createConvTransposeProgramMetadata(inputs.length > 2, attributes.cacheKey);
return {
...metadata,
get: () => createUnpackedConvTransposeProgramInfo(inferenceHandler, inputs, metadata, attributes)
};
};


const convTranspose2DUnpacked =
(inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvTransposeAttributes):
Tensor => {
const result = inferenceHandler.run(
createUnpackedConvTransposeProgramInfoLoader(inferenceHandler, inputs, attributes), inputs);
return result;
};

const getAdjustedConvTransposeAttributes = <T extends ConvTransposeAttributes>(attributes: T, inputs: Tensor[]): T => {
const kernelShape = attributes.kernelShape.slice();
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
if (attributes.kernelShape.length === 0) {
for (let i = 2; i < inputs[1].dims.length; ++i) {
kernelShape.push(inputs[1].dims[i]);
}
}

const pads = attributes.pads.slice();
const outputShape = attributes.outputShape.slice();
const inputShape = inputs[0].dims;
// If outputShape is not specified in the attributes of this op, infer it from the parameters
// Similarly, automatically infer pads if not specified
calculateOutputShapeAndPads(
inputShape, kernelShape, attributes.dilations, attributes.autoPad, pads, attributes.strides,
attributes.outputPadding, outputShape);

// always return a new object so does not modify the original attributes
const newAttributes: T = Object.assign({}, attributes);
Object.assign(newAttributes, {kernelShape, pads, outputShape, cacheKey: attributes.cacheKey});
return newAttributes;
};

export const parseConvTransposeAttributes: OperatorInitialization<ConvTransposeAttributes> =
(node: Graph.Node): ConvTransposeAttributes => {
const attributes = node.attributes;
const activationAttributes = parseInternalActivationAttributes(attributes);
// TODO : Make this generic enough to compute default attributes for multi-dimensional conv
const autoPad = attributes.getString('auto_pad', 'NOTSET');
const dilations = attributes.getInts('dilations', [1, 1]);
const group = attributes.getInt('group', 1);
const kernelShape = attributes.getInts('kernel_shape', []);
const outputPadding = attributes.getInts('output_padding', [0, 0]);
const outputShape = attributes.getInts('output_shape', []);
const pads = attributes.getInts('pads', [0, 0, 0, 0]);
const strides = attributes.getInts('strides', [1, 1]);

return createAttributeWithCacheKey(
{autoPad, dilations, group, kernelShape, outputPadding, outputShape, pads, strides, ...activationAttributes});
};

const validateInputs = (inputs: Tensor[], attributes: ConvTransposeAttributes): void => {
// Refer to the below link for all input checks
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv
if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
throw new Error('Conv requires 2 or 3 inputs');
}

// TODO : Need to add support for multi-dimensional conv
if (inputs[0].dims.length !== 4 || inputs[1].dims.length !== 4) {
throw new Error('currently only support 2-dimensional conv');
}

// FILTER_IN_CHANNEL should be equal to DATA_CHANNEL
const dataChannel = inputs[0].dims[1];
const filterInChannel = inputs[1].dims[0];
if (dataChannel !== filterInChannel) {
throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL');
}

const featureMaps = inputs[1].dims[1] * attributes.group;

// if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[2].dims[0] !== featureMaps)) {
throw new Error('invalid bias');
}

const spatialRank = inputs[0].dims.length - 2;
// wrong dilations dimension
if (attributes.dilations.length !== spatialRank) {
throw new Error(`dilations should be ${spatialRank}D`);
}

// Wrong strides dimension
if (attributes.strides.length !== spatialRank) {
throw new Error(`strides should be ${spatialRank}D`);
}

// Wrong pads dimension
if (attributes.pads.length !== spatialRank * 2) {
throw new Error(`pads should be ${spatialRank * 2}D`);
}

// Wrong output padding dimension
if (attributes.outputPadding.length !== spatialRank) {
throw new Error(`output_padding should be ${spatialRank}D`);
}

// if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor
// (the first 2 dims are batch_size and channels)
if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) {
throw new Error('invalid kernel shape');
}

// as with kernelShape, must have same number of spatial dims as input
if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) {
throw new Error('invalid output shape');
}

// TODO : Need to add support for float64
if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') {
throw new Error('ConvTranspose input(X,W) should be float tensor');
}

if (inputs.length === 3 && inputs[2].type !== 'float32') {
throw new Error('ConvTranspose input(bias) should be float tensor');
}
};
4 changes: 2 additions & 2 deletions js/web/lib/onnxjs/backends/webgl/ops/dot-product.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types';

import {getActicationSnippet, InternalActivationAttributes} from './fuse-utils';
import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils';
import {calculateIm2ColDims} from './im2col';

const createDotProductProgramMetadata = (hasBias: boolean, attributes: InternalActivationAttributes) => ({
Expand Down Expand Up @@ -35,7 +35,7 @@ const createDotProductProgramInfo =

const initValue = (inputs.length < 3) ? '0.0' : '_B(b)';
const sharedDim = Math.ceil(xshape[1] * kshape[2] * kshape[3] / 4);
const {activationFunction, applyActivation} = getActicationSnippet(attributes);
const {activationFunction, applyActivation} = getActivationSnippet(attributes);
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const shaderSource = `
${activationFunction}
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export interface InternalActivationAttributes {
readonly activationCacheKey: string;
}

export function getActicationSnippet(attributes: InternalActivationAttributes) {
export function getActivationSnippet(attributes: InternalActivationAttributes) {
let func: GlslValueFunction;
switch (attributes.activation) {
case 'Relu':
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types';
import {getCoordsDataType, getGlChannels} from '../utils';

import {getActicationSnippet, InternalActivationAttributes} from './fuse-utils';
import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils';
import {getBiasForMatmul} from './matmul';

const createPackedMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({
Expand Down Expand Up @@ -41,7 +41,7 @@ const createPackedMatmulProgramInfo =
const coordsDataType = getCoordsDataType(outputShape.length);
const outRank = outputShape.length;
const allGlChannels = getGlChannels();
const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes);
const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes);

const getBiasForMatmulSnippet =
hasBias ? `${getBiasForMatmul(coordsDataType, allGlChannels, inputs[2].dims, outputShape, true)}` : '';
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/onnxjs/backends/webgl/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types';
import {getCoordsDataType, getGlChannels} from '../utils';

import {getActicationSnippet, InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils';
import {getActivationSnippet, InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils';
import {createPackedMatmulProgramInfoLoader} from './matmul-pack';

export const matMul: OperatorImplementation<InternalActivationAttributes> =
Expand Down Expand Up @@ -45,7 +45,7 @@ function createMatmulProgramInfo(
}
const coordsDataType = getCoordsDataType(outputShape.length);
const allGlChannels = getGlChannels();
const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes);
const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes);

const hasBias = inputs.length > 2;
const processBias = hasBias ? 'value += getBiasForMatmul();' : '';
Expand Down
8 changes: 8 additions & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@
"test_conv_with_strides_and_asymmetric_padding",
"test_conv_with_strides_no_padding",
"test_conv_with_strides_padding",
"test_convtranspose",
"test_convtranspose_pad",
"test_convtranspose_pads",
// TODO: add this when test-case file in opset v8 is fixed (i.e. output_shape has 2 dims)
// Might have to rewrite git history for that...
// "test_convtranspose_output_shape",
"test_convtranspose_kernel_shape",
"test_convtranspose_dilations",
"test_constant",
"test_cos_example",
"test_cos",
Expand Down

0 comments on commit 148b1ef

Please sign in to comment.