Skip to content

Commit

Permalink
modularize broadcastTo op (tensorflow#2919)
Browse files Browse the repository at this point in the history
INTERNAL
  • Loading branch information
tafsiri authored Mar 18, 2020
1 parent 68b830c commit f8e3007
Show file tree
Hide file tree
Showing 14 changed files with 312 additions and 151 deletions.
14 changes: 7 additions & 7 deletions tfjs-core/src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import {BackendTimingInfo, DataMover, KernelBackend} from './backends/backend';
import {Environment, setEnvironmentGlobal} from './environment';
import {getGradient, getKernel, getKernelsForBackend, NamedAttrMap, TensorInfo} from './kernel_registry';
import {getGradient, getKernel, getKernelsForBackend, GradFunc, NamedAttrMap, TensorInfo} from './kernel_registry';
import {Profiler} from './profiler';
import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode} from './tape';
import {backpropagateGradients, getFilteredNodesXToY, TapeNode} from './tape';
import {DataId, setTensorTracker, Tensor, TensorTracker, Variable} from './tensor';
import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
import {getTensorsInContainer} from './tensor_util';
Expand Down Expand Up @@ -465,7 +465,7 @@ export class Engine implements TensorTracker, DataMover {
const inputs = {x};
const grad = (dy: Tensor) => ({x: () => dy.toFloat()});
const saved: Tensor[] = [];
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved);
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
return y;
}

Expand Down Expand Up @@ -604,7 +604,8 @@ export class Engine implements TensorTracker, DataMover {
});

if (isTapeOn) {
this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved);
this.addTapeNode(
kernelName, inputs, outputs, backwardsFunc, saved, attrs);
}

if (this.state.profiling) {
Expand Down Expand Up @@ -798,8 +799,7 @@ export class Engine implements TensorTracker, DataMover {

private addTapeNode(
kernelName: string, inputs: NamedTensorMap, outputs: Tensor[],
gradientsFunc: (dy: Tensor|Tensor[], saved: Tensor[]) => NamedGradientMap,
saved: Tensor[]): void {
gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap): void {
const tapeNode: TapeNode =
{id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved};

Expand All @@ -821,7 +821,7 @@ export class Engine implements TensorTracker, DataMover {
});
// Grad functions of ops with single outputs expect a dy, while ops
// with multiple outputs expect dys (array of dy).
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved);
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
};
}
this.state.activeTape.push(tapeNode);
Expand Down
49 changes: 49 additions & 0 deletions tfjs-core/src/gradients/BroadcastTo_grad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {BroadcastTo, BroadCastToAttrs} from '../kernel_names';
import {GradConfig, NamedAttrMap} from '../kernel_registry';
import {Tensor} from '../tensor';

export const broadcastToGradConfig: GradConfig = {
kernelName: BroadcastTo,
gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {
const broadCastToAttrs: BroadCastToAttrs =
attrs as unknown as BroadCastToAttrs;

const inputShape = broadCastToAttrs.inputShape;
const outputShape = broadCastToAttrs.shape;

const reps: number[] = Array.from(outputShape);
for (let i = inputShape.length - 1; i >= 0; i--) {
if (inputShape[i] === outputShape[i]) {
reps[i] = 1;
} else if (inputShape[i] !== 1) {
throw new Error(`broadcastTo(): [${
inputShape}] cannot be broadcast to [${outputShape}].`);
}
}
const axes: number[] = [];
for (let i = 0; i < reps.length; i++) {
if (reps[i] > 1) {
axes.push(i);
}
}
const keepDims = true;
return {x: () => dy.sum(axes, keepDims)};
}
};
7 changes: 7 additions & 0 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ export interface NonMaxSuppressionV5Attrs {
softNmsSigma: number;
}

export const BroadcastTo = 'BroadcastTo';
export type BroadcastToInputs = Pick<NamedTensorInfoMap, 'x'>;
export interface BroadCastToAttrs {
shape: number[];
inputShape: number[]; // for gradient
}

