diff --git a/src/poetry/core/factory.py b/src/poetry/core/factory.py index c70175b6c..103076079 100644 --- a/src/poetry/core/factory.py +++ b/src/poetry/core/factory.py @@ -10,10 +10,19 @@ from typing import Union from warnings import warn +from packaging.utils import NormalizedName from packaging.utils import canonicalize_name +from poetry.core.constraints.generic import BaseConstraint +from poetry.core.constraints.generic import Constraint +from poetry.core.constraints.generic import MultiConstraint +from poetry.core.constraints.generic import UnionConstraint from poetry.core.utils.helpers import combine_unicode from poetry.core.utils.helpers import readme_content_type +from poetry.core.version.markers import BaseMarker +from poetry.core.version.markers import MarkerUnion +from poetry.core.version.markers import MultiMarker +from poetry.core.version.markers import SingleMarkerLike if TYPE_CHECKING: @@ -113,6 +122,38 @@ def _add_package_group_dependencies( package.add_dependency_group(group) + @classmethod + def _get_extras_from_constraint( + cls, constraint: BaseConstraint + ) -> set[NormalizedName]: + if isinstance(constraint, Constraint): + if constraint.operator == "==": + return {canonicalize_name(constraint.value)} + else: + return set() + elif isinstance(constraint, (UnionConstraint, MultiConstraint)): + extras = set() + for c in constraint.constraints: + extras.update(cls._get_extras_from_constraint(c)) + return extras + else: + return set() + + @classmethod + def _get_extras_from_marker(cls, marker: BaseMarker) -> set[NormalizedName]: + if isinstance(marker, SingleMarkerLike): + if marker.name == "extra": + return cls._get_extras_from_constraint(marker.constraint) + else: + return set() + elif isinstance(marker, (MarkerUnion, MultiMarker)): + extras = set() + for m in marker.markers: + extras.update(cls._get_extras_from_marker(m)) + return extras + else: + return set() + @classmethod def configure_package( cls, @@ -188,6 +229,18 @@ def configure_package( dep.in_extras.append(extra_name) package.extras[extra_name].append(dep) + for dep in package.requires: + extras = cls._get_extras_from_marker(dep.marker) + for extra in extras: + if extra not in dep.in_extras: + dep.in_extras.append(extra) + + if extra not in package.extras: + package.extras[extra] = [] + + if dep not in package.extras[extra]: + package.extras[extra].append(dep) + if "build" in config: build = config["build"] if not isinstance(build, dict): diff --git a/tests/fixtures/project_with_extra_in_markers/pyproject.toml b/tests/fixtures/project_with_extra_in_markers/pyproject.toml new file mode 100644 index 000000000..1535e5824 --- /dev/null +++ b/tests/fixtures/project_with_extra_in_markers/pyproject.toml @@ -0,0 +1,14 @@ +[tool.poetry] +name = "extra-package" +version = "1.2.3" +description = "Some description." +authors = ["Your Name "] +license = "MIT" + +[tool.poetry.dependencies] +python = "^3.10" +psycopg = [ + { version = "^3.1.9" }, + { version = "^3.1.9", optional = true , extras = ["binary"], markers = "extra == 'extra-binary'"}, + { version = "^3.1.9", optional = true , extras = ["c"], markers = "extra in 'extra-c, extra-pool'"} +] diff --git a/tests/test_factory.py b/tests/test_factory.py index 43f3c86aa..c13fe79f7 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -494,3 +494,55 @@ def test_all_classifiers_unique_even_if_classifiers_is_duplicated() -> None: "Programming Language :: Python :: 3.11", "Topic :: Software Development :: Build Tools", ] + + +def test_extra_in_markers() -> None: + poetry = Factory().create_poetry(fixtures_dir / "project_with_extra_in_markers") + + package = poetry.package + requires = package.requires + extras = package.extras + + assert package.name == "extra-package" + assert package.version.text == "1.2.3" + assert package.description == "Some description." + assert package.authors == ["Your Name "] + assert package.license + assert package.license.id == "MIT" + + assert package.python_versions == "^3.10" + + assert len(requires) == 3 + assert len(extras) == 3 + + def find(name: str, extras: set[str]) -> Dependency: + return next( + iter( + filter( + lambda dep: dep.name == name and dep.extras == frozenset(extras), + requires, + ) + ) + ) + + psycopg = find("psycopg", set()) + assert all(psycopg not in extra for extra in extras.values()) + assert psycopg.pretty_constraint == "^3.1.9" + assert not psycopg.is_optional() + assert len(psycopg.extras) == 0 + assert len(psycopg.in_extras) == 0 + + psycopg_binary = find("psycopg", {"binary"}) + assert [psycopg_binary] == extras[canonicalize_name("extra-binary")] + assert psycopg_binary.pretty_constraint == "^3.1.9" + assert psycopg_binary.is_optional() + assert len(psycopg_binary.extras) == 1 + assert len(psycopg_binary.in_extras) == 1 + + psycopg_c = find("psycopg", {"c"}) + assert [psycopg_c] == extras[canonicalize_name("extra-c")] + assert [psycopg_c] == extras[canonicalize_name("extra-pool")] + assert psycopg_c.pretty_constraint == "^3.1.9" + assert psycopg_c.is_optional() + assert len(psycopg_c.extras) == 1 + assert len(psycopg_c.in_extras) == 2