Skip to content

Commit

Permalink
Modularize random distribution tensor creation ops (tensorflow#2939)
Browse files Browse the repository at this point in the history
INTERNAL
  • Loading branch information
tafsiri authored Mar 23, 2020
1 parent 44fe76d commit 7f6c78a
Show file tree
Hide file tree
Showing 17 changed files with 1,225 additions and 992 deletions.
213 changes: 2 additions & 211 deletions tfjs-core/src/ops/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import * as util from '../util';
import {getAxesPermutation, getInnerMostAxes} from './axis_util';
import {concat} from './concat_split';
import {op} from './operation';
import {MPRandGauss, RandGamma, UniformRandom} from './rand';
import {zeros, zerosLike} from './tensor_ops';

/**
Expand Down Expand Up @@ -101,207 +100,6 @@ function eye_(
}
}

/**
* Creates a `tf.Tensor` with values sampled from a normal distribution.
*
* ```js
* tf.randomNormal([2, 2]).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param mean The mean of the normal distribution.
* @param stdDev The standard deviation of the normal distribution.
* @param dtype The data type of the output.
* @param seed The seed for the random number generator.
*/
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function randomNormal_<R extends Rank>(
shape: ShapeMap[R], mean = 0, stdDev = 1, dtype?: 'float32'|'int32',
seed?: number): Tensor<R> {
if (dtype != null && (dtype as DataType) === 'bool') {
throw new Error(`Unsupported data type ${dtype}`);
}
const randGauss =
new MPRandGauss(mean, stdDev, dtype, false /* truncated */, seed);
const res = buffer(shape, dtype);
for (let i = 0; i < res.values.length; i++) {
res.values[i] = randGauss.nextValue();
}
return res.toTensor();
}

/**
* Creates a `tf.Tensor` with values sampled from a truncated normal
* distribution.
*
* ```js
* tf.truncatedNormal([2, 2]).print();
* ```
*
* The generated values follow a normal distribution with specified mean and
* standard deviation, except that values whose magnitude is more than 2
* standard deviations from the mean are dropped and re-picked.
*
* @param shape An array of integers defining the output tensor shape.
* @param mean The mean of the normal distribution.
* @param stdDev The standard deviation of the normal distribution.
* @param dtype The data type of the output tensor.
* @param seed The seed for the random number generator.
*/
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function truncatedNormal_<R extends Rank>(
shape: ShapeMap[R], mean = 0, stdDev = 1, dtype?: 'float32'|'int32',
seed?: number): Tensor<R> {
if (dtype != null && (dtype as DataType) === 'bool') {
throw new Error(`Unsupported data type ${dtype}`);
}
const randGauss =
new MPRandGauss(mean, stdDev, dtype, true /* truncated */, seed);
const res = buffer(shape, dtype);
for (let i = 0; i < res.values.length; i++) {
res.values[i] = randGauss.nextValue();
}
return res.toTensor();
}

/**
* Creates a `tf.Tensor` with values sampled from a gamma distribution.
*
* ```js
* tf.randomGamma([2, 2], 1).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param alpha The shape parameter of the gamma distribution.
* @param beta The inverse scale parameter of the gamma distribution. Defaults
* to 1.
* @param dtype The data type of the output. Defaults to float32.
* @param seed The seed for the random number generator.
*/
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function randomGamma_<R extends Rank>(
shape: ShapeMap[R], alpha: number, beta = 1,
dtype: 'float32'|'int32' = 'float32', seed?: number): Tensor<R> {
if (beta == null) {
beta = 1;
}
if (dtype == null) {
dtype = 'float32';
}
if (dtype !== 'float32' && dtype !== 'int32') {
throw new Error(`Unsupported data type ${dtype}`);
}
const rgamma = new RandGamma(alpha, beta, dtype, seed);
const res = buffer(shape, dtype);
for (let i = 0; i < res.values.length; i++) {
res.values[i] = rgamma.nextValue();
}
return res.toTensor();
}

/**
* Creates a `tf.Tensor` with values sampled from a uniform distribution.
*
* The generated values follow a uniform distribution in the range [minval,
* maxval). The lower bound minval is included in the range, while the upper
* bound maxval is excluded.
*
* ```js
* tf.randomUniform([2, 2]).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param minval The lower bound on the range of random values to generate.
* Defaults to 0.
* @param maxval The upper bound on the range of random values to generate.
* Defaults to 1.
* @param dtype The data type of the output tensor. Defaults to 'float32'.
*/
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function randomUniform_<R extends Rank>(
shape: ShapeMap[R], minval = 0, maxval = 1, dtype: DataType = 'float32',
seed?: number|string): Tensor<R> {
const res = buffer(shape, dtype);
const random = new UniformRandom(minval, maxval, null, seed);
for (let i = 0; i < res.values.length; i++) {
res.values[i] = random.nextValue();
}
return res.toTensor();
}

/**
* Creates a `tf.Tensor` with values sampled from a random number generator
* function defined by the user.
*
* @param shape An array of integers defining the output tensor shape.
* @param randFunction A random number generator function which is called
* for each element in the output tensor.
* @param dtype The data type of the output tensor. Defaults to 'float32'.
*/
function rand_<R extends Rank>(
shape: ShapeMap[R], randFunction: () => number,
dtype?: DataType): Tensor<R> {
const size = util.sizeFromShape(shape);

let values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
} else if (dtype === 'int32') {
values = new Int32Array(size);
} else if (dtype === 'bool') {
values = new Uint8Array(size);
} else {
throw new Error(`Unknown data type ${dtype}`);
}

for (let i = 0; i < size; i++) {
values[i] = randFunction();
}
return ENGINE.makeTensor(values, shape, dtype) as Tensor<R>;
}

