forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[reland][quant][pt2e] Add early prototype top level quantize_pt2e APIs (
pytorch#90971) Summary: This PR introduces the top level APIs for quantization support in PyTorch 2.0 Export stack * torch.ao.quantization.quantize_pt2e.prepare_pt2e Takes a model that is captured by the PyTorch 2.0 export (torchdynamo full graph mode) and prepares the model for calibration for post training quantization * torch.ao.quantization.quantize_pt2e.convert_pt2e Takes a calibrated model and converts that to a reference quantized model that can be lowered later to quantized operator libraries or delegation modules Also added a backend config for the qnnpack_pt2e backend: * torch.ao.quantization.backend_config.get_qnnpack_pt2e_backend_config Note: everything related to quantize_pt2e are experimental (prototype), and we don't have any bc guarantees Test Plan: python test/test_quantization.py TestQuantizePT2EModels Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#90971 Approved by: https://github.com/HDCharles
- Loading branch information
1 parent
e48c916
commit 7dd5e55
Showing
7 changed files
with
456 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Owner(s): ["oncall: quantization"] | ||
import torch | ||
import torch._dynamo as torchdynamo | ||
from torch.testing._internal.common_quantization import ( | ||
QuantizationTestCase, | ||
skip_if_no_torchvision, | ||
skipIfNoQNNPACK, | ||
) | ||
from torch.testing._internal.common_quantized import ( | ||
override_quantized_engine, | ||
) | ||
from torch.ao.quantization import ( | ||
get_default_qconfig, | ||
QConfigMapping, | ||
) | ||
from torch.ao.quantization.backend_config import ( | ||
get_qnnpack_backend_config, | ||
) | ||
from torch.ao.quantization.backend_config._qnnpack_pt2e import get_qnnpack_pt2e_backend_config | ||
from torch.ao.quantization.quantize_fx import prepare_fx, convert_to_reference_fx | ||
from torch.ao.quantization._quantize_pt2e import prepare_pt2e, convert_pt2e | ||
from torch.ao.ns.fx.utils import ( | ||
compute_sqnr, | ||
) | ||
|
||
class TestQuantizePT2EModels(QuantizationTestCase): | ||
@skip_if_no_torchvision | ||
@skipIfNoQNNPACK | ||
def test_resnet18(self): | ||
import copy | ||
import torchvision | ||
with override_quantized_engine("qnnpack"): | ||
example_inputs = (torch.randn(1, 3, 224, 224),) | ||
m = torchvision.models.resnet18().eval() | ||
m_copy = copy.deepcopy(m) | ||
# program capture | ||
m, guards = torchdynamo.export( | ||
m, | ||
*copy.deepcopy(example_inputs), | ||
aten_graph=True, | ||
tracing_mode="real", | ||
) | ||
|
||
backend_config = get_qnnpack_pt2e_backend_config() | ||
# TODO: define qconfig_mapping specifically for executorch | ||
qconfig = get_default_qconfig("qnnpack") | ||
qconfig_mapping = QConfigMapping().set_global(qconfig) | ||
before_fusion_result = m(*example_inputs) | ||
|
||
m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config) | ||
|
||
# checking that we inserted observers correctly for maxpool operator (input and | ||
# output share observer instance) | ||
self.assertEqual(id(m.activation_post_process_3), id(m.activation_post_process_2)) | ||
after_prepare_result = m(*example_inputs) | ||
m = convert_pt2e(m) | ||
|
||
after_quant_result = m(*example_inputs) | ||
|
||
# comparing with existing fx graph mode quantization reference flow | ||
backend_config = get_qnnpack_backend_config() | ||
m_fx = prepare_fx(m_copy, qconfig_mapping, example_inputs, backend_config=backend_config) | ||
after_prepare_result_fx = m_fx(*example_inputs) | ||
m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config) | ||
|
||
after_quant_result_fx = m_fx(*example_inputs) | ||
|
||
# the result matches exactly after prepare | ||
self.assertEqual(after_prepare_result, after_prepare_result_fx) | ||
self.assertEqual(compute_sqnr(after_prepare_result, after_prepare_result_fx), torch.tensor(float("inf"))) | ||
# there are slight differences after convert due to different implementations | ||
# of quant/dequant | ||
self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1) | ||
self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import torch | ||
import torch._dynamo as torchdynamo | ||
from torch.fx import GraphModule | ||
from torch.nn.utils.fusion import fuse_conv_bn_weights | ||
# TODO[jerryzh168]: move this to a more general util function | ||
from torch.ao.quantization.fx.prepare import ( | ||
is_activation_post_process_node, | ||
) | ||
from collections import OrderedDict | ||
import copy | ||
import operator | ||
|
||
# TODO[qihan]: longer term, don't need to retrace or parse the string | ||
# we should have node.meta["nn_module_stack"] that store the dict | ||
def _infer_nn_stack_trace_and_append_on_meta(m, gm, args_as_list): | ||
trace_func, guards = torchdynamo.export( | ||
m, | ||
*copy.deepcopy(args_as_list), | ||
aten_graph=True, | ||
tracing_mode="real" | ||
) | ||
reset_metadata = {} | ||
for node in trace_func.graph.nodes: | ||
nn_module_stack = {} | ||
stack_trace = node.meta.get("stack_trace", None) | ||
if stack_trace is not None: | ||
for line in stack_trace.split("\n"): | ||
if line.startswith("Module stack:"): | ||
mod_trace = eval(line.replace("Module stack:", "")) # pyre-ignore | ||
nn_module_stack = {"nn_module_stack": mod_trace} | ||
reset_metadata[node.name] = nn_module_stack | ||
|
||
for n in gm.graph.nodes: | ||
meta = reset_metadata.get(n.name, None) | ||
if meta is not None: | ||
n.meta.update(meta) | ||
|
||
# TODO[qihan]: longer term, this should happen in the dynamo stack as well | ||
def _get_renamed_nn_module_stack(nn_module_stack): | ||
# initialize with top level parent scope | ||
nn_module_stack_renamed = OrderedDict([("", None)]) | ||
if nn_module_stack: | ||
# Rename module_key, e.g. "self_layer1_1__conv1" to "self.layer1.1._conv1", for easier downstream parsing | ||
prev_key = "" | ||
for key, value in nn_module_stack.items(): | ||
if not prev_key: | ||
if key.startswith("self_"): | ||
new_key = key[5:] | ||
prev_key = new_key | ||
else: | ||
new_key = prev_key + "." + key[len(prev_key) + 6 :] | ||
nn_module_stack_renamed[new_key] = value | ||
prev_key = new_key | ||
return nn_module_stack_renamed | ||
|
||
def _get_tensor_constant_from_node(node, m): | ||
if node is None: | ||
return None | ||
assert node.op == "get_attr" | ||
return getattr(m, node.target) | ||
|
||
# fuse conv bn weights, inplace modification of the graph_module and graph | ||
def _fuse_conv_bn_(m: GraphModule) -> None: | ||
for n in m.graph.nodes: | ||
if n.op != "call_function" or n.target != torch.ops.aten.native_batch_norm.default: | ||
continue | ||
bn_op = n | ||
n = bn_op.args[0] | ||
if n.op != "call_function" or n.target != torch.ops.aten.convolution.default: | ||
continue | ||
conv_op = n | ||
|
||
# conv weight | ||
conv_w = _get_tensor_constant_from_node(conv_op.args[1], m) | ||
# conv bias | ||
conv_b = _get_tensor_constant_from_node(conv_op.args[2], m) | ||
transpose = conv_op.args[6] | ||
|
||
# bn weight | ||
bn_w = _get_tensor_constant_from_node(bn_op.args[1], m) | ||
# bn bias | ||
bn_b = _get_tensor_constant_from_node(bn_op.args[2], m) | ||
# bn running mean | ||
bn_rm = _get_tensor_constant_from_node(bn_op.args[3], m) | ||
# bn running variance | ||
bn_rv = _get_tensor_constant_from_node(bn_op.args[4], m) | ||
bn_eps = bn_op.args[7] | ||
|
||
fused_weight, fused_bias = fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False) | ||
|
||
# update the weight and bias for conv | ||
conv_args = list(conv_op.args) | ||
# calling data since the fused_weight and fused_bias are nn.Parameter | ||
weight_attr_name = conv_args[1].target | ||
setattr(m, weight_attr_name, fused_weight) | ||
if conv_args[2] is not None: | ||
bias_attr_name = conv_args[2].target | ||
else: | ||
bias_attr_name = weight_attr_name + "_bias" | ||
with m.graph.inserting_before(conv_op): | ||
get_bias_node = m.graph.get_attr(bias_attr_name) | ||
conv_args[2] = get_bias_node | ||
setattr(m, bias_attr_name, fused_bias) | ||
conv_op.args = tuple(conv_args) | ||
|
||
# native_batch_norm has 3 outputs, we expect getitem calls on the output | ||
# and we want to replace the uses of getitem 0 with the output of conv | ||
# | ||
# Before: | ||
# conv -> bn - (first output) -> users1 | ||
# \ - (second output) -> users2 | ||
# \ - (third output) -> users3 | ||
# After: | ||
# conv -> (first output) -> users1 | ||
# bn - | ||
# \ - (second output) -> users2 | ||
# \ - (third output) -> users3 | ||
# if users2 and users3 are empty then bn will be removed through dead code elimination | ||
|
||
for user in bn_op.users: | ||
if user.op != "call_function" or user.target != operator.getitem or user.args[1] != 0: | ||
continue | ||
user.replace_all_uses_with(conv_op) | ||
m.graph.eliminate_dead_code() | ||
m.recompile() | ||
|
||
def _rearrange_weight_observer_for_addmm( | ||
model: GraphModule, | ||
) -> None: | ||
""" | ||
before: | ||
weight - t - observer \ | ||
input - observer - addmm | ||
after: | ||
weight - observer - t \ | ||
input - observer - addmm | ||
""" | ||
named_modules = dict(model.named_modules(remove_duplicate=False)) | ||
for node in model.graph.nodes: | ||
if node.target != torch.ops.aten.addmm.default: | ||
continue | ||
addmm = node | ||
maybe_weight_obs = addmm.args[2] | ||
if not is_activation_post_process_node(maybe_weight_obs, named_modules): | ||
continue | ||
transpose_node = maybe_weight_obs.args[0] | ||
if transpose_node.target != torch.ops.aten.t.default: | ||
continue | ||
# swap the order of transpose and observation | ||
|
||
maybe_weight_obs.replace_input_with(transpose_node, transpose_node.args[0]) | ||
# remove the transpose node | ||
with model.graph.inserting_after(maybe_weight_obs): | ||
args = list(transpose_node.args) | ||
args[0] = maybe_weight_obs | ||
new_transpose_node = model.graph.create_node( | ||
"call_function", | ||
torch.ops.aten.t.default, | ||
tuple(args), | ||
transpose_node.kwargs | ||
) | ||
addmm.replace_input_with(maybe_weight_obs, new_transpose_node) | ||
|
||
model.graph.eliminate_dead_code() | ||
model.graph.lint() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from torch.fx import GraphModule | ||
|
||
from .qconfig_mapping import QConfigMapping | ||
from .backend_config import BackendConfig | ||
from .fx import prepare | ||
from .quantize_fx import _convert_to_reference_decomposed_fx | ||
from ._pt2e.utils import ( | ||
_infer_nn_stack_trace_and_append_on_meta, | ||
_get_renamed_nn_module_stack, | ||
_fuse_conv_bn_, | ||
_rearrange_weight_observer_for_addmm, | ||
) | ||
|
||
from typing import Tuple, Any, Dict | ||
|
||
def prepare_pt2e( | ||
model: GraphModule, | ||
qconfig_mapping: QConfigMapping, | ||
example_inputs: Tuple[Any, ...], | ||
backend_config: BackendConfig, | ||
): | ||
# TODO[jerryzh168]: check if the model is using EXIR - aten dialect | ||
_infer_nn_stack_trace_and_append_on_meta(model, model, example_inputs) | ||
# TODO: move this information to fx node itself | ||
node_name_to_scope: Dict[str, Tuple[str, type]] = {} | ||
for n in model.graph.nodes: | ||
renamed_stack = _get_renamed_nn_module_stack(n.meta.get("nn_module_stack", None)) | ||
current_scope = list(renamed_stack.items())[-1] | ||
node_name_to_scope[n.name] = current_scope | ||
|
||
# TODO: check qconfig_mapping to make sure conv and bn are both configured | ||
# to be quantized before fusion | ||
# TODO: (maybe) rewrite this with subgraph_rewriter | ||
_fuse_conv_bn_(model) | ||
model = prepare( | ||
model, | ||
qconfig_mapping, | ||
False, # is_qat | ||
node_name_to_scope, | ||
example_inputs, | ||
backend_config=backend_config | ||
) | ||
|
||
# TODO: remove hack when we have better support for pattern matching | ||
# move around the observer for addmm | ||
_rearrange_weight_observer_for_addmm(model) | ||
return model | ||
|
||
def convert_pt2e( | ||
model: GraphModule | ||
): | ||
return _convert_to_reference_decomposed_fx(model) |
Oops, something went wrong.