Skip to content

Commit

Permalink
[layers] Properly support channelsFirst dataFormat in Flatten layer (#…
Browse files Browse the repository at this point in the history
…2346)

* [layers] Properly support channelsFirst dataFormat in Flatten layer

BUG

Fixes: #2205
  • Loading branch information
caisq authored Nov 7, 2019
1 parent 6382f9b commit a1f6a78
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 14 deletions.
4 changes: 2 additions & 2 deletions tfjs-layers/src/exports_layers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {input} from './exports';
import {ELU, ELULayerArgs, LeakyReLU, LeakyReLULayerArgs, PReLU, PReLULayerArgs, ReLU, ReLULayerArgs, Softmax, SoftmaxLayerArgs, ThresholdedReLU, ThresholdedReLULayerArgs} from './layers/advanced_activations';
import {Conv1D, Conv2D, Conv2DTranspose, Conv3D, ConvLayerArgs, Cropping2D, Cropping2DLayerArgs, SeparableConv2D, SeparableConvLayerArgs, UpSampling2D, UpSampling2DLayerArgs} from './layers/convolutional';
import {DepthwiseConv2D, DepthwiseConv2DLayerArgs} from './layers/convolutional_depthwise';
import {Activation, ActivationLayerArgs, Dense, DenseLayerArgs, Dropout, DropoutLayerArgs, Flatten, Masking, MaskingArgs, Permute, PermuteLayerArgs, RepeatVector, RepeatVectorLayerArgs, Reshape, ReshapeLayerArgs} from './layers/core';
import {Activation, ActivationLayerArgs, Dense, DenseLayerArgs, Dropout, DropoutLayerArgs, Flatten, FlattenLayerArgs, Masking, MaskingArgs, Permute, PermuteLayerArgs, RepeatVector, RepeatVectorLayerArgs, Reshape, ReshapeLayerArgs} from './layers/core';
import {Embedding, EmbeddingLayerArgs} from './layers/embeddings';
import {Add, Average, Concatenate, ConcatenateLayerArgs, Dot, DotLayerArgs, Maximum, Minimum, Multiply} from './layers/merge';
import {AlphaDropout, AlphaDropoutArgs, GaussianDropout, GaussianDropoutArgs, GaussianNoise, GaussianNoiseArgs} from './layers/noise';
Expand Down Expand Up @@ -554,7 +554,7 @@ export function dropout(args: DropoutLayerArgs): Layer {
* ```
*/
/** @doc {heading: 'Layers', subheading: 'Basic', namespace: 'layers'} */
export function flatten(args?: LayerArgs): Layer {
export function flatten(args?: FlattenLayerArgs): Layer {
return new Flatten(args);
}

Expand Down
38 changes: 34 additions & 4 deletions tfjs-layers/src/layers/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import {DisposeResult, InputSpec, Layer, LayerArgs} from '../engine/topology';
import {ValueError} from '../errors';
import {getInitializer, Initializer, InitializerIdentifier, serializeInitializer} from '../initializers';
import {ActivationIdentifier} from '../keras_format/activation_config';
import {Shape} from '../keras_format/common';
import {DataFormat, Shape} from '../keras_format/common';
import {getRegularizer, Regularizer, RegularizerIdentifier, serializeRegularizer} from '../regularizers';
import {Kwargs} from '../types';
import {assertPositiveInteger, mapActivationToFusedKernel} from '../utils/generic_utils';
Expand Down Expand Up @@ -284,12 +284,21 @@ export class Dense extends Layer {
}
serialization.registerClass(Dense);

export declare interface FlattenLayerArgs extends LayerArgs {
/** Image data format: channeLast (default) or channelFirst. */
dataFormat?: DataFormat;
}

export class Flatten extends Layer {
private dataFormat: DataFormat;

/** @nocollapse */
static className = 'Flatten';
constructor(args?: LayerArgs) {
super(args || {});
constructor(args?: FlattenLayerArgs) {
args = args || {};
super(args);
this.inputSpec = [{minNDim: 3}];
this.dataFormat = args.dataFormat;
}

computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {
Expand All @@ -309,9 +318,30 @@ export class Flatten extends Layer {
call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {
return tidy(() => {
this.invokeCallHook(inputs, kwargs);
return K.batchFlatten(getExactlyOneTensor(inputs));

let input = getExactlyOneTensor(inputs);
if (this.dataFormat === 'channelsFirst' && input.rank > 1) {
const permutation: number[] = [0];
for (let i = 2; i < input.rank; ++i) {
permutation.push(i);
}
permutation.push(1);
input = input.transpose(permutation);
}

return K.batchFlatten(input);
});
}

getConfig(): serialization.ConfigDict {
const config: serialization.ConfigDict = {};
if (this.dataFormat != null) {
config['dataFormat'] = this.dataFormat;
}
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
serialization.registerClass(Flatten);

Expand Down
21 changes: 21 additions & 0 deletions tfjs-layers/src/layers/core_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,14 @@ describe('Flatten Layer: Symbolic', () => {
const x = new tfl.SymbolicTensor('float32', [8, 4, null], null, [], null);
expect(() => flattenLayer.apply(x)).toThrowError(/not fully defined/);
});
it('Serialization round trip', () => {
const layer = tfl.layers.flatten({dataFormat: 'channelsFirst'});
const pythonicConfig = convertTsToPythonic(layer.getConfig());
// tslint:disable-next-line:no-any
const tsConfig = convertPythonicToTs(pythonicConfig) as any;
const layerPrime = tfl.layers.flatten(tsConfig);
expect(layerPrime.getConfig().dataFormat).toEqual('channelsFirst');
});
});

describeMathCPUAndGPU('Flatten Layer: Tensor', () => {
Expand Down Expand Up @@ -428,6 +436,19 @@ describeMathCPUAndGPU('Flatten Layer: Tensor', () => {
[2, 8]);
expectTensorsClose(flattenLayer.apply(x, null) as Tensor, expectedOutput);
});
it('Flattens Tensor4D, channelFirst', () => {
const flattenLayer = tfl.layers.flatten({dataFormat: 'channelsFirst'});
const x = tensor4d(
[
[[[10, 20], [30, 40]], [[-10, -20], [-30, -40]]],
[[[1, 2], [3, 4]], [[-1, -2], [-3, -4]]]
],
[2, 2, 2, 2]);
const expectedOutput = tensor2d(
[10, -10, 20, -20, 30, -30, 40, -40, 1, -1, 2, -2, 3, -3, 4, -4],
[2, 8]);
expectTensorsClose(flattenLayer.apply(x, null) as Tensor, expectedOutput);
});
});

describeMathCPUAndGPU('Activation Layer: Tensor', () => {
Expand Down
16 changes: 8 additions & 8 deletions tfjs-layers/src/model_save_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ describeMathCPUAndGPU('LayersModel.save', () => {

await model.save(handler);
expect(handler.savedArtifacts.format).toEqual('layers-model');
expect(handler.savedArtifacts.generatedBy).toEqual(
`TensorFlow.js tfjs-layers v${version}`);
expect(handler.savedArtifacts.generatedBy)
.toEqual(`TensorFlow.js tfjs-layers v${version}`);
expect(handler.savedArtifacts.convertedBy).toEqual(null);
});

Expand All @@ -49,8 +49,8 @@ describeMathCPUAndGPU('LayersModel.save', () => {

await model.save(handler);
expect(handler.savedArtifacts.format).toEqual('layers-model');
expect(handler.savedArtifacts.generatedBy).toEqual(
`TensorFlow.js tfjs-layers v${version}`);
expect(handler.savedArtifacts.generatedBy)
.toEqual(`TensorFlow.js tfjs-layers v${version}`);
expect(handler.savedArtifacts.convertedBy).toEqual(null);
});

Expand Down Expand Up @@ -443,8 +443,8 @@ describeMathGPU('Save-load round trips', () => {

const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough();
const strict = false;
const modelPrime = await tfl.loadLayersModel(
io.fromMemory(savedArtifacts), {strict});
const modelPrime =
await tfl.loadLayersModel(io.fromMemory(savedArtifacts), {strict});
const weightsPrime = modelPrime.getWeights();
expect(weightsPrime.length).toEqual(weights.length);
expectTensorsClose(weightsPrime[0], weights[0]);
Expand Down Expand Up @@ -474,8 +474,8 @@ describeMathGPU('Save-load round trips', () => {

const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough();
const strict = false;
const modelPrime = await tfl.loadLayersModel(
io.fromMemory(savedArtifacts), {strict});
const modelPrime =
await tfl.loadLayersModel(io.fromMemory(savedArtifacts), {strict});
const weightsPrime = modelPrime.getWeights();
expect(weightsPrime.length).toEqual(weights.length);
expectTensorsClose(weightsPrime[0], weights[0]);
Expand Down

0 comments on commit a1f6a78

Please sign in to comment.