Skip to content

Commit

Permalink
Added a switch variable which allows it to disable dropout with and i…
Browse files Browse the repository at this point in the history
…nput variable.
  • Loading branch information
markusnagel committed Aug 3, 2016
1 parent f77c420 commit 998d853
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion kaffe/tensorflow/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def __init__(self, inputs, trainable=True):
self.layers = dict(inputs)
# If true, the resulting variables are set as trainable
self.trainable = trainable
# Switch variable for dropout
self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
shape=[],
name='use_dropout')
self.setup()

def setup(self):
Expand Down Expand Up @@ -236,4 +240,5 @@ def batch_normalization(self, input, name, scale_offset=True, relu=False):

@layer
def dropout(self, input, keep_prob, name):
return tf.nn.dropout(input, keep_prob, name=name)
keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
return tf.nn.dropout(input, keep, name=name)

0 comments on commit 998d853

Please sign in to comment.