Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: preserve errno if all exceptions have the same errno #77

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions src/aiohappyeyeballs/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def start_connection(
addr_infos = _interleave_addrinfos(addr_infos, interleave)

sock: Optional[socket.socket] = None
exceptions: List[List[Exception]] = []
exceptions: List[List[OSError]] = []
if happy_eyeballs_delay is None or single_addr_info:
# not using happy eyeballs
for addrinfo in addr_infos:
Expand All @@ -99,20 +99,28 @@ async def start_connection(
if sock is None:
all_exceptions = [exc for sub in exceptions for exc in sub]
try:
first_exception = all_exceptions[0]
if len(all_exceptions) == 1:
raise all_exceptions[0]
raise first_exception
else:
# If they all have the same str(), raise one.
model = str(all_exceptions[0])
model = str(first_exception)
if all(str(exc) == model for exc in all_exceptions):
raise all_exceptions[0]
raise first_exception
# Raise a combined exception so the user can see all
# the various error messages.
raise OSError(
"Multiple exceptions: {}".format(
", ".join(str(exc) for exc in all_exceptions)
)
msg = "Multiple exceptions: {}".format(
", ".join(str(exc) for exc in all_exceptions)
)
# If the errno is the same for all exceptions, raise
# an OSError with that errno.
first_errno = first_exception.errno
if all(
isinstance(exc, OSError) and exc.errno == first_errno
for exc in all_exceptions
):
raise OSError(first_errno, msg)
raise OSError(msg)
finally:
all_exceptions = None # type: ignore[assignment]
exceptions = None # type: ignore[assignment]
Expand All @@ -122,12 +130,12 @@ async def start_connection(

async def _connect_sock(
loop: asyncio.AbstractEventLoop,
exceptions: List[List[Exception]],
exceptions: List[List[OSError]],
addr_info: AddrInfoType,
local_addr_infos: Optional[Sequence[AddrInfoType]] = None,
) -> socket.socket:
"""Create, bind and connect one socket."""
my_exceptions: list[Exception] = []
my_exceptions: list[OSError] = []
exceptions.append(my_exceptions)
family, type_, proto, _, address = addr_info
sock = None
Expand Down
97 changes: 95 additions & 2 deletions tests/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,9 +1191,10 @@ async def _sock_connect(

@patch_socket
@pytest.mark.asyncio
async def test_all_same_exception(
async def test_all_same_exception_and_same_errno(
m_socket: ModuleType,
) -> None:
"""Test that all exceptions are the same and have the same errno."""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
Expand Down Expand Up @@ -1256,7 +1257,96 @@ async def _sock_connect(
# We should get the same exception raised if they are all the same
with mock.patch.object(loop, "sock_connect", _sock_connect), pytest.raises(
OSError, match="all fail"
):
) as exc_info:
assert (
await start_connection(
addr_info,
happy_eyeballs_delay=0.3,
interleave=2,
local_addr_infos=local_addr_infos,
)
== mock_socket
)

assert exc_info.value.errno == 5

# All calls failed
assert create_calls == [
("dead:beef::", 80, 0, 0),
("dead:aaaa::", 80, 0, 0),
("107.6.106.83", 80),
]


@patch_socket
@pytest.mark.asyncio
async def test_all_same_exception_and_with_different_errno(
m_socket: ModuleType,
) -> None:
"""Test no errno is set if all OSError have different errno."""
mock_socket = mock.MagicMock(
family=socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
fileno=mock.MagicMock(return_value=1),
)
create_calls = []

def _socket(*args, **kw):
for attr in kw:
setattr(mock_socket, attr, kw[attr])
return mock_socket

async def _sock_connect(
sock: socket.socket, address: Tuple[str, int, int, int]
) -> None:
create_calls.append(address)
raise OSError(len(create_calls), "all fail")

m_socket.socket = _socket # type: ignore
ipv6_addr_info = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:beef::", 80, 0, 0),
)
ipv6_addr_info_2 = (
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("dead:aaaa::", 80, 0, 0),
)
ipv4_addr_info = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("107.6.106.83", 80),
)
addr_info = [ipv6_addr_info, ipv6_addr_info_2, ipv4_addr_info]
local_addr_infos = [
(
socket.AF_INET6,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("::1", 0, 0, 0),
),
(
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("127.0.0.1", 0),
),
]
loop = asyncio.get_running_loop()
# We should get the same exception raised if they are all the same
with mock.patch.object(loop, "sock_connect", _sock_connect), pytest.raises(
OSError, match="all fail"
) as exc_info:
assert (
await start_connection(
addr_info,
Expand All @@ -1267,6 +1357,9 @@ async def _sock_connect(
== mock_socket
)

# No errno is set if they are all different
assert exc_info.value.errno is None

# All calls failed
assert create_calls == [
("dead:beef::", 80, 0, 0),
Expand Down
Loading