Skip to content

Commit

Permalink
Enable torch tracing by changing assertions in d2go forwards to allow…
Browse files Browse the repository at this point in the history
… for torch.fx.proxy.Proxy type.

Summary:
Pull Request resolved: #4227

X-link: facebookresearch/d2go#241

Torch FX tracing propagates a type of `torch.fx.proxy.Proxy` through the graph.

Existing type assertions in the d2go code base trigger during torch FX tracing, causing tracing to fail.

This adds a check for FX tracing in progress and  adds a helper function `assert_fx_safe()`, that can be used in place of a standard assertion. This function only applies the assertion if one is not tracing, allowing d2go assertion tests to be compatible with FX tracing.

Reviewed By: wat3rBro

Differential Revision: D35518556

fbshipit-source-id: a9b5d3d580518ca74948544973ae89f8b9de3282
  • Loading branch information
Simon Hollis authored and facebook-github-bot committed Aug 18, 2022
1 parent 5aeb252 commit 36a65a0
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 12 deletions.
26 changes: 14 additions & 12 deletions detectron2/modeling/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor
from detectron2.structures import Boxes
from detectron2.utils.tracing import assert_fx_safe

"""
To export ROIPooler to torchscript, in this file, variables that should be annotated with
Expand Down Expand Up @@ -219,19 +220,20 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
"""
num_level_assignments = len(self.level_poolers)

assert isinstance(x, list) and isinstance(
box_lists, list
), "Arguments to pooler must be lists"
assert (
len(x) == num_level_assignments
), "unequal value, num_level_assignments={}, but x is list of {} Tensors".format(
num_level_assignments, len(x)
assert_fx_safe(
isinstance(x, list) and isinstance(box_lists, list), "Arguments to pooler must be lists"
)

assert len(box_lists) == x[0].size(
0
), "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format(
x[0].size(0), len(box_lists)
assert_fx_safe(
len(x) == num_level_assignments,
"unequal value, num_level_assignments={}, but x is list of {} Tensors".format(
num_level_assignments, len(x)
),
)
assert_fx_safe(
len(box_lists) == x[0].size(0),
"unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format(
x[0].size(0), len(box_lists)
),
)
if len(box_lists) == 0:
return _create_zeros(None, x[0].shape[1], *self.output_size, x[0])
Expand Down
51 changes: 51 additions & 0 deletions detectron2/utils/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import inspect
from typing import Union
import torch
from torch.fx._symbolic_trace import _orig_module_call
from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current

from detectron2.utils.env import TORCH_VERSION


@torch.jit.ignore
def is_fx_tracing_legacy() -> bool:
"""
Returns a bool indicating whether torch.fx is currently symbolically tracing a module.
Can be useful for gating module logic that is incompatible with symbolic tracing.
"""
return torch.nn.Module.__call__ is not _orig_module_call


@torch.jit.ignore
def is_fx_tracing() -> bool:
"""Returns whether execution is currently in
Torch FX tracing mode"""
if TORCH_VERSION >= (1, 10):
return is_fx_tracing_current()
else:
return is_fx_tracing_legacy()


@torch.jit.ignore
def assert_fx_safe(condition: Union[bool, str], message: str):
"""An FX-tracing safe version of assert.
Avoids erroneous type assertion triggering when types are masked inside
an fx.proxy.Proxy object during tracing.
Args: condition - either a boolean expression or a string representing
the condition to test. If this assert triggers an exception when tracing
due to dynamic control flow, try encasing the expression in quotation
marks and supplying it as a string."""
if not is_fx_tracing():
try:
if isinstance(condition, str):
caller_frame = inspect.currentframe().f_back
torch._assert(
eval(condition, caller_frame.f_globals, caller_frame.f_locals), message
)
else:
torch._assert(condition, message)
except torch.fx.proxy.TraceError as e:
print(
"Found a non-FX compatible assertion. Skipping the check. Failure is shown below"
+ str(e)
)

8 comments on commit 36a65a0

@QingZhong1996
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for pytorch = 1.11.0 version got ImportError: cannot import name 'is_fx_tracing' from 'torch.fx._symbolic_trace'

@Deam0on
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

@miranheo
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch 1.10.0 has same error

@lucasjinreal
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please verify commit before push to master at least on pytorch stable!!

Noone knows where did your dirty imports come from....

@snapcart-ruben
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html

ModuleNotFoundError: No module named 'torch.fx._symbolic_trace'

@basir2021
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

giving the error of
cannot import name 'is_fx_tracing' from 'torch.fx._symbolic_trace'
on pytorch=1.12

@simonhollis
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi everyone. Thanks for your feedback regarding incompatibility of this change with older versions of pytorch. I will prepare an update to resolve this. Thanks for your patience.

@simonhollis
Copy link

@simonhollis simonhollis commented on 36a65a0 Aug 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a patch to resolve these import errors available in PR #4491 and facebookresearch/d2go#362

The patch passes CI tests for PyTorch 1.10 and I am aiming to commit this to master on Monday, but if you need unblocking before then, please try out this patch and let me know any feedback.

Please sign in to comment.