From c650c73cbcc3765ef67ae328e0ee37431597912a Mon Sep 17 00:00:00 2001 From: Kaiyu Shi Date: Mon, 8 Jan 2018 20:54:55 +0800 Subject: [PATCH] Extract the finish check for profiler (#4519) * Extract the finish check for profiler Delete unused import and rearrange the import order. * Add imports for win support --- torch/autograd/profiler.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 7e2c251feb9d5..83066a656407f 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -1,11 +1,11 @@ -import torch import subprocess import os import sys -import copy -import tempfile +import re import itertools -from collections import defaultdict, namedtuple +from collections import defaultdict + +import torch try: FileNotFoundError @@ -202,27 +202,27 @@ def __str__(self): return '' return str(self.function_events) - def table(self, sort_by=None): + def _check_finish(self): if self.function_events is None: raise RuntimeError("can't export a trace that didn't finish running") + + def table(self, sort_by=None): + self._check_finish() return self.function_events.table(sort_by) table.__doc__ = EventList.table.__doc__ def export_chrome_trace(self, path): - if self.function_events is None: - raise RuntimeError("can't export a trace that didn't finish running") + self._check_finish() return self.function_events.export_chrome_trace(path) export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__ def key_averages(self): - if self.function_events is None: - raise RuntimeError("can't average a trace that didn't finish running") + self._check_finish() return self.function_events.key_averages() key_averages.__doc__ = EventList.key_averages.__doc__ def total_average(self): - if self.function_events is None: - raise RuntimeError("can't average a trace that didn't finish running") + self._check_finish() return self.function_events.total_average() total_average.__doc__ = EventList.total_average.__doc__