Skip to content

Commit

Permalink
Fix broken tensorflow-port tests + include in Makefile
Browse files Browse the repository at this point in the history
  • Loading branch information
Garrett Smith committed Dec 21, 2016
1 parent 8637f74 commit d795826
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ upgrade:

test: compile
test/internal $(TESTS)
priv/bin/tensorflow-port test

test-operations: compile
test/operations
Expand Down
28 changes: 17 additions & 11 deletions priv/bin/tensorflow-port
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class ModelStats(object):

def _last_memory_bytes(self):
if self._batches:
return self._batches[-1][2] * 1000
return self._batches[-1][2] * 1024
else:
return None

Expand Down Expand Up @@ -555,7 +555,8 @@ def test():
def test_model_stats():
def generate_stats(batches):
stats = ModelStats()
stats._batches = batches
for count, time, mem in batches:
stats.update(count, time, mem)
return stats.generate()

def assert_equals(stats, name, expected):
Expand All @@ -569,32 +570,37 @@ def test_model_stats():
assert_equals(s1, "last_batch_time_ms", None)
assert_equals(s1, "average_batch_time_ms", None)
assert_equals(s1, "predictions_per_second", None)
assert_equals(s1, "last_memory_bytes", None)

s2 = generate_stats([(1, 10000)])
s2 = generate_stats([(1, 10000, 100)])
assert_equals(s2, "last_batch_time_ms", 10.0)
assert_equals(s2, "average_batch_time_ms", 10.0)
assert_equals(s2, "predictions_per_second", 100.0)
assert_equals(s2, "last_memory_bytes", 102400)

s3 = generate_stats([(1, 10000), (1, 10000)])
s3 = generate_stats([(1, 10000, 100), (1, 10000, 200)])
assert_equals(s3, "last_batch_time_ms", 10.0)
assert_equals(s3, "average_batch_time_ms", 10.0)
assert_equals(s3, "predictions_per_second", 100.0)
assert_equals(s3, "last_memory_bytes", 204800)

s4 = generate_stats([(1, 10000), (1, 20000)])
s4 = generate_stats([(1, 10000, 200), (1, 20000, 100)])
assert_equals(s4, "last_batch_time_ms", 20.0)
assert_equals(s4, "average_batch_time_ms", 50000 / 3 / 1000)
assert_equals(s4, "predictions_per_second", 1 / (50000 / 3 / 1000000))
assert_equals(s4, "last_memory_bytes", 102400)

s5 = generate_stats(
[(1, 60000),
(1, 50000),
(1, 40000),
(1, 30000),
(1, 20000),
(1, 10000)])
[(1, 60000, 101),
(1, 50000, 102),
(1, 40000, 103),
(1, 30000, 104),
(1, 20000, 105),
(1, 10000, 106)])
assert_equals(s5, "last_batch_time_ms", 10.0)
assert_equals(s5, "average_batch_time_ms", 350000 / 15 / 1000)
assert_equals(s5, "predictions_per_second", 1 / (350000 / 15 / 1000000))
assert_equals(s5, "last_memory_bytes", 108544)

# ===================================================================
# Main
Expand Down

0 comments on commit d795826

Please sign in to comment.