Skip to content

Commit

Permalink
[quant][pt2e] Support setting qconfig by module_type (pytorch#92355)
Browse files Browse the repository at this point in the history
Summary:
This PR supports the following feature for QConfigMapping:
```
qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Conv2d, qconfig)
backend_config = get_qnnpack_pt2e_backend_config()
m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
```
which means users want to set the qconfig for all calls to `torch.nn.Conv2d` to use `qconfig`, note this is only verified for the case when the module is broken down to a single aten op right now, e.g. torch.nn.Conv2d will be torch.ops.aten.convolution op when traced through. will need to support more complicated modules that is broken down to multiple operators later, e.g. (MaxPool)

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_qconfig_module_type

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: pytorch#92355
Approved by: https://github.com/jcaip
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Jan 20, 2023
1 parent 620846c commit 1464db0
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 45 deletions.
17 changes: 9 additions & 8 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,28 +1619,29 @@ def forward(self, x):
def test_qconfig_module_type(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
self.linear = nn.Linear(9, 3)

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv(x)
x = x.reshape((1, -1))
x = self.linear(x)
return x

m = M().eval()
qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]}
example_inputs = (torch.randn(1, 1, 1, 1),)
example_inputs = (torch.randn(1, 1, 3, 3),)
m = prepare_fx(m, qconfig_dict, example_inputs=example_inputs)
m(*example_inputs)
m = convert_fx(m)
m(*example_inputs)
# first conv is quantized, second conv is not quantized
# conv is quantized, linear is not quantized
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_module(nnq.Conv2d),
ns.call_method("dequantize"),
ns.call_module(nn.Linear),
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)

Expand Down
49 changes: 48 additions & 1 deletion test/quantization/fx/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
)
import copy

@skipIfNoQNNPACK
class TestQuantizePT2E(QuantizationTestCase):
@skipIfNoQNNPACK
def test_qconfig_none(self):
class M(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -76,6 +76,53 @@ def forward(self, x):
self.checkGraphModuleNodes(
m, expected_node_list=node_list, expected_node_occurrence=node_occurrence)

def test_qconfig_module_type(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
self.linear = nn.Linear(9, 3)

def forward(self, x):
x = self.conv(x)
x = x.reshape((1, -1))
x = self.linear(x)
return x

with override_quantized_engine("qnnpack"):
m = M().eval()
example_inputs = (torch.randn(1, 1, 3, 3),)

# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)

qconfig = get_default_qconfig("qnnpack")
qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Conv2d, qconfig)
backend_config = get_qnnpack_pt2e_backend_config()
m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)
m(*example_inputs)
m = convert_pt2e(m)
m(*example_inputs)
# conv is quantized, linear is not quantized
node_occurrence = {
# two for input and weight of the conv, one for output for the conv
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3,
}
node_list = [
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.convolution.default),
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor),
ns.call_function(torch.ops.aten.addmm.default),
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)

class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision
@skipIfNoQNNPACK
Expand Down
4 changes: 2 additions & 2 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,8 +1712,8 @@ def forward(self, x):
gm = torch.fx.symbolic_trace(m)

mod_stack = {}
expected_stack = [('sub_mod', str(type(m.sub_mod))),
('sub_mod.conv_mod', str(type(m.sub_mod.conv_mod)))]
expected_stack = [('sub_mod', type(m.sub_mod)),
('sub_mod.conv_mod', type(m.sub_mod.conv_mod))]
for node in gm.graph.nodes:
mod_stack = node.meta.get('nn_module_stack', {})
if mod_stack:
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def call_function(
@contextmanager
def record_nn_module_stack():
try:
tx.nn_module_stack[self.module_key] = str(type(mod))
tx.nn_module_stack[self.module_key] = type(mod)
yield
finally:
del tx.nn_module_stack[self.module_key]
Expand Down
27 changes: 0 additions & 27 deletions torch/ao/quantization/_pt2e/utils.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,13 @@
import torch
import torch._dynamo as torchdynamo
from torch.fx import GraphModule
from torch.nn.utils.fusion import fuse_conv_bn_weights
# TODO[jerryzh168]: move this to a more general util function
from torch.ao.quantization.fx.prepare import (
_is_activation_post_process_node,
)
from collections import OrderedDict
import copy
import operator

# TODO[qihan]: longer term, don't need to retrace or parse the string
# we should have node.meta["nn_module_stack"] that store the dict
def _infer_nn_stack_trace_and_append_on_meta(m, gm, args_as_list):
trace_func, guards = torchdynamo.export(
m,
*copy.deepcopy(args_as_list),
aten_graph=True,
tracing_mode="real"
)
reset_metadata = {}
for node in trace_func.graph.nodes:
nn_module_stack = {}
stack_trace = node.meta.get("stack_trace", None)
if stack_trace is not None:
for line in stack_trace.split("\n"):
if line.startswith("Module stack:"):
mod_trace = eval(line.replace("Module stack:", "")) # pyre-ignore
nn_module_stack = {"nn_module_stack": mod_trace}
reset_metadata[node.name] = nn_module_stack

for n in gm.graph.nodes:
meta = reset_metadata.get(n.name, None)
if meta is not None:
n.meta.update(meta)

# TODO[qihan]: longer term, this should happen in the dynamo stack as well
def _get_renamed_nn_module_stack(nn_module_stack):
# initialize with top level parent scope
Expand Down
5 changes: 0 additions & 5 deletions torch/ao/quantization/_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .fx import prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from ._pt2e.utils import (
# _infer_nn_stack_trace_and_append_on_meta,
_get_renamed_nn_module_stack,
_fuse_conv_bn_,
_rearrange_weight_observer_for_addmm,
Expand All @@ -19,10 +18,6 @@ def prepare_pt2e(
example_inputs: Tuple[Any, ...],
backend_config: BackendConfig,
):
# TODO[jerryzh168]: check if the model is using EXIR - aten dialect
# disabled for now, looks like runing torchdynamo.export twice results in
# errors
# _infer_nn_stack_trace_and_append_on_meta(model, model, example_inputs)
# TODO: move this information to fx node itself
node_name_to_scope: Dict[str, Tuple[str, type]] = {}
for n in model.graph.nodes:
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/_symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def call_module(
with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope:
# module_stack is an ordered dict so writing then deleting the
# entry is equivalent to push/pop on a list
self.module_stack[_scope.module_path] = str(_scope.module_type)
self.module_stack[_scope.module_path] = _scope.module_type
if not self.is_leaf_module(m, module_qualified_name):
ret_val = forward(*args, **kwargs)
else:
Expand Down

0 comments on commit 1464db0

Please sign in to comment.