Skip to content

Commit

Permalink
Merge pull request #3 from pjadzinsky/master
Browse files Browse the repository at this point in the history
Added .gitingore and made it python3 compatible
  • Loading branch information
mpezeshki committed Jun 14, 2015
2 parents 6563a3c + c5ff44a commit 904e8c7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__/
*.pyc
37 changes: 25 additions & 12 deletions test_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from blocks.extensions import FinishAfter, Printing
from blocks.bricks.recurrent import SimpleRecurrent
from blocks.graph import ComputationGraph
import cPickle as pickle
try:
import cPickle as pickle
except:
import pickle

floatX = theano.config.floatX

Expand All @@ -35,14 +38,24 @@ def print_pred(y_hat):
num_classes = 4

with open("ctc_test_data.pkl", "rb") as pkl_file:
data = pickle.load(pkl_file)
inputs = data['inputs']
labels = data['labels']
# from S x T x B x D to S x T x B
inputs_mask = numpy.max(data['mask_inputs'], axis=-1)
labels_mask = data['mask_labels']

print 'Building model ...'
try:
data = pickle.load(pkl_file)
inputs = data['inputs']
labels = data['labels']
# from S x T x B x D to S x T x B
inputs_mask = numpy.max(data['mask_inputs'], axis=-1)
labels_mask = data['mask_labels']
except:
data = pickle.load(pkl_file, encoding='bytes')
inputs = data[b'inputs']
labels = data[b'labels']
# from S x T x B x D to S x T x B
inputs_mask = numpy.max(data[b'mask_inputs'], axis=-1)
labels_mask = data[b'mask_labels']



print('Building model ...')
# T x B x F
x = tensor.tensor3('x', dtype=floatX)
# T x B
Expand Down Expand Up @@ -78,14 +91,14 @@ def print_pred(y_hat):
brick.biases_init = Constant(0)
brick.initialize()

print 'Bulding DataStream ...'
print('Bulding DataStream ...')
dataset = IterableDataset({'x': inputs,
'x_mask': inputs_mask,
'y': labels,
'y_mask': labels_mask})
stream = DataStream(dataset)

print 'Bulding training process...'
print('Bulding training process...')
algorithm = GradientDescent(cost=cost,
params=ComputationGraph(cost).parameters,
step_rule=CompositeRule([StepClipping(10.0),
Expand Down Expand Up @@ -118,5 +131,5 @@ def print_pred(y_hat):
Printing()],
model=model)

print 'Starting training ...'
print('Starting training ...')
main_loop.run()

0 comments on commit 904e8c7

Please sign in to comment.