/**
* TensorFlow.js-only kernels
*/
Expand Down
6 changes: 4 additions & 2 deletions tfjs-core/src/kernel_registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* =============================================================================
*/

import {NamedGradientMap} from './tape';
import {Tensor} from './tensor';
import {DataType, RecursiveArray} from './types';

Expand All @@ -37,8 +38,9 @@ export type KernelFunc = (params: {
}) => TensorInfo|TensorInfo[];

/** The function to run when computing a gradient during backprop. */
export type GradFunc = (dy: Tensor|Tensor[], saved: Tensor[]) =>
({[inputName: string]: () => Tensor});
export type GradFunc =
(dy: Tensor|Tensor[], saved: Tensor[], attrs: NamedAttrMap) =>
NamedGradientMap;

/** Function that gets called after the backend initializes. */
export type KernelSetupFunc = (backend: {}) => void;
Expand Down
57 changes: 0 additions & 57 deletions tfjs-core/src/ops/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,62 +26,6 @@ import {op} from './operation';
import {MPRandGauss, RandGamma, UniformRandom} from './rand';
import {zeros, zerosLike} from './tensor_ops';

/**
* Broadcast an array to a compatible shape NumPy-style.
*
* The tensor's shape is compared to the broadcast shape from end to beginning.
* Ones are prepended to the tensor's shape until is has the same length as
* the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
* already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
* the input tensor is tiled N times along that axis (using tf.tile).
*
* @param input The tensor that is to be broadcasted.
* @param shape The input is to be broadcast to this shape.
*/
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function broadcastTo_<R extends Rank>(
x: Tensor|TensorLike, shape: ShapeMap[R]): Tensor<R> {
let input = convertToTensor(x, 'broadcastTo', 'x');
const xShape = input.shape;

if (shape.some(d => !(d > 0) || d % 1 !== 0)) {
throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);
}

if (shape.length < input.rank) {
throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${
input.rank}.`);
}

if (shape.length > input.rank) {
const newShape = input.shape.slice();
while (newShape.length < shape.length) {
newShape.unshift(1);
}
input = input.reshape(newShape);
}

const reps: number[] = Array.from(shape);
for (let i = shape.length - 1; i >= 0; i--) {
if (input.shape[i] === shape[i]) {
reps[i] = 1;
} else if (input.shape[i] !== 1) {
throw new Error(
`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
}
}
const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);

if (axes.length === 0) {
return input.clone() as Tensor<R>;
}

return ENGINE.runKernelFunc(
backend => backend.tile(input, reps), {input},
(dy: Tensor) =>
({input: () => dy.sum(axes, /*keepDims=*/true)})) as Tensor<R>;
}

