Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop raise_for_status() releasing when in a context #9239

Merged
merged 11 commits into from
Sep 23, 2024
1 change: 1 addition & 0 deletions CHANGES/9239.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Changed :py:meth:`ClientResponse.raise_for_status() <aiohttp.ClientResponse.raise_for_status>` to only release the connection when invoked outside an ``async with`` context -- by :user:`Dreamsorcerer`.
32 changes: 7 additions & 25 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class ClientTimeout:
# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})

_RetType = TypeVar("_RetType")
_RetType = TypeVar("_RetType", ClientResponse, ClientWebSocketResponse)
_CharsetResolver = Callable[[ClientResponse, bytes], str]


Expand Down Expand Up @@ -1275,7 +1275,7 @@ class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType
__slots__ = ("_coro", "_resp")

def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None:
self._coro = coro
self._coro: Coroutine["asyncio.Future[Any]", None, _RetType] = coro

def send(self, arg: None) -> "asyncio.Future[Any]":
return self._coro.send(arg)
Expand All @@ -1294,38 +1294,20 @@ def __iter__(self) -> Generator[Any, None, _RetType]:
return self.__await__()

async def __aenter__(self) -> _RetType:
self._resp = await self._coro
return self._resp


class _RequestContextManager(_BaseRequestContextManager[ClientResponse]):
__slots__ = ()
self._resp: _RetType = await self._coro
return await self._resp.__aenter__()

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
# We're basing behavior on the exception as it can be caused by
# user code unrelated to the status of the connection. If you
# would like to close a connection you must do that
# explicitly. Otherwise connection error handling should kick in
# and close/recycle the connection as required.
self._resp.release()
await self._resp.wait_for_close()
await self._resp.__aexit__(exc_type, exc, tb)


class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]):
__slots__ = ()

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
await self._resp.close()
_RequestContextManager = _BaseRequestContextManager[ClientResponse]
_WSRequestContextManager = _BaseRequestContextManager[ClientWebSocketResponse]


class _SessionRequestContextManager:
Expand Down
10 changes: 9 additions & 1 deletion aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ class ClientResponse(HeadersMixin):
# post-init stage allows to not change ctor signature
_closed = True # to allow __del__ for non-initialized properly response
_released = False
_in_context = False
__writer = None

def __init__(
Expand Down Expand Up @@ -1022,7 +1023,12 @@ def raise_for_status(self) -> None:
if not self.ok:
# reason should always be not None for a started response
assert self.reason is not None
self.release()

# If we're in a context we can rely on __aexit__() to release as the
webknjaz marked this conversation as resolved.
Show resolved Hide resolved
# exception propagates.
if not self._in_context:
self.release()

raise ClientResponseError(
self.request_info,
self.history,
Expand Down Expand Up @@ -1144,6 +1150,7 @@ async def json(
return loads(self._body.decode(encoding)) # type: ignore[union-attr]

async def __aenter__(self) -> "ClientResponse":
self._in_context = True
bdraco marked this conversation as resolved.
Show resolved Hide resolved
return self

async def __aexit__(
Expand All @@ -1152,6 +1159,7 @@ async def __aexit__(
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self._in_context = False
# similar to _RequestContextManager, we do not need to check
# for exceptions, response object can close connection
# if state is broken
Expand Down
14 changes: 13 additions & 1 deletion aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import asyncio
import dataclasses
import sys
from typing import Any, Final, Optional, cast
from types import TracebackType
from typing import Any, Final, Optional, Type, cast

from .client_exceptions import ClientError, ServerTimeoutError
from .client_reqrep import ClientResponse
Expand Down Expand Up @@ -395,3 +396,14 @@ async def __anext__(self) -> WSMessage:
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
raise StopAsyncIteration
return msg

async def __aenter__(self) -> "ClientWebSocketResponse":
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()
37 changes: 37 additions & 0 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from aiohttp import Fingerprint, ServerFingerprintMismatch, hdrs, web
from aiohttp.abc import AbstractResolver, ResolveResult
from aiohttp.client_exceptions import (
ClientResponseError,
InvalidURL,
InvalidUrlClientError,
InvalidUrlRedirectClientError,
Expand Down Expand Up @@ -3688,6 +3689,42 @@ async def handler(request: web.Request) -> web.Response:
await resp.read()


async def test_read_after_catch_raise_for_status(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
return web.Response(body=b"data", status=404)

app = web.Application()
app.add_routes([web.get("/", handler)])

client = await aiohttp_client(app)

async with client.get("/") as resp:
with pytest.raises(ClientResponseError, match="404"):
# Should not release response when in async with context.
resp.raise_for_status()

result = await resp.read()
assert result == b"data"


async def test_read_after_raise_outside_context(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
return web.Response(body=b"data", status=404)

app = web.Application()
app.add_routes([web.get("/", handler)])

client = await aiohttp_client(app)

resp = await client.get("/")
with pytest.raises(ClientResponseError, match="404"):
# No async with, so should release and therefore read() will fail.
resp.raise_for_status()

with pytest.raises(aiohttp.ClientConnectionError, match=r"^Connection closed$"):
await resp.read()


async def test_read_from_closed_content(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
return web.Response(body=b"data")
Expand Down
Loading