Skip to content

Commit

Permalink
Update cookbook.
Browse files Browse the repository at this point in the history
  • Loading branch information
vahidk committed Sep 6, 2017
1 parent c8fbf35 commit 0a79918
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Table of Contents
12. [Numerical stability in TensorFlow](#stable)
13. [Building a neural network training framework with learn API](#tf_learn)
14. [TensorFlow Cookbook](#cookbook)
- [Get shape](#get_shape)
- [Batch gather](#batch_gather)
- [Beam search](#beam_search)
- [Merge](#merge)
- [Entropy](#entropy)
Expand Down Expand Up @@ -1202,22 +1204,20 @@ And that's it! This is all you need to get started with TensorFlow learn API. I
<a name="cookbook"></a>
This section includes implementation of a set of common operations in TensorFlow.

### Beam Search <a name="beam_search"></a>
### Get shape <a name="get_shape"></a>
```python
import tensorflow as tf

def get_shape(tensor):
"""Returns static shape if available and dynamic shape otherwise."""
static_shape = tensor.shape.as_list()
dynamic_shape = tf.unstack(tf.shape(tensor))
dims = [s[1] if s[0] is None else s[0]
for s in zip(static_shape, dynamic_shape)]
return dims
```

def log_prob_from_logits(logits, axis=-1):
"""Normalize the log-probabilities so that probabilities sum to one."""
return logits - tf.reduce_logsumexp(logits, axis=axis, keep_dims=True)
### Batch Gather <a name="batch_gather"></a>

```python
def batch_gather(tensor, indices):
"""Gather in batch from a tensor of arbitrary size.
Expand All @@ -1237,6 +1237,11 @@ def batch_gather(tensor, indices):
offset = tf.reshape(tf.range(shape[0]) * shape[1], offset_shape)
output = tf.gather(flat_first, indices + offset)
return output
```

### Beam Search <a name="beam_search"></a>
```python
import tensorflow as tf

def rnn_beam_search(update_fn, initial_state, sequence_length, beam_width,
begin_token_id, end_token_id, name="rnn"):
Expand Down Expand Up @@ -1270,7 +1275,7 @@ def rnn_beam_search(update_fn, initial_state, sequence_length, beam_width,
with tf.variable_scope(name, reuse=True if i > 0 else None):

state, logits = update_fn(state, ids)
logits = log_prob_from_logits(logits)
logits = tf.nn.log_softmax(logits)

sum_logprobs = (
tf.expand_dims(sel_sum_logprobs, axis=2) +
Expand Down

0 comments on commit 0a79918

Please sign in to comment.