From 355bb0b6c35d82968b42693f6c3cdafaa65c771e Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Mon, 19 Aug 2019 15:38:28 -0400 Subject: [PATCH] [tfjs-layers] Fix a silent error in L1/L2/L1L2 regularizer constructor (#1861) Re. https://github.com/tensorflow/tfjs/issues/1773 BUG Previosly, if you pass a number argument to `tf.regularizers.l2()`, for example, no error will be thrown, even though the specified number will not take effect as the regularization coefficient. This PR fixes that by explicitly checking the type of the arg. --- tfjs-layers/src/regularizers.ts | 12 ++++++++++++ tfjs-layers/src/regularizers_test.ts | 11 +++++++++++ 2 files changed, 23 insertions(+) diff --git a/tfjs-layers/src/regularizers.ts b/tfjs-layers/src/regularizers.ts index 378329278ec..7bb0e78d329 100644 --- a/tfjs-layers/src/regularizers.ts +++ b/tfjs-layers/src/regularizers.ts @@ -15,6 +15,14 @@ import {abs, add, Scalar, serialization, sum, Tensor, tidy, zeros} from '@tensor import * as K from './backend/tfjs_backend'; import {deserializeKerasObject, serializeKerasObject} from './utils/generic_utils'; +function assertObjectArgs(args: L1Args | L2Args | L1L2Args): void { + if (args != null && typeof args !== 'object') { + throw new Error( + `Argument to L1L2 regularizer's constructor is expected to be an ` + + `object, but received: ${args}`); + } +} + /** * Regularizer base class. */ @@ -50,6 +58,8 @@ export class L1L2 extends Regularizer { constructor(args?: L1L2Args) { super(); + assertObjectArgs(args); + this.l1 = args == null || args.l1 == null ? 0.01 : args.l1; this.l2 = args == null || args.l2 == null ? 0.01 : args.l2; this.hasL1 = this.l1 !== 0; @@ -88,10 +98,12 @@ export class L1L2 extends Regularizer { serialization.registerClass(L1L2); export function l1(args?: L1Args) { + assertObjectArgs(args); return new L1L2({l1: args != null ? args.l1 : null, l2: 0}); } export function l2(args: L2Args) { + assertObjectArgs(args); return new L1L2({l2: args != null ? args.l2 : null, l1: 0}); } diff --git a/tfjs-layers/src/regularizers_test.ts b/tfjs-layers/src/regularizers_test.ts index 04ebb0aa6a3..21ed06a7a22 100644 --- a/tfjs-layers/src/regularizers_test.ts +++ b/tfjs-layers/src/regularizers_test.ts @@ -44,6 +44,17 @@ describeMathCPU('Built-in Regularizers', () => { expectTensorsClose( score, scalar(1 * (1 + 2 + 3 + 4) + 2 * (1 + 4 + 9 + 16))); }); + it('Using number arg for constructor leads to error', () => { + // tslint:disable-next-line:no-any + expect(() => tfl.regularizers.l1(0.001 as any)) + .toThrowError(/expected.*object.*received.*0\.001/); + // tslint:disable-next-line:no-any + expect(() => tfl.regularizers.l2(0.001 as any)) + .toThrowError(/expected.*object.*received.*0\.001/); + // tslint:disable-next-line:no-any + expect(() => tfl.regularizers.l1l2(0.001 as any)) + .toThrowError(/expected.*object.*received.*0\.001/); + }); }); describeMathCPU('regularizers.get', () => {