diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 279478ef..9615a3b8 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -251,7 +251,10 @@ "skbase.utils.dependencies._dependencies": ( "_check_soft_dependencies", "_check_python_version", + "_check_env_marker", "_check_estimator_deps", + "_get_pkg_version", + "_get_installed_packages", "_normalize_requirement", "_raise_at_severity", ), diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index 226c2912..9fcf5205 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -3,22 +3,19 @@ import sys import warnings from functools import lru_cache -from importlib.metadata import PackageNotFoundError, version -from importlib.util import find_spec +from importlib.metadata import distributions from inspect import isclass -from typing import List +from packaging.markers import InvalidMarker, Marker from packaging.requirements import InvalidRequirement, Requirement from packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet from packaging.version import InvalidVersion, Version -__author__: List[str] = ["fkiraly", "mloning"] - # todo 0.10.0: remove suppress_import_stdout argument def _check_soft_dependencies( *packages, - package_import_alias=None, + package_import_alias="deprecated", severity="error", obj=None, msg=None, @@ -31,32 +28,24 @@ def _check_soft_dependencies( packages : str or list/tuple of str, or length-1-tuple containing list/tuple of str str should be package names and/or package version specifications to check. Each str must be a PEP 440 compatible specifier string, for a single package. - For instance, the PEP 440 compatible package name such as "pandas"; - or a package requirement specifier string such as "pandas>1.2.3". + For instance, the PEP 440 compatible package name such as ``"pandas"``; + or a package requirement specifier string such as ``"pandas>1.2.3"``. arg can be str, kwargs tuple, or tuple/list of str, following calls are valid: + ``_check_soft_dependencies("package1")`` + ``_check_soft_dependencies("package1", "package2")`` + ``_check_soft_dependencies(("package1", "package2"))`` + ``_check_soft_dependencies(["package1", "package2"])`` - * ``_check_soft_dependencies("package1")`` - * ``_check_soft_dependencies("package1", "package2")`` - * ``_check_soft_dependencies(("package1", "package2"))`` - * ``_check_soft_dependencies(["package1", "package2"])`` - - package_import_alias : dict with str keys and values, optional, default=empty - key-value pairs are package name, import name. - import name is str used in python import, i.e., ``from import_name import ...``, - should be provided if import names differ from package name. - For example, ``{"scikit-learn": "sklearn"}`` for the well-known package. - The argument is used as a lookup and can cover more packages - than passed in ``packages``, so a global dictionary of known - aliases can be passed. + package_import_alias : ignored, present only for backwards compatibility severity : str, "error" (default), "warning", "none" - behaviour for raising errors or warnings: + whether the check should raise an error, a warning, or nothing * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed - * "warning" - raises a warning if one of packages is not installed. - The function returns False if one of packages is not installed, otherwise True - * "none" - does not raise exception or warning. - The function returns False if one of packages is not installed, otherwise True + * "warning" - raises a warning if one of packages is not installed + function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning + function returns False if one of packages is not installed, otherwise True obj : python class, object, str, or None, default=None if self is passed here when _check_soft_dependencies is called within __init__, @@ -104,20 +93,6 @@ def _check_soft_dependencies( f"str, but found packages argument of type {type(packages)}" ) - if package_import_alias is None: - package_import_alias = {} - msg_pkg_import_alias = ( - "package_import_alias argument of _check_soft_dependencies must " - "be a dict with str keys and values, but found " - f"package_import_alias of type {type(package_import_alias)}" - ) - if not isinstance(package_import_alias, dict): - raise TypeError(msg_pkg_import_alias) - if not all(isinstance(x, str) for x in package_import_alias.keys()): - raise TypeError(msg_pkg_import_alias) - if not all(isinstance(x, str) for x in package_import_alias.values()): - raise TypeError(msg_pkg_import_alias) - if obj is None: class_name = "This functionality" elif not isclass(obj): @@ -155,13 +130,7 @@ def _check_soft_dependencies( package_name = req.name package_version_req = req.specifier - # determine the package import - if package_name in package_import_alias.keys(): - package_import_name = package_import_alias[package_name] - else: - package_import_name = package_name - - pkg_env_version = _get_pkg_version(package_name, package_import_name) + pkg_env_version = _get_pkg_version(package_name) # if package not present, make the user aware of installation reqs if pkg_env_version is None: @@ -170,23 +139,21 @@ def _check_soft_dependencies( f"{class_name} requires package {package!r} to be present " f"in the python environment, but {package!r} was not found. " ) - if obj is not None: - msg = msg + ( - f"{package!r} is a dependency of {class_name} and required " - f"to construct it. " - ) - msg = msg + ( - f"Please run: `pip install {package}` to " - f"install the {package} package. " + elif msg is None: # obj is not None, msg is None + msg = ( + f"{class_name} requires package {package!r} to be present " + f"in the python environment, but {package!r} was not found. " + f"{package!r} is a dependency of {class_name} and required " + f"to construct it. " ) + msg = msg + ( + f"Please run: `pip install {package}` to " + f"install the {package} package. " + ) # if msg is not None, none of the above is executed, # so if msg is passed it overrides the default messages - _raise_at_severity( - msg, - severity=severity, - caller="_check_soft_dependencies", - ) + _raise_at_severity(msg, severity, caller="_check_soft_dependencies") return False # now we check compatibility with the version specifier if non-empty @@ -204,11 +171,7 @@ def _check_soft_dependencies( # raise error/warning or return False if version is incompatible if pkg_env_version not in package_version_req: - _raise_at_severity( - msg, - severity=severity, - caller="_check_soft_dependencies", - ) + _raise_at_severity(msg, severity, caller="_check_soft_dependencies") return False # if package can be imported and no version issue was caught for any string, @@ -217,7 +180,30 @@ def _check_soft_dependencies( @lru_cache -def _get_pkg_version(package_name, package_import_name=None): +def _get_installed_packages_private(): + """Get a dictionary of installed packages and their versions. + + Same as _get_installed_packages, but internal to avoid mutating the lru_cache + by accident. + """ + dists = distributions() + packages = {dist.metadata["Name"]: dist.version for dist in dists} + return packages + + +def _get_installed_packages(): + """Get a dictionary of installed packages and their versions. + + Returns + ------- + dict : dictionary of installed packages and their versions + keys are PEP 440 compatible package names, values are package versions + MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3" + """ + return _get_installed_packages_private().copy() + + +def _get_pkg_version(package_name): """Check whether package is available in environment, and return its version if yes. Returns ``Version`` object from ``lru_cache``, this should not be mutated. @@ -225,37 +211,24 @@ def _get_pkg_version(package_name, package_import_name=None): Parameters ---------- package_name : str, optional, default=None - name of package to check, e.g., "pandas" or "sklearn". + name of package to check, + PEP 440 compatibe specifier string, e.g., "pandas" or "sklearn". This is the pypi package name, not the import name, e.g., ``scikit-learn``, not ``sklearn``. - package_import_name : str, optional, default=None - name of package to check for import, e.g., "pandas" or "sklearn". - Note: this is the import name, not the pypi package name, e.g., - ``sklearn``, not ``scikit-learn``. - If not given, ``package_name`` is used as ``package_import_name``, - i.e., it is assumed that the import name is the same as the package name. Returns ------- - None, if package is not found at import ``package_import_name``; - ``importlib`` ``Version`` of package, if found at import ``package_import_name`` + None, if package is not found in python environment. + ``importlib`` ``Version`` of package, if present in environment. """ - if package_import_name is None: - package_import_name = package_name - - # optimized branching to check presence of import - # and presence of package distribution - # first we check import, then we check distribution - # because try/except consumes more runtime - pkg_spec = find_spec(package_import_name) - if pkg_spec is not None: - try: - pkg_env_version = Version(version(package_name)) - except (InvalidVersion, PackageNotFoundError): - pkg_env_version = None - else: + pkgs = _get_installed_packages() + pkg_vers_str = pkgs.get(package_name, None) + if pkg_vers_str is None: + return None + try: + pkg_env_version = Version(pkg_vers_str) + except InvalidVersion: pkg_env_version = None - return pkg_env_version @@ -266,13 +239,22 @@ def _check_python_version(obj, package=None, msg=None, severity="error"): ---------- obj : BaseObject descendant used to check python version + package : str, default = None if given, will be used in error message as package name + msg : str, optional, default = default message (msg below) - error message to be returned in the `ModuleNotFoundError`, overrides default - severity : str, "error" (default), "warning", or "none" + error message to be returned in the ``ModuleNotFoundError``, overrides default + + severity : str, "error" (default), "warning", "none" whether the check should raise an error, a warning, or nothing + * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed + * "warning" - raises a warning if one of packages is not installed + function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning + function returns False if one of packages is not installed, otherwise True + Returns ------- compatible : bool, whether obj is compatible with system python version @@ -305,6 +287,7 @@ def _check_python_version(obj, package=None, msg=None, severity="error"): if sys_version in est_specifier: return True # now we know that est_version is not compatible with sys_version + if isclass(obj): class_name = obj.__name__ else: @@ -321,7 +304,79 @@ def _check_python_version(obj, package=None, msg=None, severity="error"): f" This is due to python version requirements of the {package} package." ) - _raise_at_severity(msg, severity=severity, caller="_check_python_version") + _raise_at_severity(msg, severity, caller="_check_python_version") + return False + + +def _check_env_marker(obj, package=None, msg=None, severity="error"): + """Check if packaging marker tag is with requirements of obj. + + Parameters + ---------- + obj : BaseObject descendant + used to check python version + package : str, default = None + if given, will be used in error message as package name + msg : str, optional, default = default message (msg below) + error message to be returned in the `ModuleNotFoundError`, overrides default + + severity : str, "error" (default), "warning", "none" + whether the check should raise an error, a warning, or nothing + + * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed + * "warning" - raises a warning if one of packages is not installed + function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning + function returns False if one of packages is not installed, otherwise True + + Returns + ------- + compatible : bool, whether obj is compatible with system python version + check is using the python_version tag of obj + + Raises + ------ + InvalidMarker + User friendly error if obj has env_marker tag that is not a + packaging compatible marker string + ModuleNotFoundError + User friendly error if obj has an env_marker tag that is + incompatible with the python environment. If package is given, + error message gives package as the reason for incompatibility. + """ + est_marker_tag = obj.get_class_tag("env_marker", tag_value_default="None") + if est_marker_tag in ["None", None]: + return True + + try: + est_marker = Marker(est_marker_tag) + except InvalidMarker: + msg_version = ( + f"wrong format for env_marker tag, " + f"must be PEP 508 compatible specifier string, e.g., " + f'platform_system!="windows", but found {est_marker_tag!r}' + ) + raise InvalidMarker(msg_version) from None + + if est_marker.evaluate(): + return True + # now we know that est_marker is not compatible with the environment + + if isclass(obj): + class_name = obj.__name__ + else: + class_name = type(obj).__name__ + + if not isinstance(msg, str): + msg = ( + f"{class_name} requires an environment to satisfy " + f"packaging marker spec {est_marker}, but environment does not satisfy it." + ) + + if package is not None: + msg += f" This is due to requirements of the {package} package." + + _raise_at_severity(msg, severity, caller="_check_env_marker") return False @@ -336,17 +391,20 @@ def _check_estimator_deps(obj, msg=None, severity="error"): Parameters ---------- - obj : `BaseObject` descendant, instance or class, or list/tuple thereof + obj : BaseObject descendant, instance or class, or list/tuple thereof object(s) that this function checks compatibility of, with the python env + msg : str, optional, default = default message (msg below) - error message to be returned in the `ModuleNotFoundError`, overrides default - severity : str, "error" (default), "warning", or "none" - behaviour for raising errors or warnings - "error" - raises a `ModuleNotFoundError` if environment is incompatible - "warning" - raises a warning if environment is incompatible - function returns False if environment is incompatible, otherwise True - "none" - does not raise exception or warning - function returns False if environment is incompatible, otherwise True + error message to be returned in the ``ModuleNotFoundError``, overrides default + + severity : str, "error" (default), "warning", "none" + whether the check should raise an error, a warning, or nothing + + * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed + * "warning" - raises a warning if one of packages is not installed + function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning + function returns False if one of packages is not installed, otherwise True Returns ------- @@ -375,6 +433,7 @@ def _check_estimator_deps(obj, msg=None, severity="error"): return compatible compatible = compatible and _check_python_version(obj, severity=severity) + compatible = compatible and _check_env_marker(obj, severity=severity) pkg_deps = obj.get_class_tag("python_dependencies", None) pck_alias = obj.get_class_tag("python_dependencies_alias", None) @@ -437,8 +496,16 @@ def _raise_at_severity( ---------- msg : str message to raise or warn - severity : str, "error", "warning", or "none" - behaviour for raising errors or warnings + + severity : str, "error" (default), "warning", "none" + whether the check should raise an error, a warning, or nothing + + * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed + * "warning" - raises a warning if one of packages is not installed + function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning + function returns False if one of packages is not installed, otherwise True + exception_type : Exception, default=ModuleNotFoundError exception type to raise if severity="severity" warning_type : warning, default=Warning