Skip to content

Commit

Permalink
feat: add utils (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Dec 10, 2023
1 parent 37da1ec commit d89311d
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 2 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ socket = await start_connection(addr_infos, local_addr_infos=local_addr_infos, h
transport, protocol = await loop.create_connection(
MyProtocol, sock=socket, ...)

# Remove the first address for each family from addr_info
pop_addr_infos_interleave(addr_info, 1)

# Remove all matching address from addr_info
remove_addr_infos(addr_info, "dead::beef::")

```

## Credits
Expand Down
6 changes: 6 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,10 @@ socket = await aiohappyeyeballs.start_connection(addr_infos, local_addr_infos=lo

transport, protocol = await loop.create_connection(
MyProtocol, sock=socket, ...)

# Remove the first address for each family from addr_info
aiohappyeyeballs.pop_addr_infos_interleave(addr_info, 1)

# Remove all matching address from addr_info
aiohappyeyeballs.remove_addr_infos(addr_info, "dead::beef::")
```
11 changes: 9 additions & 2 deletions src/aiohappyeyeballs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
__version__ = "1.7.0"

from .impl import AddrInfoType, start_connection
from .impl import start_connection
from .types import AddrInfoType
from .utils import pop_addr_infos_interleave, remove_addr_infos

__all__ = ("start_connection", "AddrInfoType")
__all__ = (
"start_connection",
"AddrInfoType",
"remove_addr_infos",
"pop_addr_infos_interleave",
)
11 changes: 11 additions & 0 deletions src/aiohappyeyeballs/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Base implementation."""
import socket
from typing import Tuple, Union

AddrInfoType = Tuple[
Union[int, socket.AddressFamily],
Union[int, socket.SocketKind],
int,
str,
Tuple, # type: ignore[type-arg]
]
51 changes: 51 additions & 0 deletions src/aiohappyeyeballs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Utility functions for aiohappyeyeballs."""

import ipaddress
from typing import Dict, List

from .types import AddrInfoType


def pop_addr_infos_interleave(addr_infos: List[AddrInfoType], interleave: int) -> None:
"""
Pop addr_info from the list of addr_infos by family up to interleave times.
The interleave parameter is used to know how many addr_infos for
each family should be popped of the top of the list.
"""
seen: Dict[int, int] = {}
to_remove: List[AddrInfoType] = []
for addr_info in addr_infos:
family = addr_info[0]
if family not in seen:
seen[family] = 0
if seen[family] < interleave:
to_remove.append(addr_info)
seen[family] += 1
for addr_info in to_remove:
addr_infos.remove(addr_info)


def remove_addr_infos(
addr_infos: List[AddrInfoType],
address: str,
) -> None:
"""Remove an address from the list of addr_infos."""
bad_addrs_infos: List[AddrInfoType] = []
for addr_info in addr_infos:
if addr_info[-1][0] == address:
bad_addrs_infos.append(addr_info)
if bad_addrs_infos:
for bad_addr_info in bad_addrs_infos:
addr_infos.remove(bad_addr_info)
return
# Slow path in case addr is formatted differently
ip_address = ipaddress.ip_address(address)
for addr_info in addr_infos:
if ip_address == ipaddress.ip_address(addr_info[-1][0]):
bad_addrs_infos.append(addr_info)
if bad_addrs_infos:
for bad_addr_info in bad_addrs_infos:
addr_infos.remove(bad_addr_info)
return
raise ValueError(f"Address {address} not found in addr_infos")
107 changes: 107 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import socket
from typing import List

import pytest

from aiohappyeyeballs import AddrInfoType, pop_addr_infos_interleave, remove_addr_infos


def test_pop_addr_infos_interleave():
"""Test pop_addr_infos_interleave."""
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
addr_info_copy = addr_info.copy()
pop_addr_infos_interleave(addr_info_copy, 1)
assert addr_info_copy == [ipv6_addr_info_2]
pop_addr_infos_interleave(addr_info_copy, 1)
assert addr_info_copy == []
addr_info_copy = addr_info.copy()
pop_addr_infos_interleave(addr_info_copy, 2)
assert addr_info_copy == []


def test_remove_addr_infos():
"""Test remove_addr_infos."""
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
addr_info_copy = addr_info.copy()
remove_addr_infos(addr_info_copy, "dead:beef::")
assert addr_info_copy == [ipv6_addr_info_2, ipv4_addr_info]
remove_addr_infos(addr_info_copy, "dead:aaaa::")
assert addr_info_copy == [ipv4_addr_info]
remove_addr_infos(addr_info_copy, "107.6.106.83")
assert addr_info_copy == []


def test_remove_addr_infos_slow_path():
"""Test remove_addr_infos with mis-matched formatting."""
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info: List[AddrInfoType] = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
addr_info_copy = addr_info.copy()
remove_addr_infos(addr_info_copy, "dead:beef:0000:0000:0000:0000:0000:0000")
assert addr_info_copy == [ipv6_addr_info_2, ipv4_addr_info]
remove_addr_infos(addr_info_copy, "dead:aaaa:0000:0000:0000:0000:0000:0000")
assert addr_info_copy == [ipv4_addr_info]
with pytest.raises(ValueError, match="Address 107.6.106.2 not found in addr_infos"):
remove_addr_infos(addr_info_copy, "107.6.106.2")
assert addr_info_copy == [ipv4_addr_info]

0 comments on commit d89311d

Please sign in to comment.