Skip to content

Commit

Permalink
[reland][quant][pt2e] Add early prototype top level quantize_pt2e APIs (
Browse files Browse the repository at this point in the history
pytorch#90971)

Summary:
This PR introduces the top level APIs for quantization support in PyTorch 2.0 Export stack
* torch.ao.quantization.quantize_pt2e.prepare_pt2e
Takes a model that is captured by the PyTorch 2.0 export (torchdynamo full graph mode) and prepares the model for calibration
for post training quantization

* torch.ao.quantization.quantize_pt2e.convert_pt2e
Takes a calibrated model and converts that to a reference quantized model that can be lowered later to quantized operator libraries or delegation modules

Also added a backend config for the qnnpack_pt2e backend:
* torch.ao.quantization.backend_config.get_qnnpack_pt2e_backend_config

Note: everything related to quantize_pt2e are experimental (prototype), and we don't have any bc guarantees

Test Plan:
python test/test_quantization.py TestQuantizePT2EModels

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#90971
Approved by: https://github.com/HDCharles
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Dec 16, 2022
1 parent e48c916 commit 7dd5e55
Show file tree
Hide file tree
Showing 7 changed files with 456 additions and 2 deletions.
74 changes: 74 additions & 0 deletions test/quantization/fx/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Owner(s): ["oncall: quantization"]
import torch
import torch._dynamo as torchdynamo
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skip_if_no_torchvision,
skipIfNoQNNPACK,
)
from torch.testing._internal.common_quantized import (
override_quantized_engine,
)
from torch.ao.quantization import (
get_default_qconfig,
QConfigMapping,
)
from torch.ao.quantization.backend_config import (
get_qnnpack_backend_config,
)
from torch.ao.quantization.backend_config._qnnpack_pt2e import get_qnnpack_pt2e_backend_config
from torch.ao.quantization.quantize_fx import prepare_fx, convert_to_reference_fx
from torch.ao.quantization._quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.ns.fx.utils import (
compute_sqnr,
)

class TestQuantizePT2EModels(QuantizationTestCase):
@skip_if_no_torchvision
@skipIfNoQNNPACK
def test_resnet18(self):
import copy
import torchvision
with override_quantized_engine("qnnpack"):
example_inputs = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m_copy = copy.deepcopy(m)
# program capture
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)

backend_config = get_qnnpack_pt2e_backend_config()
# TODO: define qconfig_mapping specifically for executorch
qconfig = get_default_qconfig("qnnpack")
qconfig_mapping = QConfigMapping().set_global(qconfig)
before_fusion_result = m(*example_inputs)

m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config)

# checking that we inserted observers correctly for maxpool operator (input and
# output share observer instance)
self.assertEqual(id(m.activation_post_process_3), id(m.activation_post_process_2))
after_prepare_result = m(*example_inputs)
m = convert_pt2e(m)

after_quant_result = m(*example_inputs)

# comparing with existing fx graph mode quantization reference flow
backend_config = get_qnnpack_backend_config()
m_fx = prepare_fx(m_copy, qconfig_mapping, example_inputs, backend_config=backend_config)
after_prepare_result_fx = m_fx(*example_inputs)
m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)

after_quant_result_fx = m_fx(*example_inputs)

