Skip to content

Commit

Permalink
feat: add more tests (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Dec 9, 2023
1 parent 4d21d43 commit 4428c07
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ python = "^3.8"
[tool.poetry.group.dev.dependencies]
pytest = "^7.0"
pytest-cov = "^3.0"
pytest-asyncio = "^0.23.2"

[tool.poetry.group.docs]
optional = true
Expand Down
3 changes: 0 additions & 3 deletions src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ async def create_connection(
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
happy_eyeballs_delay: Optional[float] = None,
interleave: Optional[int] = None,
all_errors: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> socket.socket:
"""
Expand Down Expand Up @@ -70,8 +69,6 @@ async def create_connection(
if sock is None:
all_exceptions = [exc for sub in exceptions for exc in sub]
try:
if all_errors:
raise ExceptionGroup("create_connection failed", all_exceptions)
if len(all_exceptions) == 1:
raise all_exceptions[0]
else:
Expand Down
23 changes: 21 additions & 2 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import socket
from test.test_asyncio import utils as test_utils
from types import ModuleType
Expand All @@ -7,8 +8,6 @@

from aiohappyeyeballs import create_connection

MOCK_ANY = mock.ANY


def mock_socket_module():
m_socket = mock.MagicMock(spec=socket)
Expand Down Expand Up @@ -56,3 +55,23 @@ def _socket(*args, **kw):
addr_info = [(2, 1, 6, "", ("107.6.106.82", 80))]
with pytest.raises(OSError, match=errors[0]):
await create_connection(addr_info)


@pytest.mark.asyncio
@patch_socket
async def test_create_connection_single_addr_success(m_socket: ModuleType) -> None:
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)

def _socket(*args, **kw):
return mock_socket

m_socket.socket = _socket # type: ignore
addr_info = [(2, 1, 6, "", ("107.6.106.82", 80))]
loop = asyncio.get_running_loop()
with mock.patch.object(loop, "sock_connect", return_value=None):
assert await create_connection(addr_info) == mock_socket

0 comments on commit 4428c07

Please sign in to comment.