diff --git a/README.md b/README.md index f52543b..a65b542 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,12 @@ socket = await start_connection(addr_infos, local_addr_infos=local_addr_infos, h transport, protocol = await loop.create_connection( MyProtocol, sock=socket, ...) +# Remove the first address for each family from addr_info +pop_addr_infos_interleave(addr_info, 1) + +# Remove all matching address from addr_info +remove_addr_infos(addr_info, "dead::beef::") + ``` ## Credits diff --git a/docs/usage.md b/docs/usage.md index 7fbd3c1..ec2e596 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -16,4 +16,10 @@ socket = await aiohappyeyeballs.start_connection(addr_infos, local_addr_infos=lo transport, protocol = await loop.create_connection( MyProtocol, sock=socket, ...) + +# Remove the first address for each family from addr_info +aiohappyeyeballs.pop_addr_infos_interleave(addr_info, 1) + +# Remove all matching address from addr_info +aiohappyeyeballs.remove_addr_infos(addr_info, "dead::beef::") ``` diff --git a/src/aiohappyeyeballs/__init__.py b/src/aiohappyeyeballs/__init__.py index 618bb0f..5025300 100644 --- a/src/aiohappyeyeballs/__init__.py +++ b/src/aiohappyeyeballs/__init__.py @@ -1,5 +1,12 @@ __version__ = "1.7.0" -from .impl import AddrInfoType, start_connection +from .impl import start_connection +from .types import AddrInfoType +from .utils import pop_addr_infos_interleave, remove_addr_infos -__all__ = ("start_connection", "AddrInfoType") +__all__ = ( + "start_connection", + "AddrInfoType", + "remove_addr_infos", + "pop_addr_infos_interleave", +) diff --git a/src/aiohappyeyeballs/types.py b/src/aiohappyeyeballs/types.py new file mode 100644 index 0000000..da2a803 --- /dev/null +++ b/src/aiohappyeyeballs/types.py @@ -0,0 +1,11 @@ +"""Base implementation.""" +import socket +from typing import Tuple, Union + +AddrInfoType = Tuple[ + Union[int, socket.AddressFamily], + Union[int, socket.SocketKind], + int, + str, + Tuple, # type: ignore[type-arg] +] diff --git a/src/aiohappyeyeballs/utils.py b/src/aiohappyeyeballs/utils.py new file mode 100644 index 0000000..d40c644 --- /dev/null +++ b/src/aiohappyeyeballs/utils.py @@ -0,0 +1,51 @@ +"""Utility functions for aiohappyeyeballs.""" + +import ipaddress +from typing import Dict, List + +from .types import AddrInfoType + + +def pop_addr_infos_interleave(addr_infos: List[AddrInfoType], interleave: int) -> None: + """ + Pop addr_info from the list of addr_infos by family up to interleave times. + + The interleave parameter is used to know how many addr_infos for + each family should be popped of the top of the list. + """ + seen: Dict[int, int] = {} + to_remove: List[AddrInfoType] = [] + for addr_info in addr_infos: + family = addr_info[0] + if family not in seen: + seen[family] = 0 + if seen[family] < interleave: + to_remove.append(addr_info) + seen[family] += 1 + for addr_info in to_remove: + addr_infos.remove(addr_info) + + +def remove_addr_infos( + addr_infos: List[AddrInfoType], + address: str, +) -> None: + """Remove an address from the list of addr_infos.""" + bad_addrs_infos: List[AddrInfoType] = [] + for addr_info in addr_infos: + if addr_info[-1][0] == address: + bad_addrs_infos.append(addr_info) + if bad_addrs_infos: + for bad_addr_info in bad_addrs_infos: + addr_infos.remove(bad_addr_info) + return + # Slow path in case addr is formatted differently + ip_address = ipaddress.ip_address(address) + for addr_info in addr_infos: + if ip_address == ipaddress.ip_address(addr_info[-1][0]): + bad_addrs_infos.append(addr_info) + if bad_addrs_infos: + for bad_addr_info in bad_addrs_infos: + addr_infos.remove(bad_addr_info) + return + raise ValueError(f"Address {address} not found in addr_infos") diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..c1e1022 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,107 @@ +import socket +from typing import List + +import pytest + +from aiohappyeyeballs import AddrInfoType, pop_addr_infos_interleave, remove_addr_infos + + +def test_pop_addr_infos_interleave(): + """Test pop_addr_infos_interleave.""" + ipv6_addr_info = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:beef::", 80, 0, 0), + ) + ipv6_addr_info_2 = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:aaaa::", 80, 0, 0), + ) + ipv4_addr_info = ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("107.6.106.83", 80), + ) + addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info] + addr_info_copy = addr_info.copy() + pop_addr_infos_interleave(addr_info_copy, 1) + assert addr_info_copy == [ipv6_addr_info_2] + pop_addr_infos_interleave(addr_info_copy, 1) + assert addr_info_copy == [] + addr_info_copy = addr_info.copy() + pop_addr_infos_interleave(addr_info_copy, 2) + assert addr_info_copy == [] + + +def test_remove_addr_infos(): + """Test remove_addr_infos.""" + ipv6_addr_info = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:beef::", 80, 0, 0), + ) + ipv6_addr_info_2 = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:aaaa::", 80, 0, 0), + ) + ipv4_addr_info = ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("107.6.106.83", 80), + ) + addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info] + addr_info_copy = addr_info.copy() + remove_addr_infos(addr_info_copy, "dead:beef::") + assert addr_info_copy == [ipv6_addr_info_2, ipv4_addr_info] + remove_addr_infos(addr_info_copy, "dead:aaaa::") + assert addr_info_copy == [ipv4_addr_info] + remove_addr_infos(addr_info_copy, "107.6.106.83") + assert addr_info_copy == [] + + +def test_remove_addr_infos_slow_path(): + """Test remove_addr_infos with mis-matched formatting.""" + ipv6_addr_info = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:beef::", 80, 0, 0), + ) + ipv6_addr_info_2 = ( + socket.AF_INET6, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("dead:aaaa::", 80, 0, 0), + ) + ipv4_addr_info = ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + ("107.6.106.83", 80), + ) + addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info] + addr_info_copy = addr_info.copy() + remove_addr_infos(addr_info_copy, "dead:beef:0000:0000:0000:0000:0000:0000") + assert addr_info_copy == [ipv6_addr_info_2, ipv4_addr_info] + remove_addr_infos(addr_info_copy, "dead:aaaa:0000:0000:0000:0000:0000:0000") + assert addr_info_copy == [ipv4_addr_info] + with pytest.raises(ValueError, match="Address 107.6.106.2 not found in addr_infos"): + remove_addr_infos(addr_info_copy, "107.6.106.2") + assert addr_info_copy == [ipv4_addr_info]