Skip to content

Commit

Permalink
Enable torch tracing by changing assertions in d2go forwards to allow…
Browse files Browse the repository at this point in the history
… for torch.fx.proxy.Proxy type.

Summary:
X-link: facebookresearch/detectron2#4227

Pull Request resolved: facebookresearch#241

Torch FX tracing propagates a type of `torch.fx.proxy.Proxy` through the graph.

Existing type assertions in the d2go code base trigger during torch FX tracing, causing tracing to fail.

This adds a check for FX tracing in progress and  adds a helper function `assert_fx_safe()`, that can be used in place of a standard assertion. This function only applies the assertion if one is not tracing, allowing d2go assertion tests to be compatible with FX tracing.

Reviewed By: wat3rBro

Differential Revision: D35518556

fbshipit-source-id: b5d65165f271722af24e3dd9d33b3e37e4cf0e34
  • Loading branch information
Simon Hollis authored and facebook-github-bot committed Aug 2, 2022
1 parent 80d3844 commit 2dc70f9
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion d2go/modeling/backbone/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.nn as nn
from detectron2 import layers
from detectron2.layers import assert_fx_safe
from mobile_cv.arch.fbnet_v2.irf_block import IRFBlock


Expand All @@ -33,7 +34,7 @@ def __init__(self, in_channels, num_anchors, box_dim=4):
torch.nn.init.constant_(l.bias, 0)

def forward(self, x: List[torch.Tensor]):
assert isinstance(x, (list, tuple))
assert_fx_safe(isinstance(x, (list, tuple)), "Unexpected data type")
logits = [self.cls_logits(y) for y in x]
bbox_reg = [self.bbox_pred(y) for y in x]

Expand Down

0 comments on commit 2dc70f9

Please sign in to comment.