From 1f2d970c2e98a80ea426c2d3f14ec48ac37c0870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 Aug 2024 23:55:37 +0100 Subject: [PATCH 1/6] Update _dependencies.py --- skbase/utils/dependencies/_dependencies.py | 422 +++++++++++++++------ 1 file changed, 310 insertions(+), 112 deletions(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index cb4f4dd2..832884e5 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -2,25 +2,25 @@ """Utility to check soft dependency imports, and raise warnings or errors.""" import sys import warnings -from importlib import import_module +from functools import lru_cache +from importlib.metadata import distributions +from importlib.util import find_spec from inspect import isclass -from typing import List +from packaging.markers import InvalidMarker, Marker from packaging.requirements import InvalidRequirement, Requirement -from packaging.specifiers import InvalidSpecifier, SpecifierSet - -from skbase.utils.stdout_mute import StdoutMute - -__author__: List[str] = ["fkiraly", "mloning"] +from packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet +from packaging.version import InvalidVersion, Version +# todo 0.32.0: remove suppress_import_stdout argument def _check_soft_dependencies( *packages, - package_import_alias=None, + package_import_alias="deprecated", severity="error", obj=None, msg=None, - suppress_import_stdout=False, + suppress_import_stdout="deprecated", ): """Check if required soft dependencies are installed and raise error or warning. @@ -29,33 +29,33 @@ def _check_soft_dependencies( packages : str or list/tuple of str, or length-1-tuple containing list/tuple of str str should be package names and/or package version specifications to check. Each str must be a PEP 440 compatible specifier string, for a single package. - For instance, the PEP 440 compatible package name such as "pandas"; - or a package requirement specifier string such as "pandas>1.2.3". + For instance, the PEP 440 compatible package name such as ``"pandas"``; + or a package requirement specifier string such as ``"pandas>1.2.3"``. arg can be str, kwargs tuple, or tuple/list of str, following calls are valid: - `_check_soft_dependencies("package1")` - `_check_soft_dependencies("package1", "package2")` - `_check_soft_dependencies(("package1", "package2"))` - `_check_soft_dependencies(["package1", "package2"])` - package_import_alias : dict with str keys and values, optional, default=empty - key-value pairs are package name, import name - import name is str used in python import, i.e., from import_name import ... - should be provided if import name differs from package name + ``_check_soft_dependencies("package1")`` + ``_check_soft_dependencies("package1", "package2")`` + ``_check_soft_dependencies(("package1", "package2"))`` + ``_check_soft_dependencies(["package1", "package2"])`` + + package_import_alias : ignored, present only for backwards compatibility + severity : str, "error" (default), "warning", "none" - behaviour for raising errors or warnings - "error" - raises a `ModuleNotFoundError` if one of packages is not installed - "warning" - raises a warning if one of packages is not installed - function returns False if one of packages is not installed, otherwise True - "none" - does not raise exception or warning - function returns False if one of packages is not installed, otherwise True + whether the check should raise an error, a warning, or nothing + + * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed + * "warning" - raises a warning if one of packages is not installed + function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning + function returns False if one of packages is not installed, otherwise True + obj : python class, object, str, or None, default=None if self is passed here when _check_soft_dependencies is called within __init__, or a class is passed when it is called at the start of a single-class module, the error message is more informative and will refer to the class/object; if str is passed, will be used as name of the class/object or module + msg : str, or None, default=None if str, will override the error message or warning shown with msg - suppress_import_stdout : bool, optional. Default=False - whether to suppress stdout printout upon import. Raises ------ @@ -66,6 +66,22 @@ def _check_soft_dependencies( ------- boolean - whether all packages are installed, only if no exception is raised """ + # todo 0.10.0: remove this warning + if suppress_import_stdout != "deprecated": + warnings.warn( + "In skbase _check_soft_dependencies, the suppress_import_stdout argument " + "is deprecated and no longer has any effect. " + "The argument will be removed in version 0.10.0, so users of the " + "_check_soft_dependencies utility should not pass this argument anymore. " + "The _check_soft_dependencies utility also no longer causes imports, " + "hence no stdout " + "output is created from imports, for any setting of the " + "suppress_import_stdout argument. If you wish to import packages " + "and make use of stdout prints, import the package directly instead.", + DeprecationWarning, + stacklevel=2, + ) + if len(packages) == 1 and isinstance(packages[0], (tuple, list)): packages = packages[0] if not all(isinstance(x, str) for x in packages): @@ -74,20 +90,6 @@ def _check_soft_dependencies( f"str, but found packages argument of type {type(packages)}" ) - if package_import_alias is None: - package_import_alias = {} - msg_pkg_import_alias = ( - "package_import_alias argument of _check_soft_dependencies must " - "be a dict with str keys and values, but found " - f"package_import_alias of type {type(package_import_alias)}" - ) - if not isinstance(package_import_alias, dict): - raise TypeError(msg_pkg_import_alias) - if not all(isinstance(x, str) for x in package_import_alias.keys()): - raise TypeError(msg_pkg_import_alias) - if not all(isinstance(x, str) for x in package_import_alias.values()): - raise TypeError(msg_pkg_import_alias) - if obj is None: class_name = "This functionality" elif not isclass(obj): @@ -112,65 +114,47 @@ def _check_soft_dependencies( for package in packages: try: req = Requirement(package) + req = _normalize_requirement(req) except InvalidRequirement: msg_version = ( - f"wrong format for package requirement string, " + f"wrong format for package requirement string " f"passed via packages argument of _check_soft_dependencies, " f'must be PEP 440 compatible requirement string, e.g., "pandas"' f' or "pandas>1.1", but found {package!r}' ) - raise InvalidRequirement(msg_version) from None + raise InvalidRequirement(msg_version) package_name = req.name package_version_req = req.specifier - # determine the package import - if package_name in package_import_alias.keys(): - package_import_name = package_import_alias[package_name] - else: - package_import_name = package_name - # attempt import - if not possible, we know we need to raise warning/exception - try: - with StdoutMute(active=suppress_import_stdout): - pkg_ref = import_module(package_import_name) - # if package cannot be imported, make the user aware of installation requirement - except ModuleNotFoundError as e: - if msg is None: + pkg_env_version = _get_pkg_version(package_name) + + # if package not present, make the user aware of installation reqs + if pkg_env_version is None: + if obj is None and msg is None: msg = ( - f"{e}. " f"{class_name} requires package {package!r} to be present " f"in the python environment, but {package!r} was not found. " ) - if obj is not None: - msg = msg + ( - f"{package!r} is a dependency of {class_name} and required " - f"to construct it. " - ) - msg = msg + ( - f"Please run: `pip install {package}` to " - f"install the {package} package. " + elif msg is None: # obj is not None, msg is None + msg = ( + f"{class_name} requires package {package!r} to be present " + f"in the python environment, but {package!r} was not found. " + f"{package!r} is a dependency of {class_name} and required " + f"to construct it. " ) + msg = msg + ( + f"Please run: `pip install {package}` to " + f"install the {package} package. " + ) # if msg is not None, none of the above is executed, # so if msg is passed it overrides the default messages - if severity == "error": - raise ModuleNotFoundError(msg) from e - elif severity == "warning": - warnings.warn(msg, stacklevel=2) - return False - elif severity == "none": - return False - else: - raise RuntimeError( - "Error in calling _check_soft_dependencies, severity " - 'argument must be "error", "warning", or "none",' - f"found {severity!r}." - ) from e + _raise_at_severity(msg, severity, caller="_check_soft_dependencies") + return False # now we check compatibility with the version specifier if non-empty if package_version_req != SpecifierSet(""): - pkg_env_version = pkg_ref.__version__ - msg = ( f"{class_name} requires package {package!r} to be present " f"in the python environment, with version {package_version_req}, " @@ -184,23 +168,67 @@ def _check_soft_dependencies( # raise error/warning or return False if version is incompatible if pkg_env_version not in package_version_req: - if severity == "error": - raise ModuleNotFoundError(msg) - elif severity == "warning": - warnings.warn(msg, stacklevel=2) - elif severity == "none": - return False - else: - raise RuntimeError( - "Error in calling _check_soft_dependencies, severity argument" - f' must be "error", "warning", or "none", found {severity!r}.' - ) + _raise_at_severity(msg, severity, caller="_check_soft_dependencies") + return False # if package can be imported and no version issue was caught for any string, # then obj is compatible with the requirements and we should return True return True +@lru_cache +def _get_installed_packages_private(): + """Get a dictionary of installed packages and their versions. + + Same as _get_installed_packages, but internal to avoid mutating the lru_cache + by accident. + """ + dists = distributions() + packages = {dist.metadata["Name"]: dist.version for dist in dists} + return packages + + +def _get_installed_packages(): + """Get a dictionary of installed packages and their versions. + + Returns + ------- + dict : dictionary of installed packages and their versions + keys are PEP 440 compatible package names, values are package versions + MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3" + """ + return _get_installed_packages_private().copy() + + +def _get_pkg_version(package_name): + """Check whether package is available in environment, and return its version if yes. + + Returns ``Version`` object from ``lru_cache``, this should not be mutated. + + Parameters + ---------- + package_name : str, optional, default=None + name of package to check, + PEP 440 compatibe specifier string, e.g., "pandas" or "sklearn". + This is the pypi package name, not the import name, e.g., + ``scikit-learn``, not ``sklearn``. + + Returns + ------- + None, if package is not found in python environment. + ``importlib`` ``Version`` of package, if present in environment. + """ + pkgs = _get_installed_packages() + pkg_vers_str = pkgs.get(package_name, None) + if pkg_vers_str is None: + return None + try: + pkg_env_version = Version(pkg_vers_str) + except InvalidVersion: + pkg_env_version = None + return pkg_env_version + + def _check_python_version(obj, package=None, msg=None, severity="error"): """Check if system python version is compatible with requirements of obj. @@ -208,13 +236,22 @@ def _check_python_version(obj, package=None, msg=None, severity="error"): ---------- obj : BaseObject descendant used to check python version + package : str, default = None if given, will be used in error message as package name + msg : str, optional, default = default message (msg below) - error message to be returned in the `ModuleNotFoundError`, overrides default - severity : str, "error" (default), "warning", or "none" + error message to be returned in the ``ModuleNotFoundError``, overrides default + + severity : str, "error" (default), "warning", "none" whether the check should raise an error, a warning, or nothing + * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed + * "warning" - raises a warning if one of packages is not installed + function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning + function returns False if one of packages is not installed, otherwise True + Returns ------- compatible : bool, whether obj is compatible with system python version @@ -239,7 +276,7 @@ def _check_python_version(obj, package=None, msg=None, severity="error"): f'must be PEP 440 compatible specifier string, e.g., "<3.9, >= 3.6.3",' f" but found {est_specifier_tag!r}" ) - raise InvalidSpecifier(msg_version) from None + raise InvalidSpecifier(msg_version) # python sys version, e.g., "3.8.12" sys_version = sys.version.split(" ")[0] @@ -247,6 +284,7 @@ def _check_python_version(obj, package=None, msg=None, severity="error"): if sys_version in est_specifier: return True # now we know that est_version is not compatible with sys_version + if isclass(obj): class_name = obj.__name__ else: @@ -263,18 +301,80 @@ def _check_python_version(obj, package=None, msg=None, severity="error"): f" This is due to python version requirements of the {package} package." ) - if severity == "error": - raise ModuleNotFoundError(msg) - elif severity == "warning": - warnings.warn(msg, stacklevel=2) - elif severity == "none": - return False + _raise_at_severity(msg, severity, caller="_check_python_version") + return False + + +def _check_env_marker(obj, package=None, msg=None, severity="error"): + """Check if packaging marker tag is with requirements of obj. + + Parameters + ---------- + obj : BaseObject descendant + used to check python version + package : str, default = None + if given, will be used in error message as package name + msg : str, optional, default = default message (msg below) + error message to be returned in the `ModuleNotFoundError`, overrides default + + severity : str, "error" (default), "warning", "none" + whether the check should raise an error, a warning, or nothing + + * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed + * "warning" - raises a warning if one of packages is not installed + function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning + function returns False if one of packages is not installed, otherwise True + + Returns + ------- + compatible : bool, whether obj is compatible with system python version + check is using the python_version tag of obj + + Raises + ------ + InvalidMarker + User friendly error if obj has env_marker tag that is not a + packaging compatible marker string + ModuleNotFoundError + User friendly error if obj has an env_marker tag that is + incompatible with the python environment. If package is given, + error message gives package as the reason for incompatibility. + """ + est_marker_tag = obj.get_class_tag("env_marker", tag_value_default="None") + if est_marker_tag in ["None", None]: + return True + + try: + est_marker = Marker(est_marker_tag) + except InvalidMarker: + msg_version = ( + f"wrong format for env_marker tag, " + f"must be PEP 508 compatible specifier string, e.g., " + f'platform_system!="windows", but found "{est_marker_tag}"' + ) + raise InvalidMarker(msg_version) + + if est_marker.evaluate(): + return True + # now we know that est_marker is not compatible with the environment + + if isclass(obj): + class_name = obj.__name__ else: - raise RuntimeError( - "Error in calling _check_python_version, severity " - f'argument must be "error", "warning", or "none", found {severity!r}.' + class_name = type(obj).__name__ + + if not isinstance(msg, str): + msg = ( + f"{class_name} requires an environment to satisfy " + f"packaging marker spec {est_marker}, but environment does not satisfy it." ) - return True + + if package is not None: + msg += f" This is due to requirements of the {package} package." + + _raise_at_severity(msg, severity, caller="_check_env_marker") + return False def _check_estimator_deps(obj, msg=None, severity="error"): @@ -288,17 +388,20 @@ def _check_estimator_deps(obj, msg=None, severity="error"): Parameters ---------- - obj : `BaseObject` descendant, instance or class, or list/tuple thereof + obj : BaseObject descendant, instance or class, or list/tuple thereof object(s) that this function checks compatibility of, with the python env + msg : str, optional, default = default message (msg below) - error message to be returned in the `ModuleNotFoundError`, overrides default - severity : str, "error" (default), "warning", or "none" - behaviour for raising errors or warnings - "error" - raises a `ModuleNotFoundError` if environment is incompatible - "warning" - raises a warning if environment is incompatible - function returns False if environment is incompatible, otherwise True - "none" - does not raise exception or warning - function returns False if environment is incompatible, otherwise True + error message to be returned in the ``ModuleNotFoundError``, overrides default + + severity : str, "error" (default), "warning", "none" + whether the check should raise an error, a warning, or nothing + + * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed + * "warning" - raises a warning if one of packages is not installed + function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning + function returns False if one of packages is not installed, otherwise True Returns ------- @@ -327,6 +430,7 @@ def _check_estimator_deps(obj, msg=None, severity="error"): return compatible compatible = compatible and _check_python_version(obj, severity=severity) + compatible = compatible and _check_env_marker(obj, severity=severity) pkg_deps = obj.get_class_tag("python_dependencies", None) pck_alias = obj.get_class_tag("python_dependencies_alias", None) @@ -339,3 +443,97 @@ def _check_estimator_deps(obj, msg=None, severity="error"): compatible = compatible and pkg_deps_ok return compatible + + +def _normalize_requirement(req): + """Normalize packaging Requirement by removing build metadata from versions. + + Parameters + ---------- + req : packaging.requirements.Requirement + requirement string to normalize, e.g., Requirement("pandas>1.2.3+foobar") + + Returns + ------- + normalized_req : packaging.requirements.Requirement + normalized requirement object with build metadata removed from versions, + e.g., Requirement("pandas>1.2.3") + """ + # Process each specifier in the requirement + normalized_specs = [] + for spec in req.specifier: + # Parse the version and remove the build metadata + spec_v = Version(spec.version) + version_wo_build_metadata = f"{spec_v.major}.{spec_v.minor}.{spec_v.micro}" + + # Create a new specifier without the build metadata + normalized_spec = Specifier(f"{spec.operator}{version_wo_build_metadata}") + normalized_specs.append(normalized_spec) + + # Reconstruct the specifier set + normalized_specifier_set = SpecifierSet(",".join(str(s) for s in normalized_specs)) + + # Create a new Requirement object with the normalized specifiers + normalized_req = Requirement(f"{req.name}{normalized_specifier_set}") + + return normalized_req + + +def _raise_at_severity( + msg, + severity, + exception_type=None, + warning_type=None, + stacklevel=2, + caller="_raise_at_severity", +): + """Raise exception or warning or take no action, based on severity. + + Parameters + ---------- + msg : str + message to raise or warn + + severity : str, "error" (default), "warning", "none" + whether the check should raise an error, a warning, or nothing + + * "error" - raises a ``ModuleNotFoundError`` if one of packages is not installed + * "warning" - raises a warning if one of packages is not installed + function returns False if one of packages is not installed, otherwise True + * "none" - does not raise exception or warning + function returns False if one of packages is not installed, otherwise True + + exception_type : Exception, default=ModuleNotFoundError + exception type to raise if severity="severity" + warning_type : warning, default=Warning + warning type to raise if severity="warning" + stacklevel : int, default=2 + stacklevel for warnings, if severity="warning" + caller : str, default="_raise_at_severity" + caller name, used in exception if severity not in ["error", "warning", "none"] + + Returns + ------- + None + + Raises + ------ + exception : exception_type, if severity="error" + warning : warning+type, if severity="warning" + ValueError : if severity not in ["error", "warning", "none"] + """ + if exception_type is None: + exception_type = ModuleNotFoundError + + if severity == "error": + raise exception_type(msg) + elif severity == "warning": + warnings.warn(msg, category=warning_type, stacklevel=stacklevel) + elif severity == "none": + return None + else: + raise ValueError( + f"Error in calling {caller}, severity " + f'argument must be "error", "warning", or "none", found {severity!r}.' + ) + return None From 5c87b5699006482d0ba20b805b7d7babb7dfe439 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 Aug 2024 23:57:16 +0100 Subject: [PATCH 2/6] Update _dependencies.py --- skbase/utils/dependencies/_dependencies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index 832884e5..348fcf28 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -13,7 +13,7 @@ from packaging.version import InvalidVersion, Version -# todo 0.32.0: remove suppress_import_stdout argument +# todo 0.10.0: remove suppress_import_stdout argument def _check_soft_dependencies( *packages, package_import_alias="deprecated", From f9e8bc18115efbba9ea7ee2405ae3e0143719ac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 Aug 2024 23:57:29 +0100 Subject: [PATCH 3/6] Update _dependencies.py --- skbase/utils/dependencies/_dependencies.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index 348fcf28..da8b62f5 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -4,7 +4,6 @@ import warnings from functools import lru_cache from importlib.metadata import distributions -from importlib.util import find_spec from inspect import isclass from packaging.markers import InvalidMarker, Marker From 98e0921c7cbc3b579fc09a8361a6095c0f04bbfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 Aug 2024 23:58:54 +0100 Subject: [PATCH 4/6] Update _dependencies.py --- skbase/utils/dependencies/_dependencies.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index da8b62f5..838cf3b4 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -58,8 +58,12 @@ def _check_soft_dependencies( Raises ------ + InvalidRequirement + if package requirement strings are not PEP 440 compatible ModuleNotFoundError error with informative message, asking to install required soft dependencies + TypeError, ValueError + on invalid arguments Returns ------- @@ -116,7 +120,7 @@ def _check_soft_dependencies( req = _normalize_requirement(req) except InvalidRequirement: msg_version = ( - f"wrong format for package requirement string " + f"wrong format for package requirement string, " f"passed via packages argument of _check_soft_dependencies, " f'must be PEP 440 compatible requirement string, e.g., "pandas"' f' or "pandas>1.1", but found {package!r}' From a7b12ef06f8e15388c2b8b09cf5d376b55496eb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Fri, 2 Aug 2024 00:03:35 +0100 Subject: [PATCH 5/6] Update _dependencies.py --- skbase/utils/dependencies/_dependencies.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skbase/utils/dependencies/_dependencies.py b/skbase/utils/dependencies/_dependencies.py index 838cf3b4..9fcf5205 100644 --- a/skbase/utils/dependencies/_dependencies.py +++ b/skbase/utils/dependencies/_dependencies.py @@ -125,7 +125,7 @@ def _check_soft_dependencies( f'must be PEP 440 compatible requirement string, e.g., "pandas"' f' or "pandas>1.1", but found {package!r}' ) - raise InvalidRequirement(msg_version) + raise InvalidRequirement(msg_version) from None package_name = req.name package_version_req = req.specifier @@ -279,7 +279,7 @@ def _check_python_version(obj, package=None, msg=None, severity="error"): f'must be PEP 440 compatible specifier string, e.g., "<3.9, >= 3.6.3",' f" but found {est_specifier_tag!r}" ) - raise InvalidSpecifier(msg_version) + raise InvalidSpecifier(msg_version) from None # python sys version, e.g., "3.8.12" sys_version = sys.version.split(" ")[0] @@ -354,9 +354,9 @@ def _check_env_marker(obj, package=None, msg=None, severity="error"): msg_version = ( f"wrong format for env_marker tag, " f"must be PEP 508 compatible specifier string, e.g., " - f'platform_system!="windows", but found "{est_marker_tag}"' + f'platform_system!="windows", but found {est_marker_tag!r}' ) - raise InvalidMarker(msg_version) + raise InvalidMarker(msg_version) from None if est_marker.evaluate(): return True From 1515c1b15c215df9224b49b09d8184ea8451962a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Fri, 2 Aug 2024 00:10:46 +0100 Subject: [PATCH 6/6] Update conftest.py --- skbase/tests/conftest.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 279478ef..9615a3b8 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -251,7 +251,10 @@ "skbase.utils.dependencies._dependencies": ( "_check_soft_dependencies", "_check_python_version", + "_check_env_marker", "_check_estimator_deps", + "_get_pkg_version", + "_get_installed_packages", "_normalize_requirement", "_raise_at_severity", ),