From 36a65a0907d90ed591479b2ebaa8b61cfa0b4ef0 Mon Sep 17 00:00:00 2001 From: Simon Hollis Date: Wed, 17 Aug 2022 19:43:18 -0700 Subject: [PATCH] Enable torch tracing by changing assertions in d2go forwards to allow for torch.fx.proxy.Proxy type. Summary: Pull Request resolved: https://github.com/facebookresearch/detectron2/pull/4227 X-link: https://github.com/facebookresearch/d2go/pull/241 Torch FX tracing propagates a type of `torch.fx.proxy.Proxy` through the graph. Existing type assertions in the d2go code base trigger during torch FX tracing, causing tracing to fail. This adds a check for FX tracing in progress and adds a helper function `assert_fx_safe()`, that can be used in place of a standard assertion. This function only applies the assertion if one is not tracing, allowing d2go assertion tests to be compatible with FX tracing. Reviewed By: wat3rBro Differential Revision: D35518556 fbshipit-source-id: a9b5d3d580518ca74948544973ae89f8b9de3282 --- detectron2/modeling/poolers.py | 26 +++++++++-------- detectron2/utils/tracing.py | 51 ++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 12 deletions(-) create mode 100644 detectron2/utils/tracing.py diff --git a/detectron2/modeling/poolers.py b/detectron2/modeling/poolers.py index 23d8767d1d..12073b0524 100644 --- a/detectron2/modeling/poolers.py +++ b/detectron2/modeling/poolers.py @@ -7,6 +7,7 @@ from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor from detectron2.structures import Boxes +from detectron2.utils.tracing import assert_fx_safe """ To export ROIPooler to torchscript, in this file, variables that should be annotated with @@ -219,19 +220,20 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]): """ num_level_assignments = len(self.level_poolers) - assert isinstance(x, list) and isinstance( - box_lists, list - ), "Arguments to pooler must be lists" - assert ( - len(x) == num_level_assignments - ), "unequal value, num_level_assignments={}, but x is list of {} Tensors".format( - num_level_assignments, len(x) + assert_fx_safe( + isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists" ) - - assert len(box_lists) == x[0].size( - 0 - ), "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format( - x[0].size(0), len(box_lists) + assert_fx_safe( + len(x) == num_level_assignments, + "unequal value, num_level_assignments={}, but x is list of {} Tensors".format( + num_level_assignments, len(x) + ), + ) + assert_fx_safe( + len(box_lists) == x[0].size(0), + "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format( + x[0].size(0), len(box_lists) + ), ) if len(box_lists) == 0: return _create_zeros(None, x[0].shape[1], *self.output_size, x[0]) diff --git a/detectron2/utils/tracing.py b/detectron2/utils/tracing.py new file mode 100644 index 0000000000..994b615a76 --- /dev/null +++ b/detectron2/utils/tracing.py @@ -0,0 +1,51 @@ +import inspect +from typing import Union +import torch +from torch.fx._symbolic_trace import _orig_module_call +from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current + +from detectron2.utils.env import TORCH_VERSION + + +@torch.jit.ignore +def is_fx_tracing_legacy() -> bool: + """ + Returns a bool indicating whether torch.fx is currently symbolically tracing a module. + Can be useful for gating module logic that is incompatible with symbolic tracing. + """ + return torch.nn.Module.__call__ is not _orig_module_call + + +@torch.jit.ignore +def is_fx_tracing() -> bool: + """Returns whether execution is currently in + Torch FX tracing mode""" + if TORCH_VERSION >= (1, 10): + return is_fx_tracing_current() + else: + return is_fx_tracing_legacy() + + +@torch.jit.ignore +def assert_fx_safe(condition: Union[bool, str], message: str): + """An FX-tracing safe version of assert. + Avoids erroneous type assertion triggering when types are masked inside + an fx.proxy.Proxy object during tracing. + Args: condition - either a boolean expression or a string representing + the condition to test. If this assert triggers an exception when tracing + due to dynamic control flow, try encasing the expression in quotation + marks and supplying it as a string.""" + if not is_fx_tracing(): + try: + if isinstance(condition, str): + caller_frame = inspect.currentframe().f_back + torch._assert( + eval(condition, caller_frame.f_globals, caller_frame.f_locals), message + ) + else: + torch._assert(condition, message) + except torch.fx.proxy.TraceError as e: + print( + "Found a non-FX compatible assertion. Skipping the check. Failure is shown below" + + str(e) + )