From d795826ac3e93ce57baea3c2b145e9b99e5e8842 Mon Sep 17 00:00:00 2001 From: Garrett Smith Date: Wed, 21 Dec 2016 13:03:14 -0600 Subject: [PATCH] Fix broken tensorflow-port tests + include in Makefile --- Makefile | 1 + priv/bin/tensorflow-port | 28 +++++++++++++++++----------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/Makefile b/Makefile index c16c070..17d48b9 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,7 @@ upgrade: test: compile test/internal $(TESTS) + priv/bin/tensorflow-port test test-operations: compile test/operations diff --git a/priv/bin/tensorflow-port b/priv/bin/tensorflow-port index d9fa0b8..864816e 100755 --- a/priv/bin/tensorflow-port +++ b/priv/bin/tensorflow-port @@ -279,7 +279,7 @@ class ModelStats(object): def _last_memory_bytes(self): if self._batches: - return self._batches[-1][2] * 1000 + return self._batches[-1][2] * 1024 else: return None @@ -555,7 +555,8 @@ def test(): def test_model_stats(): def generate_stats(batches): stats = ModelStats() - stats._batches = batches + for count, time, mem in batches: + stats.update(count, time, mem) return stats.generate() def assert_equals(stats, name, expected): @@ -569,32 +570,37 @@ def test_model_stats(): assert_equals(s1, "last_batch_time_ms", None) assert_equals(s1, "average_batch_time_ms", None) assert_equals(s1, "predictions_per_second", None) + assert_equals(s1, "last_memory_bytes", None) - s2 = generate_stats([(1, 10000)]) + s2 = generate_stats([(1, 10000, 100)]) assert_equals(s2, "last_batch_time_ms", 10.0) assert_equals(s2, "average_batch_time_ms", 10.0) assert_equals(s2, "predictions_per_second", 100.0) + assert_equals(s2, "last_memory_bytes", 102400) - s3 = generate_stats([(1, 10000), (1, 10000)]) + s3 = generate_stats([(1, 10000, 100), (1, 10000, 200)]) assert_equals(s3, "last_batch_time_ms", 10.0) assert_equals(s3, "average_batch_time_ms", 10.0) assert_equals(s3, "predictions_per_second", 100.0) + assert_equals(s3, "last_memory_bytes", 204800) - s4 = generate_stats([(1, 10000), (1, 20000)]) + s4 = generate_stats([(1, 10000, 200), (1, 20000, 100)]) assert_equals(s4, "last_batch_time_ms", 20.0) assert_equals(s4, "average_batch_time_ms", 50000 / 3 / 1000) assert_equals(s4, "predictions_per_second", 1 / (50000 / 3 / 1000000)) + assert_equals(s4, "last_memory_bytes", 102400) s5 = generate_stats( - [(1, 60000), - (1, 50000), - (1, 40000), - (1, 30000), - (1, 20000), - (1, 10000)]) + [(1, 60000, 101), + (1, 50000, 102), + (1, 40000, 103), + (1, 30000, 104), + (1, 20000, 105), + (1, 10000, 106)]) assert_equals(s5, "last_batch_time_ms", 10.0) assert_equals(s5, "average_batch_time_ms", 350000 / 15 / 1000) assert_equals(s5, "predictions_per_second", 1 / (350000 / 15 / 1000000)) + assert_equals(s5, "last_memory_bytes", 108544) # =================================================================== # Main