Skip to content

Commit

Permalink
[rfc][pkg] check spec for module source before falling back to file i…
Browse files Browse the repository at this point in the history
…n package exporter (pytorch#90258)

Summary: To get source for a particular module, the "correct" thing to do is to check the module's spec and use `get_source` if it's a SourceFileLoader, since subclasses may look elsewhere than the `__file__`, and the spec will give the source of truth. For torch packager, however, we prefer to use linecache, but the loader could still change the file, so we figure out the file for the module using the spec's loader rather than using `module.__file__`, if possible.

Test Plan: This code path will get exercised by CI. Also added a test for remapped files.

Differential Revision: D41412983

Pull Request resolved: pytorch#90258
Approved by: https://github.com/PaliC
  • Loading branch information
smacke authored and pytorchmergebot committed Dec 8, 2022
1 parent e1674d7 commit 0c972fb
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 11 deletions.
1 change: 1 addition & 0 deletions test/package/module_a_remapped_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
result = "module_a_remapped_path"
56 changes: 56 additions & 0 deletions test/package/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# Owner(s): ["oncall: package/deploy"]

import inspect
import os
import platform
import sys
from io import BytesIO
from pathlib import Path
from textwrap import dedent
Expand Down Expand Up @@ -104,6 +106,60 @@ def test_file_structure(self):
import_exclude,
)

def test_loaders_that_remap_files_work_ok(self):
from importlib.abc import MetaPathFinder
from importlib.machinery import SourceFileLoader
from importlib.util import spec_from_loader

class LoaderThatRemapsModuleA(SourceFileLoader):
def get_filename(self, name):
result = super().get_filename(name)
if name == "module_a":
return os.path.join(os.path.dirname(result), "module_a_remapped_path.py")
else:
return result

class FinderThatRemapsModuleA(MetaPathFinder):
def find_spec(self, fullname, path, target):
"""Try to find the original spec for module_a using all the
remaining meta_path finders."""
if fullname != "module_a":
return None
spec = None
for finder in sys.meta_path:
if finder is self:
continue
if hasattr(finder, "find_spec"):
spec = finder.find_spec(fullname, path, target=target)
elif hasattr(finder, "load_module"):
spec = spec_from_loader(fullname, finder)
if spec is not None:
break
assert spec is not None and isinstance(spec.loader, SourceFileLoader)
spec.loader = LoaderThatRemapsModuleA(spec.loader.name, spec.loader.path)
return spec

sys.meta_path.insert(0, FinderThatRemapsModuleA())
# clear it from sys.modules so that we use the custom finder next time
# it gets imported
sys.modules.pop("module_a", None)
try:
buffer = BytesIO()
with PackageExporter(buffer) as he:
import module_a

he.intern("**")
he.save_module(module_a.__name__)


buffer.seek(0)
hi = PackageImporter(buffer)
self.assertTrue("remapped_path" in hi.get_source("module_a"))
finally:
# pop it again to ensure it does not mess up other tests
sys.modules.pop("module_a", None)
sys.meta_path.pop(0)

def test_python_version(self):
"""
Tests that the current python version is stored in the package and is available
Expand Down
26 changes: 15 additions & 11 deletions torch/package/package_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import defaultdict, OrderedDict
from dataclasses import dataclass
from enum import Enum
from importlib.machinery import SourceFileLoader
from pathlib import Path
from typing import (
Any,
Expand Down Expand Up @@ -422,17 +423,20 @@ def _module_exists(self, module_name: str) -> bool:
return False

def _get_source_of_module(self, module: types.ModuleType) -> Optional[str]:
filename = getattr(module, "__file__", None)
result = (
None
if filename is None or not filename.endswith(".py")
else linecache.getlines(filename, module.__dict__)
)

if result is None:
return None

return "".join(result)
filename = None
spec = getattr(module, "__spec__", None)
if spec is not None:
loader = getattr(spec, "loader", None)
if loader is not None and isinstance(loader, SourceFileLoader):
try:
filename = loader.get_filename(module.__name__)
except ImportError:
pass
if filename is None:
filename = getattr(module, "__file__", None)
if isinstance(filename, str) and filename.endswith(".py"):
return "".join(linecache.getlines(filename, module.__dict__))
return None

def add_dependency(self, module_name: str, dependencies=True):
"""Given a module, add it to the dependency graph according to patterns
Expand Down

0 comments on commit 0c972fb

Please sign in to comment.