/**
* Creates a `tf.Tensor` with values drawn from a multinomial distribution.
*
* ```js
* const probs = tf.tensor([.75, .25]);
* tf.multinomial(probs, 3).print();
* ```
*
* @param logits 1D array with unnormalized log-probabilities, or
* 2D array of shape `[batchSize, numOutcomes]`. See the `normalized`
* parameter.
* @param numSamples Number of samples to draw for each row slice.
* @param seed The seed number.
* @param normalized Whether the provided `logits` are normalized true
* probabilities (sum to 1). Defaults to false.
* @return 1D array of shape `[numSamples]`, or 2D array of shape
* `[batchSize, numSamples]`, depending on the rank of the input.
*/
/** @doc {heading: 'Tensors', subheading: 'Random'} */
function multinomial_(
logits: Tensor1D|Tensor2D|TensorLike, numSamples: number, seed?: number,
normalized = false): Tensor1D|Tensor2D {
const $logits = convertToTensor(logits, 'logits', 'multinomial');
const numOutcomes = $logits.size;
const origRank = $logits.rank;
if (numOutcomes < 2) {
throw new Error(
`Error in multinomial: you need at least 2 outcomes, but got ` +
`${numOutcomes}.`);
}
if (origRank > 2) {
throw new Error(`Rank of probabilities must be 1 or 2, but is ${origRank}`);
}
seed = seed || Math.random();
const logits2D = origRank === 1 ? $logits.as2D(1, -1) : $logits as Tensor2D;
const res = ENGINE.runKernelFunc(
backend => backend.multinomial(logits2D, normalized, numSamples, seed),
{logits2D});

return origRank === 1 ? res.as1D() : res;
}

/**
* Creates a one-hot `tf.Tensor`. The locations represented by `indices` take
* value `onValue` (defaults to 1), while all other locations take value
Expand Down Expand Up @@ -1100,7 +898,7 @@ async function setdiff1dAsync_(
* zeros.
*/
/** @doc {heading: 'Tensors', subheading: 'Creation'} */
function buffer<R extends Rank, D extends DataType = 'float32'>(
export function buffer<R extends Rank, D extends DataType = 'float32'>(
shape: ShapeMap[R], dtype: D = 'float32' as D,
values?: DataTypeMap[D]): TensorBuffer<R, D> {
dtype = dtype || 'float32' as D;
Expand All @@ -1125,8 +923,7 @@ function print<T extends Tensor>(x: T, verbose = false): void {
}

export {
buffer, // Not wrapped in op() since no tensors.
print // Not wrapped in op() since no need to increase stack trace.
print // Not wrapped in op() since no need to increase stack trace.
};

export const batchToSpaceND = op({batchToSpaceND_});
Expand All @@ -1136,22 +933,16 @@ export const cumsum = op({cumsum_});
export const depthToSpace = op({depthToSpace_});
export const expandDims = op({expandDims_});
export const eye = op({eye_});
export const multinomial = op({multinomial_});
export const oneHot = op({oneHot_});
export const pad = op({pad_});
export const pad1d = op({pad1d_});
export const pad2d = op({pad2d_});
export const pad3d = op({pad3d_});
export const pad4d = op({pad4d_});
export const rand = op({rand_});
export const randomNormal = op({randomNormal_});
export const randomGamma = op({randomGamma_});
export const randomUniform = op({randomUniform_});
export const reshape = op({reshape_});
export const spaceToBatchND = op({spaceToBatchND_});
export const squeeze = op({squeeze_});
export const stack = op({stack_});
export const tile = op({tile_});
export const truncatedNormal = op({truncatedNormal_});
export const unstack = op({unstack_});
export const setdiff1dAsync = setdiff1dAsync_;
Loading

0 comments on commit 7f6c78a

Please sign in to comment.