# the result matches exactly after prepare
self.assertEqual(after_prepare_result, after_prepare_result_fx)
self.assertEqual(compute_sqnr(after_prepare_result, after_prepare_result_fx), torch.tensor(float("inf")))
# there are slight differences after convert due to different implementations
# of quant/dequant
self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1)
self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35)
2 changes: 2 additions & 0 deletions test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
from quantization.fx.test_quantize_fx import TestQuantizeFxOps # noqa: F401
from quantization.fx.test_quantize_fx import TestQuantizeFxModels # noqa: F401
from quantization.fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401
# Quantization for PyTorch 2.0 Export path
from quantization.fx.test_quantize_pt2e import TestQuantizePT2EModels # noqa: F401
except ImportError:
# In FBCode we separate FX out into a separate target for the sake of dev
# velocity. These are covered by a separate test target `quantization_fx`
Expand Down
Empty file.
165 changes: 165 additions & 0 deletions torch/ao/quantization/_pt2e/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import torch
import torch._dynamo as torchdynamo
from torch.fx import GraphModule
from torch.nn.utils.fusion import fuse_conv_bn_weights
# TODO[jerryzh168]: move this to a more general util function
from torch.ao.quantization.fx.prepare import (
is_activation_post_process_node,
)
from collections import OrderedDict
import copy
import operator

# TODO[qihan]: longer term, don't need to retrace or parse the string
# we should have node.meta["nn_module_stack"] that store the dict
def _infer_nn_stack_trace_and_append_on_meta(m, gm, args_as_list):
trace_func, guards = torchdynamo.export(
m,
*copy.deepcopy(args_as_list),
aten_graph=True,
tracing_mode="real"
)
reset_metadata = {}
for node in trace_func.graph.nodes:
nn_module_stack = {}
stack_trace = node.meta.get("stack_trace", None)
if stack_trace is not None:
for line in stack_trace.split("\n"):
if line.startswith("Module stack:"):
mod_trace = eval(line.replace("Module stack:", "")) # pyre-ignore
nn_module_stack = {"nn_module_stack": mod_trace}
reset_metadata[node.name] = nn_module_stack

for n in gm.graph.nodes:
meta = reset_metadata.get(n.name, None)
if meta is not None:
n.meta.update(meta)

# TODO[qihan]: longer term, this should happen in the dynamo stack as well
def _get_renamed_nn_module_stack(nn_module_stack):
# initialize with top level parent scope
nn_module_stack_renamed = OrderedDict([("", None)])
if nn_module_stack:
# Rename module_key, e.g. "self_layer1_1__conv1" to "self.layer1.1._conv1", for easier downstream parsing
prev_key = ""
for key, value in nn_module_stack.items():
if not prev_key:
if key.startswith("self_"):
new_key = key[5:]
prev_key = new_key
else:
new_key = prev_key + "." + key[len(prev_key) + 6 :]
nn_module_stack_renamed[new_key] = value
prev_key = new_key
return nn_module_stack_renamed

def _get_tensor_constant_from_node(node, m):
if node is None:
return None
assert node.op == "get_attr"
return getattr(m, node.target)

# fuse conv bn weights, inplace modification of the graph_module and graph
def _fuse_conv_bn_(m: GraphModule) -> None:
for n in m.graph.nodes:
if n.op != "call_function" or n.target != torch.ops.aten.native_batch_norm.default:
continue
bn_op = n
n = bn_op.args[0]
if n.op != "call_function" or n.target != torch.ops.aten.convolution.default:
continue
conv_op = n

# conv weight
conv_w = _get_tensor_constant_from_node(conv_op.args[1], m)
# conv bias
conv_b = _get_tensor_constant_from_node(conv_op.args[2], m)
transpose = conv_op.args[6]

# bn weight
bn_w = _get_tensor_constant_from_node(bn_op.args[1], m)
# bn bias
bn_b = _get_tensor_constant_from_node(bn_op.args[2], m)
# bn running mean
bn_rm = _get_tensor_constant_from_node(bn_op.args[3], m)
# bn running variance
bn_rv = _get_tensor_constant_from_node(bn_op.args[4], m)
bn_eps = bn_op.args[7]

fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False)

# update the weight and bias for conv
conv_args = list(conv_op.args)
# calling data since the fused_weight and fused_bias are nn.Parameter
weight_attr_name = conv_args[1].target
setattr(m, weight_attr_name, fused_weight)
if conv_args[2] is not None:
bias_attr_name = conv_args[2].target
else:
bias_attr_name = weight_attr_name + "_bias"
with m.graph.inserting_before(conv_op):
get_bias_node = m.graph.get_attr(bias_attr_name)
conv_args[2] = get_bias_node
setattr(m, bias_attr_name, fused_bias)
conv_op.args = tuple(conv_args)

