-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[js/web] add ConvTranspose2D to WebGL backend (#11990)
* Add ConvTranspose * Update docs + tests * fix lint * fix output shape calculations * Revert "fix output shape calculations" This reverts commit 8014fa9. * fix format * remove broken output_shape test
- Loading branch information
Showing
9 changed files
with
279 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
import {createAttributeWithCacheKey} from '../../../attribute-with-cache-key'; | ||
import {InferenceHandler} from '../../../backend'; | ||
import {Graph} from '../../../graph'; | ||
import {OperatorImplementation, OperatorInitialization} from '../../../operators'; | ||
import {Tensor} from '../../../tensor'; | ||
import {getGlsl} from '../glsl-source'; | ||
import {WebGLInferenceHandler} from '../inference-handler'; | ||
import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types'; | ||
|
||
import {ConvAttributes} from './conv'; | ||
import {getActivationSnippet, parseInternalActivationAttributes} from './fuse-utils'; | ||
|
||
const computeTotalPad = | ||
(inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) => | ||
(inDim - 1) * stride + adj + (kernel - 1) * dilation + 1 - outSize; | ||
|
||
const distributePadding = (totalPad: number, autoPad: string, pads: number[], head: number, tail: number) => { | ||
const smallPad = Math.floor(totalPad / 2); | ||
if (autoPad === 'SAME_UPPER') { | ||
pads[head] = smallPad; | ||
pads[tail] = totalPad - smallPad; | ||
} else if (autoPad === 'SAME_LOWER') { | ||
pads[head] = totalPad - smallPad; | ||
pads[tail] = smallPad; | ||
} | ||
}; | ||
|
||
const calculateOutputShapeAndPads = | ||
(inputShape: readonly number[], kernelShape: readonly number[], dilations: readonly number[], autoPad: string, | ||
pads: number[], strides: readonly number[], outputPadding: readonly number[], outputShape: number[]) => { | ||
const spatialRank = inputShape.length - 2; | ||
const updateShape = outputShape.length === 0; | ||
for (let i = 0; i < spatialRank; ++i) { | ||
const outSize = updateShape ? inputShape[i + 2] * strides[i] : outputShape[i]; | ||
const totalPad = computeTotalPad(inputShape[i + 2], strides[i], pads[i], kernelShape[i], dilations[i], outSize); | ||
distributePadding(totalPad, autoPad, pads, i, i + spatialRank); | ||
if (updateShape) { | ||
outputShape.push( | ||
strides[i] * (inputShape[i + 2] - 1) + outputPadding[i] + (kernelShape[i] - 1) * dilations[i] + 1 - | ||
pads[i] - pads[i + spatialRank]); | ||
} | ||
} | ||
}; | ||
|
||
export interface ConvTransposeAttributes extends ConvAttributes { | ||
readonly outputPadding: readonly number[]; | ||
readonly outputShape: readonly number[]; | ||
} | ||
|
||
export const convTranspose: OperatorImplementation<ConvTransposeAttributes> = | ||
(inferenceHandler: InferenceHandler, inputs: Tensor[], attributes: ConvTransposeAttributes): Tensor[] => { | ||
validateInputs(inputs, attributes); // currently will fail if not convTranspose2D | ||
return convTranspose2d(inferenceHandler, inputs, attributes); | ||
}; | ||
|
||
const convTranspose2d: OperatorImplementation<ConvTransposeAttributes> = | ||
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ConvTransposeAttributes): Tensor[] => { | ||
const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); | ||
return [convTranspose2DUnpacked(inferenceHandler, inputs, adjustedAttributes)]; | ||
}; | ||
|
||
const createConvTransposeProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ | ||
name: 'ConvTranspose', | ||
inputNames: hasBias ? ['X', 'W', 'B'] : ['X', 'W'], | ||
inputTypes: hasBias ? [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked] : | ||
[TextureType.unpacked, TextureType.unpacked], | ||
cacheHint | ||
}); | ||
|
||
const createUnpackedConvTransposeProgramInfo = | ||
(inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], metadata: ProgramMetadata, | ||
attributes: ConvTransposeAttributes): ProgramInfo => { | ||
const hasBias = inputs.length > 2; | ||
const valueInit = hasBias ? 'getB(output_channel)' : '0.0'; | ||
const xShape = inputs[0].dims; | ||
const wShape = inputs[1].dims; | ||
const outputChannelsPerGroup = wShape[1]; | ||
const inputChannelsPerGroup = wShape[0] / attributes.group; | ||
const outputShape = [inputs[0].dims[0], inputs[1].dims[1] * attributes.group, ...attributes.outputShape]; | ||
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); | ||
const {activationFunction, applyActivation} = getActivationSnippet(attributes); | ||
|
||
const shaderSource = ` | ||
const ivec2 strides = ivec2(${attributes.strides[0]}, ${attributes.strides[1]}); | ||
const ivec2 pads = ivec2(${attributes.pads[0]}, ${attributes.pads[1]}); | ||
${activationFunction} | ||
void main() { | ||
ivec4 coords = getOutputCoords(); | ||
int batch = coords.x; | ||
int output_channel = coords.y; | ||
ivec2 loc = coords.zw + pads; | ||
int group_id = output_channel / ${outputChannelsPerGroup}; | ||
int wOutChannel = output_channel - group_id * ${outputChannelsPerGroup}; | ||
float value = ${valueInit}; | ||
for (int inChannelOffset = 0; inChannelOffset < ${inputChannelsPerGroup}; inChannelOffset++) { | ||
int input_channel = group_id * ${inputChannelsPerGroup} + inChannelOffset; | ||
for (int wWOff = 0; wWOff < ${wShape[2]}; wWOff++) { | ||
for (int wHOff = 0; wHOff < ${wShape[3]}; wHOff++) { | ||
ivec2 wOff = ivec2(wWOff * ${attributes.dilations[0]}, wHOff * ${attributes.dilations[1]}); | ||
ivec2 wLoc = loc - wOff; | ||
ivec2 wLocIn = wLoc / strides; | ||
if ( | ||
wLocIn * strides == wLoc && | ||
wLocIn.x >= 0 && wLocIn.x < ${xShape[2]} && | ||
wLocIn.y >= 0 && wLocIn.y < ${xShape[3]} | ||
) { | ||
float xVal = getX(batch, input_channel, wLocIn.y, wLocIn.x); | ||
float wVal = getW(input_channel, wOutChannel, wHOff, wWOff); | ||
value += xVal * wVal; | ||
} | ||
} | ||
} | ||
} | ||
${applyActivation} | ||
${glsl.output} = vec4(value, .0, .0, .0); | ||
} | ||
`; | ||
return { | ||
...metadata, | ||
output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked}, | ||
shaderSource, | ||
hasMain: true, | ||
}; | ||
}; | ||
|
||
const createUnpackedConvTransposeProgramInfoLoader = | ||
(inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvTransposeAttributes): | ||
ProgramInfoLoader => { | ||
const metadata = createConvTransposeProgramMetadata(inputs.length > 2, attributes.cacheKey); | ||
return { | ||
...metadata, | ||
get: () => createUnpackedConvTransposeProgramInfo(inferenceHandler, inputs, metadata, attributes) | ||
}; | ||
}; | ||
|
||
|
||
const convTranspose2DUnpacked = | ||
(inferenceHandler: WebGLInferenceHandler, inputs: readonly Tensor[], attributes: ConvTransposeAttributes): | ||
Tensor => { | ||
const result = inferenceHandler.run( | ||
createUnpackedConvTransposeProgramInfoLoader(inferenceHandler, inputs, attributes), inputs); | ||
return result; | ||
}; | ||
|
||
const getAdjustedConvTransposeAttributes = <T extends ConvTransposeAttributes>(attributes: T, inputs: Tensor[]): T => { | ||
const kernelShape = attributes.kernelShape.slice(); | ||
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims | ||
if (attributes.kernelShape.length === 0) { | ||
for (let i = 2; i < inputs[1].dims.length; ++i) { | ||
kernelShape.push(inputs[1].dims[i]); | ||
} | ||
} | ||
|
||
const pads = attributes.pads.slice(); | ||
const outputShape = attributes.outputShape.slice(); | ||
const inputShape = inputs[0].dims; | ||
// If outputShape is not specified in the attributes of this op, infer it from the parameters | ||
// Similarly, automatically infer pads if not specified | ||
calculateOutputShapeAndPads( | ||
inputShape, kernelShape, attributes.dilations, attributes.autoPad, pads, attributes.strides, | ||
attributes.outputPadding, outputShape); | ||
|
||
// always return a new object so does not modify the original attributes | ||
const newAttributes: T = Object.assign({}, attributes); | ||
Object.assign(newAttributes, {kernelShape, pads, outputShape, cacheKey: attributes.cacheKey}); | ||
return newAttributes; | ||
}; | ||
|
||
export const parseConvTransposeAttributes: OperatorInitialization<ConvTransposeAttributes> = | ||
(node: Graph.Node): ConvTransposeAttributes => { | ||
const attributes = node.attributes; | ||
const activationAttributes = parseInternalActivationAttributes(attributes); | ||
// TODO : Make this generic enough to compute default attributes for multi-dimensional conv | ||
const autoPad = attributes.getString('auto_pad', 'NOTSET'); | ||
const dilations = attributes.getInts('dilations', [1, 1]); | ||
const group = attributes.getInt('group', 1); | ||
const kernelShape = attributes.getInts('kernel_shape', []); | ||
const outputPadding = attributes.getInts('output_padding', [0, 0]); | ||
const outputShape = attributes.getInts('output_shape', []); | ||
const pads = attributes.getInts('pads', [0, 0, 0, 0]); | ||
const strides = attributes.getInts('strides', [1, 1]); | ||
|
||
return createAttributeWithCacheKey( | ||
{autoPad, dilations, group, kernelShape, outputPadding, outputShape, pads, strides, ...activationAttributes}); | ||
}; | ||
|
||
const validateInputs = (inputs: Tensor[], attributes: ConvTransposeAttributes): void => { | ||
// Refer to the below link for all input checks | ||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv | ||
if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) { | ||
throw new Error('Conv requires 2 or 3 inputs'); | ||
} | ||
|
||
// TODO : Need to add support for multi-dimensional conv | ||
if (inputs[0].dims.length !== 4 || inputs[1].dims.length !== 4) { | ||
throw new Error('currently only support 2-dimensional conv'); | ||
} | ||
|
||
// FILTER_IN_CHANNEL should be equal to DATA_CHANNEL | ||
const dataChannel = inputs[0].dims[1]; | ||
const filterInChannel = inputs[1].dims[0]; | ||
if (dataChannel !== filterInChannel) { | ||
throw new Error('FILTER_IN_CHANNEL should be equal to DATA_CHANNEL'); | ||
} | ||
|
||
const featureMaps = inputs[1].dims[1] * attributes.group; | ||
|
||
// if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps | ||
if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[2].dims[0] !== featureMaps)) { | ||
throw new Error('invalid bias'); | ||
} | ||
|
||
const spatialRank = inputs[0].dims.length - 2; | ||
// wrong dilations dimension | ||
if (attributes.dilations.length !== spatialRank) { | ||
throw new Error(`dilations should be ${spatialRank}D`); | ||
} | ||
|
||
// Wrong strides dimension | ||
if (attributes.strides.length !== spatialRank) { | ||
throw new Error(`strides should be ${spatialRank}D`); | ||
} | ||
|
||
// Wrong pads dimension | ||
if (attributes.pads.length !== spatialRank * 2) { | ||
throw new Error(`pads should be ${spatialRank * 2}D`); | ||
} | ||
|
||
// Wrong output padding dimension | ||
if (attributes.outputPadding.length !== spatialRank) { | ||
throw new Error(`output_padding should be ${spatialRank}D`); | ||
} | ||
|
||
// if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor | ||
// (the first 2 dims are batch_size and channels) | ||
if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) { | ||
throw new Error('invalid kernel shape'); | ||
} | ||
|
||
// as with kernelShape, must have same number of spatial dims as input | ||
if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) { | ||
throw new Error('invalid output shape'); | ||
} | ||
|
||
// TODO : Need to add support for float64 | ||
if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') { | ||
throw new Error('ConvTranspose input(X,W) should be float tensor'); | ||
} | ||
|
||
if (inputs.length === 3 && inputs[2].type !== 'float32') { | ||
throw new Error('ConvTranspose input(bias) should be float tensor'); | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters