Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mkaglins/pruning masks applying #985

4 changes: 0 additions & 4 deletions nncf/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,6 @@ def with_attributes(schema: Dict, **kwargs) -> Dict:
" `False` by default.",
default=False
),
"zero_grad": with_attributes(_BOOLEAN,
description="Whether to setting gradients corresponding to zeroed"
" filters to zero during training, `True` by default.",
default=True),
"save_ranking_coeffs_path": with_attributes(_STRING),
"load_ranking_coeffs_path": with_attributes(_STRING),
"legr_params":
Expand Down
38 changes: 33 additions & 5 deletions nncf/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,21 @@ def from_module(module):
def _custom_forward_fn(self, input_):
proxy_padding_value = getattr(self, NNCF_PADDING_VALUE_ATTR_NAME) # hack to get value from ProxyModule
proxy_weight = self.weight
return self._conv_forward_proxy(input_, proxy_weight, proxy_padding_value)
proxy_bias = self.bias
return self._conv_forward_proxy(input_, proxy_weight, proxy_bias, proxy_padding_value)


def _conv_forward_proxy(self, input_, weight, padding_value):
def _conv_forward_proxy(self, input_, weight, bias, padding_value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really need to use proxy_weight and proxy_bias in _custom_forward_fn, instead of self in _conv_forward_proxy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting question, I will discuss it with @vshampor.

Copy link
Contributor

@ljaljushkin ljaljushkin Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's really needed, otherwise it would take original weight/bias directly from conv module, not from the proxy one (ProxyModule) with compressed params

self.get_padding_value_ref().data.fill_(padding_value.item())
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input_, self._reversed_padding_repeated_twice, mode=self.padding_mode,
value=self.get_padding_value_ref().item()),
weight, self.bias, self.stride,
weight, bias, self.stride,
(0, 0), self.dilation, self.groups)
if not self.get_padding_value_ref():
return F.conv2d(input_, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return F.conv2d(input_, weight, bias, self.stride, self.padding, self.dilation, self.groups)
return F.conv2d(F.pad(input_, self._reversed_padding_repeated_twice, value=self.get_padding_value_ref().item()),
weight, self.bias, self.stride,
weight, bias, self.stride,
(0, 0), self.dilation, self.groups)


Expand All @@ -122,6 +123,31 @@ def from_module(module):
dict_update(nncf_linear.__dict__, module.__dict__)
return nncf_linear


class NNCFBatchNorm(_NNCFModuleMixin, nn.BatchNorm2d):
op_func_name = "batch_norm"

@staticmethod
def from_module(module):
assert module.__class__.__name__ == nn.BatchNorm2d.__name__

nncf_bn = NNCFBatchNorm(module.num_features)
dict_update(nncf_bn.__dict__, module.__dict__)
return nncf_bn


class NNCFGroupNorm(_NNCFModuleMixin, nn.GroupNorm):
op_func_name = "group_norm"

@staticmethod
def from_module(module):
assert module.__class__.__name__ == nn.GroupNorm.__name__

nncf_bn = NNCFGroupNorm(module.num_features)
dict_update(nncf_bn.__dict__, module.__dict__)
return nncf_bn


class NNCFConvTranspose2d(_NNCFModuleMixin, nn.ConvTranspose2d):
op_func_name = "conv_transpose2d"
target_weight_dim_for_compression = 1
Expand Down Expand Up @@ -208,6 +234,8 @@ def from_module(module):
NNCFConv2d: nn.Conv2d,
NNCFConv3d: nn.Conv3d,
NNCFLinear: nn.Linear,
NNCFBatchNorm : nn.BatchNorm2d,
NNCFGroupNorm : nn.GroupNorm,
NNCFConvTranspose2d: nn.ConvTranspose2d,
NNCFConvTranspose3d: nn.ConvTranspose3d,
NNCFEmbedding: nn.Embedding,
Expand Down
38 changes: 35 additions & 3 deletions nncf/torch/module_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def __init__(self, op):
def operand(self):
return self.op

def forward(self, *inputs):
return self.op(*inputs)
def forward(self, *inputs, **kwargs):
return self.op(*inputs, **kwargs)


class UpdateInputs(BaseOp):
Expand All @@ -42,7 +42,6 @@ def __init__(self, param_name, op):
def __call__(self, module, _):
if not hasattr(module, self._param_name):
raise TypeError('{} should have {} attribute'.format(type(module), self._param_name))

value = getattr(module, self._param_name)
result = super().__call__(value)
setattr(module, self._param_name, result)
Expand All @@ -53,6 +52,39 @@ def __init__(self, op):
super().__init__("weight", op)


class UpdateParameterList(BaseOp):
"""
A module which updates attributes of a module fed to
forward method call by operand call.
"""

def __init__(self, param_names, op):
super().__init__(op)
self._param_names = param_names

def __call__(self, module, _):
param_values = []
for param_name in self._param_names:
if not hasattr(module, param_name):
raise TypeError('{} should have {} attribute'.format(type(module), param_name))
param_values.append(getattr(module, param_name))
updated_kwargs = dict(zip(self._param_names, param_values))
updated_values = super().__call__(**updated_kwargs)

for param_name, updated_value in zip(self._param_names, updated_values):
setattr(module, param_name, updated_value)


class UpdateWeightAndBias(UpdateParameterList):
"""
A module which updates `weight` and `bias` attributes of a module
fed to forward method call by operand call.
"""

def __init__(self, op):
super().__init__(["weight", "bias"], op)


class UpdatePaddingValue(UpdateParameter):
def __init__(self, op):
super().__init__(NNCF_PADDING_VALUE_ATTR_NAME, op)
78 changes: 37 additions & 41 deletions nncf/torch/pruning/base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,28 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from functools import partial
from functools import update_wrapper
from typing import List, Dict

from torch import nn
import torch
from texttable import Texttable
from torch import nn

from nncf import NNCFConfig
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.pruning.clusterization import Cluster
from nncf.common.pruning.clusterization import Clusterization
from nncf.common.utils.logger import logger as nncf_logger
from nncf.config.extractors import extract_algo_specific_config
from nncf.torch.algo_selector import ZeroCompressionLoss
from nncf.common.graph.transformations.commands import TargetType
from nncf.torch.compression_method_api import PTCompressionAlgorithmBuilder
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.common.utils.logger import logger as nncf_logger
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import TransformationPriority
from nncf.torch.graph.transformations.layout import PTTransformationLayout
from nncf.torch.module_operations import UpdateWeightAndBias
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.graph.transformations.commands import TransformationPriority
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.pruning.operations import PT_PRUNING_OPERATOR_METATYPES
from nncf.torch.pruning.filter_pruning.layers import apply_filter_binary_mask
from nncf.common.pruning.clusterization import Clusterization
from nncf.common.pruning.clusterization import Cluster
from nncf.torch.pruning.structs import PrunedModuleInfo


Expand Down Expand Up @@ -62,6 +60,7 @@ def __init__(self, config, should_init: bool = True):
self.prune_downsample_convs)