# native_batch_norm has 3 outputs, we expect getitem calls on the output
# and we want to replace the uses of getitem 0 with the output of conv
#
# Before:
# conv -> bn - (first output) -> users1
# \ - (second output) -> users2
# \ - (third output) -> users3
# After:
# conv -> (first output) -> users1
# bn -
# \ - (second output) -> users2
# \ - (third output) -> users3
# if users2 and users3 are empty then bn will be removed through dead code elimination

for user in bn_op.users:
if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0:
continue
user.replace_all_uses_with(conv_op)
m.graph.eliminate_dead_code()
m.recompile()

def _rearrange_weight_observer_for_addmm(
model: GraphModule,
) -> None:
"""
before:
weight - t - observer \
input - observer - addmm
after:
weight - observer - t \
input - observer - addmm
"""
named_modules = dict(model.named_modules(remove_duplicate=False))
for node in model.graph.nodes:
if node.target != torch.ops.aten.addmm.default:
continue
addmm = node
maybe_weight_obs = addmm.args[2]
if not is_activation_post_process_node(maybe_weight_obs, named_modules):
continue
transpose_node = maybe_weight_obs.args[0]
if transpose_node.target != torch.ops.aten.t.default:
continue
# swap the order of transpose and observation

maybe_weight_obs.replace_input_with(transpose_node, transpose_node.args[0])
# remove the transpose node
with model.graph.inserting_after(maybe_weight_obs):
args = list(transpose_node.args)
args[0] = maybe_weight_obs
new_transpose_node = model.graph.create_node(
"call_function",
torch.ops.aten.t.default,
tuple(args),
transpose_node.kwargs
)
addmm.replace_input_with(maybe_weight_obs, new_transpose_node)

model.graph.eliminate_dead_code()
model.graph.lint()
52 changes: 52 additions & 0 deletions torch/ao/quantization/_quantize_pt2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from torch.fx import GraphModule

from .qconfig_mapping import QConfigMapping
from .backend_config import BackendConfig
from .fx import prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from ._pt2e.utils import (
_infer_nn_stack_trace_and_append_on_meta,
_get_renamed_nn_module_stack,
_fuse_conv_bn_,
_rearrange_weight_observer_for_addmm,
)

from typing import Tuple, Any, Dict

def prepare_pt2e(
model: GraphModule,
qconfig_mapping: QConfigMapping,
example_inputs: Tuple[Any, ...],
backend_config: BackendConfig,
):
# TODO[jerryzh168]: check if the model is using EXIR - aten dialect
_infer_nn_stack_trace_and_append_on_meta(model, model, example_inputs)
# TODO: move this information to fx node itself
node_name_to_scope: Dict[str, Tuple[str, type]] = {}
for n in model.graph.nodes:
renamed_stack = _get_renamed_nn_module_stack(n.meta.get("nn_module_stack", None))
current_scope = list(renamed_stack.items())[-1]
node_name_to_scope[n.name] = current_scope

# TODO: check qconfig_mapping to make sure conv and bn are both configured
# to be quantized before fusion
# TODO: (maybe) rewrite this with subgraph_rewriter
_fuse_conv_bn_(model)
model = prepare(
model,
qconfig_mapping,
False, # is_qat
node_name_to_scope,
example_inputs,
backend_config=backend_config
)

# TODO: remove hack when we have better support for pattern matching
# move around the observer for addmm
_rearrange_weight_observer_for_addmm(model)
return model

def convert_pt2e(
model: GraphModule
):
return _convert_to_reference_decomposed_fx(model)
Loading

0 comments on commit 7dd5e55

Please sign in to comment.