/**
* Creates a new tensor with the same values and shape as the specified
* tensor.
Expand Down Expand Up @@ -1186,7 +1130,6 @@ export {
};

export const batchToSpaceND = op({batchToSpaceND_});
export const broadcastTo = op({broadcastTo_});
export const cast = op({cast_});
export const clone = op({clone_});
export const cumsum = op({cumsum_});
Expand Down
85 changes: 0 additions & 85 deletions tfjs-core/src/ops/array_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,95 +19,10 @@ import * as tf from '../index';
import {ALL_ENVS, BROWSER_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util';
import {expectArraysClose, expectArraysEqual, expectPromiseToFail, expectValuesInRange} from '../test_util';
import {TypedArray} from '../types';
import {Tensor} from '../tensor';
import * as util from '../util';

import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util';

describeWithFlags('broadcastTo', ALL_ENVS, () => {
it('[] -> [3,2]', async () => {
const a = tf.scalar(4.2);
const A = tf.tensor2d([[4.2, 4.2],
[4.2, 4.2],
[4.2, 4.2]]);

expectArraysClose(
await A.array(),
await tf.broadcastTo(a,A.shape).array()
);

// test gradients
const w = tf.tensor2d([[ 4.7, 4.5],
[-6.1,-6.6],
[-8.1,-3.4]]),
f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(),
h = (a: Tensor) => a.mul(w).mean().asScalar();

const df = tf.grad(f),
dh = tf.grad(h);

expectArraysClose(
await df(a).array(),
await dh(a).array()
);
});

it('[2] -> [3,2]', async () => {
const a = tf.tensor1d( [1,2] );
const A = tf.tensor2d([[1,2],
[1,2],
[1,2]]);
expectArraysClose(
await A.array(),
await tf.broadcastTo(a,A.shape).array()
);

// test gradients
const w = tf.tensor2d([[ 4.7, 4.5],
[-6.1,-6.6],
[-8.1,-3.4]]),
f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(),
h = (a: Tensor) => a.mul(w).mean().asScalar();

const df = tf.grad(f),
dh = tf.grad(h);

expectArraysClose(
await df(a).array(),
await dh(a).array()
);
});

it('[3,1] -> [3,2]', async () => {
const a = tf.tensor2d([[1],
[2],
[3]]);
const A = tf.tensor2d([[1,1],
[2,2],
[3,3]]);

expectArraysClose(
await A.array(),
await tf.broadcastTo(a,A.shape).array()
);

// test gradients
const w = tf.tensor2d([[ 4.7, 4.5],
[-6.1,-6.6],
[-8.1,-3.4]]),
f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(),
h = (a: Tensor) => a.mul(w).mean().asScalar();

const df = tf.grad(f),
dh = tf.grad(h);

expectArraysClose(
await df(a).array(),
await dh(a).array()
);
});
});

describeWithFlags('zeros', ALL_ENVS, () => {
it('1D default dtype', async () => {
const a: tf.Tensor1D = tf.zeros([3]);
Expand Down
92 changes: 92 additions & 0 deletions tfjs-core/src/ops/broadcast_to.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelBackend} from '../backends/backend';
import {ENGINE} from '../engine';
import {BroadcastTo, BroadCastToAttrs, BroadcastToInputs} from '../kernel_names';
import {NamedAttrMap} from '../kernel_registry';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {Rank, ShapeMap, TensorLike} from '../types';

import {op} from './operation';

/**
* Broadcast an array to a compatible shape NumPy-style.
*
* The tensor's shape is compared to the broadcast shape from end to beginning.
* Ones are prepended to the tensor's shape until is has the same length as
* the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
* already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
* the input tensor is tiled N times along that axis (using tf.tile).
*
* @param input The tensor that is to be broadcasted.
* @param shape The input is to be broadcast to this shape.
*/
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function broadcastTo_<R extends Rank>(
x: Tensor|TensorLike, shape: ShapeMap[R]): Tensor<R> {
let input = convertToTensor(x, 'broadcastTo', 'x');
const xShape = input.shape;

if (shape.some(d => !(d > 0) || d % 1 !== 0)) {
throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);
}

if (shape.length < input.rank) {
throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${
input.rank}.`);
}

if (shape.length > input.rank) {
const newShape = input.shape.slice();
while (newShape.length < shape.length) {
newShape.unshift(1);
}
input = input.reshape(newShape);
}

const inputShape = input.shape;
const reps: number[] = Array.from(shape);
for (let i = shape.length - 1; i >= 0; i--) {
if (inputShape[i] === shape[i]) {
reps[i] = 1;
} else if (input.shape[i] !== 1) {
throw new Error(
`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
}
}
const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);

if (axes.length === 0) {
return input.clone() as Tensor<R>;
}

const forward = (backend: KernelBackend) => backend.tile(input, reps);
const keepDims = true;
const backward = (dy: Tensor) => ({x: () => dy.sum(axes, keepDims)});

const inputs: BroadcastToInputs = {x: input};
const attrs: BroadCastToAttrs = {shape, inputShape};

return ENGINE.runKernelFunc(
forward, inputs as unknown as NamedTensorMap, backward,
BroadcastTo, attrs as unknown as NamedAttrMap) as Tensor<R>;
}

export const broadcastTo = op({broadcastTo_});
Loading

0 comments on commit f8e3007

Please sign in to comment.