Skip to content

Commit

Permalink
feat!: require the full address tuple for the remove_addr_infos util (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Dec 10, 2023
1 parent dc30a22 commit d7e5df1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
28 changes: 21 additions & 7 deletions src/aiohappyeyeballs/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utility functions for aiohappyeyeballs."""

import ipaddress
from typing import Dict, List
from typing import Dict, List, Tuple, Union

from .types import AddrInfoType

Expand All @@ -26,26 +26,40 @@ def pop_addr_infos_interleave(addr_infos: List[AddrInfoType], interleave: int) -
addr_infos.remove(addr_info)


def _addr_tuple_to_ip_address(
addr: Union[Tuple[str, int], Tuple[str, int, int, int]]
) -> Union[
Tuple[ipaddress.IPv4Address, int], Tuple[ipaddress.IPv6Address, int, int, int]
]:
"""Convert an address tuple to an IPv4Address."""
return (ipaddress.ip_address(addr[0]), *addr[1:])


def remove_addr_infos(
addr_infos: List[AddrInfoType],
address: str,
addr: Union[Tuple[str, int], Tuple[str, int, int, int]],
) -> None:
"""Remove an address from the list of addr_infos."""
"""
Remove an address from the list of addr_infos.
The addr value is typically the return value of
sock.getpeername().
"""
bad_addrs_infos: List[AddrInfoType] = []
for addr_info in addr_infos:
if addr_info[-1][0] == address:
if addr_info[-1] == addr:
bad_addrs_infos.append(addr_info)
if bad_addrs_infos:
for bad_addr_info in bad_addrs_infos:
addr_infos.remove(bad_addr_info)
return
# Slow path in case addr is formatted differently
ip_address = ipaddress.ip_address(address)
match_addr = _addr_tuple_to_ip_address(addr)
for addr_info in addr_infos:
if ip_address == ipaddress.ip_address(addr_info[-1][0]):
if match_addr == _addr_tuple_to_ip_address(addr_info[-1]):
bad_addrs_infos.append(addr_info)
if bad_addrs_infos:
for bad_addr_info in bad_addrs_infos:
addr_infos.remove(bad_addr_info)
return
raise ValueError(f"Address {address} not found in addr_infos")
raise ValueError(f"Address {addr} not found in addr_infos")
23 changes: 16 additions & 7 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ def test_remove_addr_infos():
)
addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
addr_info_copy = addr_info.copy()
remove_addr_infos(addr_info_copy, "dead:beef::")
remove_addr_infos(
addr_info_copy,
("dead:beef::", 80, 0, 0),
)
assert addr_info_copy == [ipv6_addr_info_2, ipv4_addr_info]
remove_addr_infos(addr_info_copy, "dead:aaaa::")
remove_addr_infos(addr_info_copy, ("dead:aaaa::", 80, 0, 0))
assert addr_info_copy == [ipv4_addr_info]
remove_addr_infos(addr_info_copy, "107.6.106.83")
remove_addr_infos(addr_info_copy, ("107.6.106.83", 80))
assert addr_info_copy == []


Expand Down Expand Up @@ -98,10 +101,16 @@ def test_remove_addr_infos_slow_path():
)
addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
addr_info_copy = addr_info.copy()
remove_addr_infos(addr_info_copy, "dead:beef:0000:0000:0000:0000:0000:0000")
remove_addr_infos(
addr_info_copy, ("dead:beef:0000:0000:0000:0000:0000:0000", 80, 0, 0)
)
assert addr_info_copy == [ipv6_addr_info_2, ipv4_addr_info]
remove_addr_infos(addr_info_copy, "dead:aaaa:0000:0000:0000:0000:0000:0000")
remove_addr_infos(
addr_info_copy, ("dead:aaaa:0000:0000:0000:0000:0000:0000", 80, 0, 0)
)
assert addr_info_copy == [ipv4_addr_info]
with pytest.raises(ValueError, match="Address 107.6.106.2 not found in addr_infos"):
remove_addr_infos(addr_info_copy, "107.6.106.2")
with pytest.raises(
ValueError, match=r"Address \('107.6.106.2', 80\) not found in addr_infos"
):
remove_addr_infos(addr_info_copy, ("107.6.106.2", 80))
assert addr_info_copy == [ipv4_addr_info]

0 comments on commit d7e5df1

Please sign in to comment.