Skip to content

Commit

Permalink
[quant][be] Merge qconfig_mapping_utils.py in quantization and fx fol…
Browse files Browse the repository at this point in the history
…ders (pytorch#89979)

Summary:
att, no functionality changes

Test Plan:
python test/test_quantization.py TestQuantizeFx

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: pytorch#89979
Approved by: https://github.com/vkuzo
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Dec 1, 2022
1 parent 0ad6715 commit 8aee768
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 120 deletions.
7 changes: 2 additions & 5 deletions test/quantization/fx/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@
QConfigMapping,
)

from torch.ao.quantization.qconfig_mapping_utils import (
from torch.ao.quantization.fx.qconfig_mapping_utils import (
_get_object_type_qconfig,
_get_module_name_qconfig,
_get_module_name_regex_qconfig,
maybe_adjust_qconfig_for_module_name_object_type_order,
)

from torch.ao.quantization.fx.pattern_utils import (
Expand Down Expand Up @@ -131,10 +132,6 @@
StandaloneModuleConfigEntry,
)

from torch.ao.quantization.fx.qconfig_mapping_utils import (
maybe_adjust_qconfig_for_module_name_object_type_order,
)

from torch.ao.quantization.fx.utils import (
_reroute_tuple_getitem_pattern,
NodeInfo,
Expand Down
4 changes: 1 addition & 3 deletions torch/ao/quantization/fx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@
qconfig_equals
)
from ..qconfig_mapping import QConfigMapping
from ..qconfig_mapping_utils import (
_update_qconfig_for_qat,
)
from .qconfig_mapping_utils import (
generate_node_name_to_qconfig,
compare_prepare_convert_qconfig_mappings,
update_qconfig_for_fusion,
is_qconfig_supported_by_dtype_configs,
_update_qconfig_for_qat,
)
from torch.ao.quantization.backend_config.utils import (
get_root_module_to_quantized_reference_module,
Expand Down
6 changes: 2 additions & 4 deletions torch/ao/quantization/fx/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@
from ..qconfig_mapping import (
QConfigMapping,
)
from ..qconfig_mapping_utils import (
_get_flattened_qconfig_dict,
_update_qconfig_for_qat,
)
from .qconfig_mapping_utils import (
generate_node_name_to_qconfig,
update_qconfig_for_fusion,
_get_flattened_qconfig_dict,
_update_qconfig_for_qat,
)

from .quantize_handler import (
Expand Down
96 changes: 91 additions & 5 deletions torch/ao/quantization/fx/qconfig_mapping_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import re
from collections import defaultdict, OrderedDict
from typing import Callable, Any, Dict, Tuple, Set, List
from typing import Callable, Any, Dict, Tuple, Set, List, Union
from torch.ao.quantization import QConfig
from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals
from torch.ao.quantization.quantize import (
Expand All @@ -21,19 +22,18 @@
from ..utils import (
_parent_name,
get_qconfig_dtypes,
get_combined_dict
)
from ..qconfig_mapping import (
_OBJECT_TYPE_DICT_KEY,
_MODULE_NAME_DICT_KEY,
_MODULE_NAME_REGEX_DICT_KEY,
QConfigMapping,
)
from ..qconfig_mapping_utils import (
_get_object_type_qconfig,
_maybe_adjust_qconfig_for_module_type_or_name,
from ..quantization_mappings import (
get_default_qat_module_mappings,
)


# TODO: revisit this list. Many helper methods shouldn't be public
__all__ = [
"check_is_valid_config_dict",
Expand Down Expand Up @@ -264,3 +264,89 @@ def is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[
if is_match:
return True
return False

def _get_object_type_qconfig(
qconfig_mapping: QConfigMapping,
object_type: Union[Callable, str],
fallback_qconfig: QConfigAny) -> QConfigAny:
return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig)


def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items():
if re.match(regex_pattern, module_name):
# first match wins
return qconfig
return fallback_qconfig


def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
if module_name == '':
# module name qconfig not found
return fallback_qconfig
if module_name in qconfig_mapping.module_name_qconfigs:
return qconfig_mapping.module_name_qconfigs[module_name]
else:
parent, _ = _parent_name(module_name)
return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)


def _maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig):
# get qconfig for module_name,
# fallback to module_name_regex_qconfig, module_type_qconfig,
# global_qconfig if necessary
module_type_qconfig = _get_object_type_qconfig(
qconfig_mapping, module_type, global_qconfig)
module_name_regex_qconfig = _get_module_name_regex_qconfig(
qconfig_mapping, module_name, module_type_qconfig)
module_name_qconfig = _get_module_name_qconfig(
qconfig_mapping, module_name, module_name_regex_qconfig)
return module_name_qconfig


def _get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]:
""" flatten the global, object_type and module_name qconfig
to the same qconfig_dict so that it can be used by
propagate_qconfig_ function.
"module_name_regex" is ignored for now since it's not supported
in propagate_qconfig_, but it can be fixed later.
For example:
Input: {
"": qconfig,
"object_type": [
(torch.add, qconfig)
],
"module_name": [
("conv", qconfig)
]
}
Output: {
"": qconfig,
torch.add: qconfig,
"conv": qconfig
}
"""
flattened: Dict[Union[Callable, str], QConfigAny] = {"": qconfig_mapping.global_qconfig}
for obj, qconfig in qconfig_mapping.object_type_qconfigs.items():
flattened[obj] = qconfig
for obj, qconfig in qconfig_mapping.module_name_qconfigs.items():
flattened[obj] = qconfig
return flattened


def _update_qconfig_for_qat(
qconfig_mapping: QConfigMapping,
additional_qat_module_mapping: Dict[Callable, Callable]):
"""
Update the qconfig_dict to account for module swaps during QAT.
During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types.
"""
all_qat_mappings = get_combined_dict(
get_default_qat_module_mappings(), additional_qat_module_mapping)
object_type_dict = qconfig_mapping.object_type_qconfigs
new_object_type_dict = object_type_dict.copy()
for k, v in new_object_type_dict.items():
if k in all_qat_mappings:
object_type_dict[all_qat_mappings[k]] = v
103 changes: 0 additions & 103 deletions torch/ao/quantization/qconfig_mapping_utils.py

This file was deleted.

0 comments on commit 8aee768

Please sign in to comment.