Skip to content

Commit

Permalink
Prioritize valid dists to invalid dists when retrieving by name.
Browse files Browse the repository at this point in the history
Closes #489
  • Loading branch information
jaraco committed Jul 23, 2024
1 parent 48f6b14 commit a65c29a
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 3 deletions.
14 changes: 12 additions & 2 deletions importlib_metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
install,
)
from ._functools import method_cache, pass_none
from ._itertools import always_iterable, unique_everseen
from ._itertools import always_iterable, bucket, unique_everseen
from ._meta import PackageMetadata, SimplePath

from contextlib import suppress
Expand Down Expand Up @@ -388,7 +388,7 @@ def from_name(cls, name: str) -> Distribution:
if not name:
raise ValueError("A distribution name is required.")
try:
return next(iter(cls.discover(name=name)))
return next(iter(cls._prefer_valid(cls.discover(name=name))))
except StopIteration:
raise PackageNotFoundError(name)

Expand All @@ -412,6 +412,16 @@ def discover(
resolver(context) for resolver in cls._discover_resolvers()
)

@staticmethod
def _prefer_valid(dists: Iterable[Distribution]) -> Iterable[Distribution]:
"""
Prefer (move to the front) distributions that have metadata.
Ref python/importlib_resources#489.
"""
buckets = bucket(dists, lambda dist: bool(dist.metadata))
return itertools.chain(buckets[True], buckets[False])

@staticmethod
def at(path: str | os.PathLike[str]) -> Distribution:
"""Return a Distribution for the indicated metadata path.
Expand Down
98 changes: 98 additions & 0 deletions importlib_metadata/_itertools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict, deque
from itertools import filterfalse


Expand Down Expand Up @@ -71,3 +72,100 @@ def always_iterable(obj, base_type=(str, bytes)):
return iter(obj)
except TypeError:
return iter((obj,))


# Copied from more_itertools 10.3
class bucket:
"""Wrap *iterable* and return an object that buckets the iterable into
child iterables based on a *key* function.
>>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
>>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
>>> sorted(list(s)) # Get the keys
['a', 'b', 'c']
>>> a_iterable = s['a']
>>> next(a_iterable)
'a1'
>>> next(a_iterable)
'a2'
>>> list(s['b'])
['b1', 'b2', 'b3']
The original iterable will be advanced and its items will be cached until
they are used by the child iterables. This may require significant storage.
By default, attempting to select a bucket to which no items belong will
exhaust the iterable and cache all values.
If you specify a *validator* function, selected buckets will instead be
checked against it.
>>> from itertools import count
>>> it = count(1, 2) # Infinite sequence of odd numbers
>>> key = lambda x: x % 10 # Bucket by last digit
>>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
>>> s = bucket(it, key=key, validator=validator)
>>> 2 in s
False
>>> list(s[2])
[]
"""

def __init__(self, iterable, key, validator=None):
self._it = iter(iterable)
self._key = key
self._cache = defaultdict(deque)
self._validator = validator or (lambda x: True)

def __contains__(self, value):
if not self._validator(value):
return False

try:
item = next(self[value])
except StopIteration:
return False
else:
self._cache[value].appendleft(item)

return True

def _get_values(self, value):
"""
Helper to yield items from the parent iterator that match *value*.
Items that don't match are stored in the local cache as they
are encountered.
"""
while True:
# If we've cached some items that match the target value, emit
# the first one and evict it from the cache.
if self._cache[value]:
yield self._cache[value].popleft()
# Otherwise we need to advance the parent iterator to search for
# a matching item, caching the rest.
else:
while True:
try:
item = next(self._it)
except StopIteration:
return
item_value = self._key(item)
if item_value == value:
yield item
break
elif self._validator(item_value):
self._cache[item_value].append(item)

def __iter__(self):
for item in self._it:
item_value = self._key(item)
if self._validator(item_value):
self._cache[item_value].append(item)

yield from self._cache.keys()

def __getitem__(self, value):
if not self._validator(value):
return iter(())

return self._get_values(value)
1 change: 1 addition & 0 deletions newsfragments/489.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Prioritize valid dists to invalid dists when retrieving by name.
1 change: 0 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def make_pkg(name, files=dict(METADATA="VERSION: 1.0")):
f'{name}.dist-info': files,
}

@__import__('pytest').mark.xfail(reason="#489")
def test_valid_dists_preferred(self):
"""
Dists with metadata should be preferred when discovered by name.
Expand Down

0 comments on commit a65c29a

Please sign in to comment.