Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mkaglins/pruning masks applying #985

Prev Previous commit
Next Next commit
GroupNorm added
  • Loading branch information
mkaglins committed Oct 18, 2021
commit 30da9d8e1a8ac67bf4ad67d8fe63f0e5c2ab4f2b
15 changes: 14 additions & 1 deletion nncf/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ def from_module(module):
return nncf_bn


class NNCFGroupNorm(_NNCFModuleMixin, nn.GroupNorm):
op_func_name = "group_norm"

@staticmethod
def from_module(module):
assert module.__class__.__name__ == nn.GroupNorm.__name__

nncf_bn = NNCFGroupNorm(module.num_features)
dict_update(nncf_bn.__dict__, module.__dict__)
return nncf_bn


class NNCFConvTranspose2d(_NNCFModuleMixin, nn.ConvTranspose2d):
op_func_name = "conv_transpose2d"
target_weight_dim_for_compression = 1
Expand Down Expand Up @@ -222,7 +234,8 @@ def from_module(module):
NNCFConv2d: nn.Conv2d,
NNCFConv3d: nn.Conv3d,
NNCFLinear: nn.Linear,
NNCFBatchNorm:nn.BatchNorm2d,
NNCFBatchNorm : nn.BatchNorm2d,
NNCFGroupNorm : nn.GroupNorm,
NNCFConvTranspose2d: nn.ConvTranspose2d,
NNCFConvTranspose3d: nn.ConvTranspose3d,
NNCFEmbedding: nn.Embedding,
Expand Down
6 changes: 3 additions & 3 deletions nncf/torch/pruning/filter_pruning/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def set_pruning_rate(self, pruning_rate: Union[float, Dict[int, float]],
if self.zero_grad:
self.zero_grads_for_pruned_modules()

self._apply_masks()
self._propagate_masks()
if not groupwise_pruning_rates_set:
self._pruning_rate = passed_pruning_rate
else:
Expand Down Expand Up @@ -647,7 +647,7 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self,
cur_num += 1
raise RuntimeError("Can't prune model to asked flops pruning rate")

def _apply_masks(self):
def _propagate_masks(self):
nncf_logger.debug("Propagating pruning masks")
# 1. Propagate masks for all modules
graph = self.model.get_original_graph()
Expand Down Expand Up @@ -676,7 +676,7 @@ def prepare_for_export(self):
"""
Applies pruning masks to layer weights before exporting the model to ONNX.
"""
self._apply_masks()
self._propagate_masks()

pruned_layers_stats = self.get_stats_for_pruned_modules()
nncf_logger.debug('Pruned layers statistics: \n%s', pruned_layers_stats.draw())
Expand Down