From 2dc70f9d7f262a88fa194f4c83795390d8bb9bcf Mon Sep 17 00:00:00 2001 From: Simon Hollis Date: Mon, 1 Aug 2022 19:57:13 -0700 Subject: [PATCH] Enable torch tracing by changing assertions in d2go forwards to allow for torch.fx.proxy.Proxy type. Summary: X-link: https://github.com/facebookresearch/detectron2/pull/4227 Pull Request resolved: https://github.com/facebookresearch/d2go/pull/241 Torch FX tracing propagates a type of `torch.fx.proxy.Proxy` through the graph. Existing type assertions in the d2go code base trigger during torch FX tracing, causing tracing to fail. This adds a check for FX tracing in progress and adds a helper function `assert_fx_safe()`, that can be used in place of a standard assertion. This function only applies the assertion if one is not tracing, allowing d2go assertion tests to be compatible with FX tracing. Reviewed By: wat3rBro Differential Revision: D35518556 fbshipit-source-id: b5d65165f271722af24e3dd9d33b3e37e4cf0e34 --- d2go/modeling/backbone/modules.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/d2go/modeling/backbone/modules.py b/d2go/modeling/backbone/modules.py index 652759ed..94bdd2a4 100644 --- a/d2go/modeling/backbone/modules.py +++ b/d2go/modeling/backbone/modules.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn from detectron2 import layers +from detectron2.layers import assert_fx_safe from mobile_cv.arch.fbnet_v2.irf_block import IRFBlock @@ -33,7 +34,7 @@ def __init__(self, in_channels, num_anchors, box_dim=4): torch.nn.init.constant_(l.bias, 0) def forward(self, x: List[torch.Tensor]): - assert isinstance(x, (list, tuple)) + assert_fx_safe(isinstance(x, (list, tuple)), "Unexpected data type") logits = [self.cls_logits(y) for y in x] bbox_reg = [self.bbox_pred(y) for y in x]