Skip to content

Commit

Permalink
& yarnabrina [ENH] prevent imports in _check_soft_dependencies (#340)
Browse files Browse the repository at this point in the history
The `_check_soft_dependencies` utility had inefficiencies and side
effects, because to check the soft dependency being present, an import
would be attempted, which can lead to circular imports or casacding
imports.

This PR replaces the logic with an import-free check.

It also deprecates the argument `suppress_import_stdout`, as there is no
longer any import, hence no `stdout` output that would be created.

Mirror of sktime/sktime#6355 which includes
contributions by @yarnabrina, and of
sktime/sktime#6719
  • Loading branch information
fkiraly authored Jul 12, 2024
1 parent f0d3306 commit f09f148
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 66 deletions.
2 changes: 2 additions & 0 deletions skbase/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@
"_check_soft_dependencies",
"_check_python_version",
"_check_estimator_deps",
"_normalize_requirement",
"_raise_at_severity",
),
"skbase.utils.random_state": (
"check_random_state",
Expand Down
266 changes: 200 additions & 66 deletions skbase/utils/dependencies/_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@
"""Utility to check soft dependency imports, and raise warnings or errors."""
import sys
import warnings
from importlib import import_module
from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version
from importlib.util import find_spec
from inspect import isclass
from typing import List

from packaging.requirements import InvalidRequirement, Requirement
from packaging.specifiers import InvalidSpecifier, SpecifierSet

from skbase.utils.stdout_mute import StdoutMute
from packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet
from packaging.version import InvalidVersion, Version

__author__: List[str] = ["fkiraly", "mloning"]


# todo 0.10.0: remove suppress_import_stdout argument
def _check_soft_dependencies(
*packages,
package_import_alias=None,
severity="error",
obj=None,
msg=None,
suppress_import_stdout=False,
suppress_import_stdout="deprecated",
):
"""Check if required soft dependencies are installed and raise error or warning.
Expand All @@ -32,40 +34,68 @@ def _check_soft_dependencies(
For instance, the PEP 440 compatible package name such as "pandas";
or a package requirement specifier string such as "pandas>1.2.3".
arg can be str, kwargs tuple, or tuple/list of str, following calls are valid:
`_check_soft_dependencies("package1")`
`_check_soft_dependencies("package1", "package2")`
`_check_soft_dependencies(("package1", "package2"))`
`_check_soft_dependencies(["package1", "package2"])`
* ``_check_soft_dependencies("package1")``
* ``_check_soft_dependencies("package1", "package2")``
* ``_check_soft_dependencies(("package1", "package2"))``
* ``_check_soft_dependencies(["package1", "package2"])``
package_import_alias : dict with str keys and values, optional, default=empty
key-value pairs are package name, import name
import name is str used in python import, i.e., from import_name import ...
should be provided if import name differs from package name
key-value pairs are package name, import name.
import name is str used in python import, i.e., ``from import_name import ...``,
should be provided if import names differ from package name.
For example, ``{"scikit-learn": "sklearn"}`` for the well-known package.
The argument is used as a lookup and can cover more packages
than passed in ``packages``, so a global dictionary of known
aliases can be passed.
severity : str, "error" (default), "warning", "none"
behaviour for raising errors or warnings
"error" - raises a `ModuleNotFoundError` if one of packages is not installed
"warning" - raises a warning if one of packages is not installed
function returns False if one of packages is not installed, otherwise True
"none" - does not raise exception or warning
function returns False if one of packages is not installed, otherwise True
behaviour for raising errors or warnings:
* "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed
* "warning" - raises a warning if one of packages is not installed.
The function returns False if one of packages is not installed, otherwise True
* "none" - does not raise exception or warning.
The function returns False if one of packages is not installed, otherwise True
obj : python class, object, str, or None, default=None
if self is passed here when _check_soft_dependencies is called within __init__,
or a class is passed when it is called at the start of a single-class module,
the error message is more informative and will refer to the class/object;
if str is passed, will be used as name of the class/object or module
msg : str, or None, default=None
if str, will override the error message or warning shown with msg
suppress_import_stdout : bool, optional. Default=False
whether to suppress stdout printout upon import.
Raises
------
InvalidRequirement
if package requirement strings are not PEP 440 compatible
ModuleNotFoundError
error with informative message, asking to install required soft dependencies
TypeError, ValueError
on invalid arguments
Returns
-------
boolean - whether all packages are installed, only if no exception is raised
"""
# todo 0.10.0: remove this warning
if suppress_import_stdout != "deprecated":
warnings.warn(
"In skbase _check_soft_dependencies, the suppress_import_stdout argument "
"is deprecated and no longer has any effect. "
"The argument will be removed in version 0.10.0, so users of the "
"_check_soft_dependencies utility should not pass this argument anymore. "
"The _check_soft_dependencies utility also no longer causes imports, "
"hence no stdout "
"output is created from imports, for any setting of the "
"suppress_import_stdout argument. If you wish to import packages "
"and make use of stdout prints, import the package directly instead.",
DeprecationWarning,
stacklevel=2,
)

if len(packages) == 1 and isinstance(packages[0], (tuple, list)):
packages = packages[0]
if not all(isinstance(x, str) for x in packages):
Expand Down Expand Up @@ -112,6 +142,7 @@ def _check_soft_dependencies(
for package in packages:
try:
req = Requirement(package)
req = _normalize_requirement(req)
except InvalidRequirement:
msg_version = (
f"wrong format for package requirement string, "
Expand All @@ -129,15 +160,13 @@ def _check_soft_dependencies(
package_import_name = package_import_alias[package_name]
else:
package_import_name = package_name
# attempt import - if not possible, we know we need to raise warning/exception
try:
with StdoutMute(active=suppress_import_stdout):
pkg_ref = import_module(package_import_name)
# if package cannot be imported, make the user aware of installation requirement
except ModuleNotFoundError as e:
if msg is None:

pkg_env_version = _get_pkg_version(package_name, package_import_name)

# if package not present, make the user aware of installation reqs
if pkg_env_version is None:
if obj is None and msg is None:
msg = (
f"{e}. "
f"{class_name} requires package {package!r} to be present "
f"in the python environment, but {package!r} was not found. "
)
Expand All @@ -153,24 +182,15 @@ def _check_soft_dependencies(
# if msg is not None, none of the above is executed,
# so if msg is passed it overrides the default messages

if severity == "error":
raise ModuleNotFoundError(msg) from e
elif severity == "warning":
warnings.warn(msg, stacklevel=2)
return False
elif severity == "none":
return False
else:
raise RuntimeError(
"Error in calling _check_soft_dependencies, severity "
'argument must be "error", "warning", or "none",'
f"found {severity!r}."
) from e
_raise_at_severity(
msg,
severity=severity,
caller="_check_soft_dependencies",
)
return False

# now we check compatibility with the version specifier if non-empty
if package_version_req != SpecifierSet(""):
pkg_env_version = pkg_ref.__version__

msg = (
f"{class_name} requires package {package!r} to be present "
f"in the python environment, with version {package_version_req}, "
Expand All @@ -184,23 +204,61 @@ def _check_soft_dependencies(

# raise error/warning or return False if version is incompatible
if pkg_env_version not in package_version_req:
if severity == "error":
raise ModuleNotFoundError(msg)
elif severity == "warning":
warnings.warn(msg, stacklevel=2)
elif severity == "none":
return False
else:
raise RuntimeError(
"Error in calling _check_soft_dependencies, severity argument"
f' must be "error", "warning", or "none", found {severity!r}.'
)
_raise_at_severity(
msg,
severity=severity,
caller="_check_soft_dependencies",
)
return False

# if package can be imported and no version issue was caught for any string,
# then obj is compatible with the requirements and we should return True
return True


@lru_cache
def _get_pkg_version(package_name, package_import_name=None):
"""Check whether package is available in environment, and return its version if yes.
Returns ``Version`` object from ``lru_cache``, this should not be mutated.
Parameters
----------
package_name : str, optional, default=None
name of package to check, e.g., "pandas" or "sklearn".
This is the pypi package name, not the import name, e.g.,
``scikit-learn``, not ``sklearn``.
package_import_name : str, optional, default=None
name of package to check for import, e.g., "pandas" or "sklearn".
Note: this is the import name, not the pypi package name, e.g.,
``sklearn``, not ``scikit-learn``.
If not given, ``package_name`` is used as ``package_import_name``,
i.e., it is assumed that the import name is the same as the package name.
Returns
-------
None, if package is not found at import ``package_import_name``;
``importlib`` ``Version`` of package, if found at import ``package_import_name``
"""
if package_import_name is None:
package_import_name = package_name

# optimized branching to check presence of import
# and presence of package distribution
# first we check import, then we check distribution
# because try/except consumes more runtime
pkg_spec = find_spec(package_import_name)
if pkg_spec is not None:
try:
pkg_env_version = Version(version(package_name))
except (InvalidVersion, PackageNotFoundError):
pkg_env_version = None
else:
pkg_env_version = None

return pkg_env_version


def _check_python_version(obj, package=None, msg=None, severity="error"):
"""Check if system python version is compatible with requirements of obj.
Expand Down Expand Up @@ -263,18 +321,8 @@ def _check_python_version(obj, package=None, msg=None, severity="error"):
f" This is due to python version requirements of the {package} package."
)

if severity == "error":
raise ModuleNotFoundError(msg)
elif severity == "warning":
warnings.warn(msg, stacklevel=2)
elif severity == "none":
return False
else:
raise RuntimeError(
"Error in calling _check_python_version, severity "
f'argument must be "error", "warning", or "none", found {severity!r}.'
)
return True
_raise_at_severity(msg, severity=severity, caller="_check_python_version")
return False


def _check_estimator_deps(obj, msg=None, severity="error"):
Expand Down Expand Up @@ -339,3 +387,89 @@ def _check_estimator_deps(obj, msg=None, severity="error"):
compatible = compatible and pkg_deps_ok

return compatible


def _normalize_requirement(req):
"""Normalize packaging Requirement by removing build metadata from versions.
Parameters
----------
req : packaging.requirements.Requirement
requirement string to normalize, e.g., Requirement("pandas>1.2.3+foobar")
Returns
-------
normalized_req : packaging.requirements.Requirement
normalized requirement object with build metadata removed from versions,
e.g., Requirement("pandas>1.2.3")
"""
# Process each specifier in the requirement
normalized_specs = []
for spec in req.specifier:
# Parse the version and remove the build metadata
spec_v = Version(spec.version)
version_wo_build_metadata = f"{spec_v.major}.{spec_v.minor}.{spec_v.micro}"

# Create a new specifier without the build metadata
normalized_spec = Specifier(f"{spec.operator}{version_wo_build_metadata}")
normalized_specs.append(normalized_spec)

# Reconstruct the specifier set
normalized_specifier_set = SpecifierSet(",".join(str(s) for s in normalized_specs))

# Create a new Requirement object with the normalized specifiers
normalized_req = Requirement(f"{req.name}{normalized_specifier_set}")

return normalized_req


def _raise_at_severity(
msg,
severity,
exception_type=None,
warning_type=None,
stacklevel=2,
caller="_raise_at_severity",
):
"""Raise exception or warning or take no action, based on severity.
Parameters
----------
msg : str
message to raise or warn
severity : str, "error", "warning", or "none"
behaviour for raising errors or warnings
exception_type : Exception, default=ModuleNotFoundError
exception type to raise if severity="severity"
warning_type : warning, default=Warning
warning type to raise if severity="warning"
stacklevel : int, default=2
stacklevel for warnings, if severity="warning"
caller : str, default="_raise_at_severity"
caller name, used in exception if severity not in ["error", "warning", "none"]
Returns
-------
None
Raises
------
exception : exception_type, if severity="error"
warning : warning+type, if severity="warning"
ValueError : if severity not in ["error", "warning", "none"]
"""
if exception_type is None:
exception_type = ModuleNotFoundError

if severity == "error":
raise exception_type(msg)
elif severity == "warning":
warnings.warn(msg, category=warning_type, stacklevel=stacklevel)
elif severity == "none":
return None
else:
raise ValueError(
f"Error in calling {caller}, severity "
f'argument must be "error", "warning", or "none", found {severity!r}.'
)
return None

0 comments on commit f09f148

Please sign in to comment.