forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[torch.package] add test case for repackaging parent module (pytorch#…
…72367) Summary: Pull Request resolved: pytorch#72367 Pull Request resolved: pytorch#72299 Test Plan: Before pytorch#71520: ``` Summary Pass: 106 Fail: 1 ✗ caffe2/test:package - test_repackage_import_indirectly_via_parent_module (package.package_d.test_repackage.TestRepackage) Skip: 22 ... ListingSuccess: 1 ``` After pytorch#71520: ``` BUILD SUCCEEDED ✓ ListingSuccess: caffe2/test:package : 129 tests discovered (28.595) ✓ Pass: caffe2/test:package - test_repackage_import_indirectly_via_parent_module (package.package_d.test_repackage.TestRepackage) (18.635) Summary Pass: 1 ListingSuccess: 1 ``` Reviewed By: PaliC Differential Revision: D34015540 fbshipit-source-id: b45af5872ae4a5f52afbc0008494569d1080fa38 (cherry picked from commit 432d728)
- Loading branch information
1 parent
29c81bb
commit f0f49a1
Showing
6 changed files
with
74 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import torch | ||
|
||
from .subpackage_0.subsubpackage_0 import important_string | ||
|
||
|
||
class ImportsDirectlyFromSubSubPackage(torch.nn.Module): | ||
|
||
key = important_string | ||
|
||
def forward(self, inp): | ||
return torch.sum(inp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import torch | ||
|
||
from .subpackage_0 import important_string | ||
|
||
|
||
class ImportsIndirectlyFromSubPackage(torch.nn.Module): | ||
|
||
key = important_string | ||
|
||
def forward(self, inp): | ||
return torch.sum(inp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .subsubpackage_0 import important_string |
1 change: 1 addition & 0 deletions
1
test/package/package_d/subpackage_0/subsubpackage_0/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
important_string = "subsubpackage_0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Owner(s): ["oncall: package/deploy"] | ||
|
||
from io import BytesIO | ||
|
||
from torch.package import ( | ||
PackageExporter, | ||
PackageImporter, | ||
sys_importer, | ||
) | ||
from torch.testing._internal.common_utils import run_tests | ||
|
||
try: | ||
from .common import PackageTestCase | ||
except ImportError: | ||
# Support the case where we run this file directly. | ||
from common import PackageTestCase | ||
|
||
|
||
class TestRepackage(PackageTestCase): | ||
"""Tests for repackaging.""" | ||
|
||
def test_repackage_import_indirectly_via_parent_module(self): | ||
from package_d.imports_directly import ImportsDirectlyFromSubSubPackage | ||
from package_d.imports_indirectly import ImportsIndirectlyFromSubPackage | ||
|
||
model_a = ImportsDirectlyFromSubSubPackage() | ||
buffer = BytesIO() | ||
with PackageExporter(buffer) as pe: | ||
pe.intern("**") | ||
pe.save_pickle("default", "model.py", model_a) | ||
|
||
buffer.seek(0) | ||
pi = PackageImporter(buffer) | ||
loaded_model = pi.load_pickle("default", "model.py") | ||
|
||
model_b = ImportsIndirectlyFromSubPackage() | ||
buffer = BytesIO() | ||
with PackageExporter( | ||
buffer, | ||
importer=( | ||
pi, | ||
sys_importer, | ||
), | ||
) as pe: | ||
pe.intern("**") | ||
pe.save_pickle("default", "model_b.py", model_b) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |