Skip to content

Commit

Permalink
Make LinearPackedParams works with both torchscript and torch.package (
Browse files Browse the repository at this point in the history
…pytorch#71656)

Summary:
Pull Request resolved: pytorch#71656

Customized `__getstate__`/`__setstate__` didn't call super (torch.nn.Module), and won't restore attributes (e.g. `_modules`) after being serialized and deserialized via torch.package

After a few iteration, as it turns out, pack/unpack linear param has been supported in torchbind class already, no need to hack torch module anymore.

Test Plan: `buck test caffe2/test/:quantization -- test_linear_api`

Reviewed By: jerryzh168

Differential Revision: D33711086

fbshipit-source-id: 3a36d10c64b7da414d3657d2ef766bb9a9290ea9
(cherry picked from commit 6337b6c)
  • Loading branch information
ShijunK authored and pytorchmergebot committed Feb 7, 2022
1 parent 717d8c6 commit 09e2fb8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 20 deletions.
19 changes: 19 additions & 0 deletions test/quantization/core/test_quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
default_float_qparams_observer,
PerChannelMinMaxObserver,
)
from torch.package import PackageExporter, PackageImporter
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
prepare_dynamic,
Expand Down Expand Up @@ -107,6 +108,8 @@ def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias,
qlinear = class_map[use_fused](in_features, out_features)

qlinear_copy = copy.deepcopy(qlinear)
# set random quantized weight and bias before test torch scriptable
qlinear_copy.set_weight_bias(W_q, B)
self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True)
# Run module with default-initialized parameters.
# This tests that the constructor is correct.
Expand Down Expand Up @@ -175,6 +178,22 @@ def _test_linear_api_impl(self, batch_size, in_features, out_features, use_bias,
self.assertEqual(qlinear.scale, loaded.scale)
self.assertEqual(qlinear.zero_point, loaded.zero_point)

# Test torch.package
buffer = io.BytesIO()
with PackageExporter(buffer) as pe:
pe.save_pickle("module", "qlinear.pkl", qlinear)
buffer.seek(0)

importer = PackageImporter(buffer)
loaded_from_package = importer.load_pickle("module", "qlinear.pkl")
self.assertEqual(qlinear.weight(), loaded_from_package.weight())
self.assertEqual(qlinear.scale, loaded_from_package.scale)
self.assertEqual(qlinear.zero_point, loaded_from_package.zero_point)

for name, module in loaded_from_package.named_modules():
# noop, just make sure attribute "_modules" is restored correctly during torch.package import
assert(name is not None)

# Test copy and deepcopy
copied_linear = copy.copy(qlinear)
self.assertEqual(copied_linear.bias(), qlinear.bias())
Expand Down
20 changes: 0 additions & 20 deletions torch/nn/quantized/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,26 +83,6 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
super(LinearPackedParams, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)

@torch.jit.export
def __getstate__(self):
qweight, bias = self._weight_bias()
return qweight, bias, self.training, self.dtype

@torch.jit.export
def __setstate__(self, state):
self.dtype = state[3]
self.set_weight_bias(state[0], state[1])
self.training = state[2]

def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self))
torch.nn.Module.__init__(new_instance)
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance

def __copy__(self):
return self.__deepcopy__({})

def __repr__(self):
return self._weight_bias().__repr__()
Expand Down

0 comments on commit 09e2fb8

Please sign in to comment.