Skip to content

Commit

Permalink
[tfjs-layers] Fix a silent error in L1/L2/L1L2 regularizer constructor (
Browse files Browse the repository at this point in the history
#1861)

Re. #1773

BUG

Previosly, if you pass a number argument to `tf.regularizers.l2()`,
for example, no error will be thrown, even though the specified number
will not take effect as the regularization coefficient.

This PR fixes that by explicitly checking the type of the arg.
  • Loading branch information
caisq authored Aug 19, 2019
1 parent 20e5420 commit 355bb0b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tfjs-layers/src/regularizers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ import {abs, add, Scalar, serialization, sum, Tensor, tidy, zeros} from '@tensor
import * as K from './backend/tfjs_backend';
import {deserializeKerasObject, serializeKerasObject} from './utils/generic_utils';

function assertObjectArgs(args: L1Args | L2Args | L1L2Args): void {
if (args != null && typeof args !== 'object') {
throw new Error(
`Argument to L1L2 regularizer's constructor is expected to be an ` +
`object, but received: ${args}`);
}
}

/**
* Regularizer base class.
*/
Expand Down Expand Up @@ -50,6 +58,8 @@ export class L1L2 extends Regularizer {
constructor(args?: L1L2Args) {
super();

assertObjectArgs(args);

this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
this.hasL1 = this.l1 !== 0;
Expand Down Expand Up @@ -88,10 +98,12 @@ export class L1L2 extends Regularizer {
serialization.registerClass(L1L2);

export function l1(args?: L1Args) {
assertObjectArgs(args);
return new L1L2({l1: args != null ? args.l1 : null, l2: 0});
}

export function l2(args: L2Args) {
assertObjectArgs(args);
return new L1L2({l2: args != null ? args.l2 : null, l1: 0});
}

Expand Down
11 changes: 11 additions & 0 deletions tfjs-layers/src/regularizers_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ describeMathCPU('Built-in Regularizers', () => {
expectTensorsClose(
score, scalar(1 * (1 + 2 + 3 + 4) + 2 * (1 + 4 + 9 + 16)));
});
it('Using number arg for constructor leads to error', () => {
// tslint:disable-next-line:no-any
expect(() => tfl.regularizers.l1(0.001 as any))
.toThrowError(/expected.*object.*received.*0\.001/);
// tslint:disable-next-line:no-any
expect(() => tfl.regularizers.l2(0.001 as any))
.toThrowError(/expected.*object.*received.*0\.001/);
// tslint:disable-next-line:no-any
expect(() => tfl.regularizers.l1l2(0.001 as any))
.toThrowError(/expected.*object.*received.*0\.001/);
});
});

describeMathCPU('regularizers.get', () => {
Expand Down

0 comments on commit 355bb0b

Please sign in to comment.