self.pruned_module_groups_info = []
self._pruned_norms_operators = {}

@staticmethod
def _set_default_params_for_ranking_type(params: Dict) -> None:
Expand Down Expand Up @@ -111,24 +110,44 @@ def _prune_weights(self, target_model: NNCFNetwork):
assert self._is_pruned_module(module)

nncf_logger.info("Adding Weight Pruner in scope: {}".format(node_name))
operation = self.create_weight_pruning_operation(module)
hook = operation.to(device)
pruning_block = self.create_weight_pruning_operation(module)
# Hook for weights and bias
hook = UpdateWeightAndBias(pruning_block).to(device)
insertion_commands.append(
PTInsertionCommand(
PTTargetPoint(TargetType.OPERATION_WITH_WEIGHTS,
PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
target_node_name=node_name),
hook,
TransformationPriority.PRUNING_PRIORITY
)
)

group_minfos.append(PrunedModuleInfo(node_name=node_name,
module_scope=module_scope,
module=module,
operand=hook,
operand=pruning_block,
node_id=node.node_id))

cluster = Cluster[PrunedModuleInfo](i, group_minfos, [n.node_id for n in group.elements])
self.pruned_module_groups_info.add_cluster(cluster)

# Adding binary masks also for Batch/Group Norms to allow applying masks after propagation
all_norm_layers = target_model_graph.get_nodes_by_types(['batch_norm', 'group_norm'])
for node in all_norm_layers:
node_name = node.node_name
module = target_model.get_containing_module(node_name)

pruning_block = self.create_weight_pruning_operation(module)
# Hook for weights and bias
hook = UpdateWeightAndBias(pruning_block).to(device)
insertion_commands.append(
PTInsertionCommand(
PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
target_node_name=node_name),
hook,
TransformationPriority.PRUNING_PRIORITY
)
)
self._pruned_norms_operators[node_name] = (pruning_block, module)
return insertion_commands

def create_weight_pruning_operation(self, module):
Expand Down Expand Up @@ -169,7 +188,6 @@ def __init__(self, target_model: NNCFNetwork,
self.prune_first = params.get('prune_first_conv', False)
self.prune_last = params.get('prune_last_conv', False)
self.prune_downsample_convs = params.get('prune_downsample_convs', False)
self.zero_grad = params.get('zero_grad', True)
self.prune_flops = False
self.check_pruning_rate(params)
self._hooks = []
Expand All @@ -181,29 +199,7 @@ def set_pruning_rate(self, pruning_rate):
raise NotImplementedError

def step(self, next_step):
raise NotImplementedError

def zero_grads_for_pruned_modules(self):
"""
This function registers a hook that will set the
gradients for pruned filters to zero.
"""
self._clean_hooks()

def hook(grad, mask, dim=0):
mask = mask.to(grad.device)
return apply_filter_binary_mask(mask, grad, dim=dim)

for minfo in self.pruned_module_groups_info.get_all_nodes():
mask = minfo.operand.binary_filter_pruning_mask
weight = minfo.module.weight
dim = minfo.module.target_weight_dim_for_compression
partial_hook = update_wrapper(partial(hook, mask=mask, dim=dim), hook)
self._hooks.append(weight.register_hook(partial_hook))
if minfo.module.bias is not None:
bias = minfo.module.bias
partial_hook = update_wrapper(partial(hook, mask=mask), hook)
self._hooks.append(bias.register_hook(partial_hook))
pass

def check_pruning_rate(self, params):
"""
Expand Down
Loading