diff --git a/CHANGES/7764.bugfix b/CHANGES/7764.bugfix new file mode 100644 index 00000000000..6e4c7aa5ba8 --- /dev/null +++ b/CHANGES/7764.bugfix @@ -0,0 +1 @@ +Fixed an issue when a client request is closed before completing a chunked payload -- by :user:`Dreamsorcerer` diff --git a/aiohttp/client.py b/aiohttp/client.py index 1d9d9fe94d1..be6a00ffaa7 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -1140,6 +1140,7 @@ async def __aexit__( # explicitly. Otherwise connection error handling should kick in # and close/recycle the connection as required. self._resp.release() + await self._resp.wait_for_close() class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]): diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 67664733940..8271f70e445 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -529,8 +529,11 @@ async def write_bytes( """Support coroutines that yields bytes objects.""" # 100 response if self._continue is not None: - await writer.drain() - await self._continue + try: + await writer.drain() + await self._continue + except asyncio.CancelledError: + return protocol = conn.protocol assert protocol is not None @@ -543,8 +546,6 @@ async def write_bytes( for chunk in self.body: await writer.write(chunk) # type: ignore[arg-type] - - await writer.write_eof() except OSError as exc: if exc.errno is None and isinstance(exc, asyncio.TimeoutError): protocol.set_exception(exc) @@ -555,12 +556,12 @@ async def write_bytes( new_exc.__context__ = exc new_exc.__cause__ = exc protocol.set_exception(new_exc) - except asyncio.CancelledError as exc: - if not conn.closed: - protocol.set_exception(exc) + except asyncio.CancelledError: + await writer.write_eof() except Exception as exc: protocol.set_exception(exc) else: + await writer.write_eof() protocol.start_timeout() finally: self._writer = None @@ -649,7 +650,8 @@ async def send(self, conn: "Connection") -> "ClientResponse": async def close(self) -> None: if self._writer is not None: try: - await self._writer + with contextlib.suppress(asyncio.CancelledError): + await self._writer finally: self._writer = None @@ -914,8 +916,7 @@ def _response_eof(self) -> None: ): return - self._connection.release() - self._connection = None + self._release_connection() self._closed = True self._cleanup_writer() @@ -927,30 +928,22 @@ def closed(self) -> bool: def close(self) -> None: if not self._released: self._notify_content() - if self._closed: - return self._closed = True if self._loop.is_closed(): return - if self._connection is not None: - self._connection.close() - self._connection = None self._cleanup_writer() + self._release_connection() def release(self) -> Any: if not self._released: self._notify_content() - if self._closed: - return noop() self._closed = True - if self._connection is not None: - self._connection.release() - self._connection = None self._cleanup_writer() + self._release_connection() return noop() @property @@ -975,10 +968,28 @@ def raise_for_status(self) -> None: headers=self.headers, ) + def _release_connection(self) -> None: + if self._connection is not None: + if self._writer is None: + self._connection.release() + self._connection = None + else: + self._writer.add_done_callback(lambda f: self._release_connection()) + + async def _wait_released(self) -> None: + if self._writer is not None: + try: + await self._writer + finally: + self._writer = None + self._release_connection() + def _cleanup_writer(self) -> None: if self._writer is not None: - self._writer.cancel() - self._writer = None + if self._writer.done(): + self._writer = None + else: + self._writer.cancel() self._session = None def _notify_content(self) -> None: @@ -1008,9 +1019,10 @@ async def read(self) -> bytes: except BaseException: self.close() raise - elif self._released: + elif self._released: # Response explicity released raise ClientConnectionError("Connection closed") + await self._wait_released() # Underlying connection released return self._body def get_encoding(self) -> str: @@ -1087,3 +1099,4 @@ async def __aexit__( # for exceptions, response object can close connection # if state is broken self.release() + await self.wait_for_close() diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 6ce43d1501e..e56e44824fd 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -137,6 +137,7 @@ async def handler(request): client = await aiohttp_client(app) resp = await client.get("/") assert resp.closed + await resp.wait_for_close() assert 1 == len(client._session.connector._conns) @@ -156,6 +157,60 @@ async def handler(request): assert content == b"" +async def test_stream_request_on_server_eof(aiohttp_client) -> None: + async def handler(request): + return web.Response(text="OK", status=200) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + app.add_routes([web.put("/", handler)]) + + client = await aiohttp_client(app) + + async def data_gen(): + for _ in range(2): + yield b"just data" + await asyncio.sleep(0.1) + + async with client.put("/", data=data_gen()) as resp: + assert 200 == resp.status + assert len(client.session.connector._acquired) == 1 + conn = next(iter(client.session.connector._acquired)) + + async with client.get("/") as resp: + assert 200 == resp.status + + # Connection should have been reused + conns = next(iter(client.session.connector._conns.values())) + assert len(conns) == 1 + assert conns[0][0] is conn + + +async def test_stream_request_on_server_eof_nested(aiohttp_client) -> None: + async def handler(request): + return web.Response(text="OK", status=200) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + app.add_routes([web.put("/", handler)]) + + client = await aiohttp_client(app) + + async def data_gen(): + for _ in range(2): + yield b"just data" + await asyncio.sleep(0.1) + + async with client.put("/", data=data_gen()) as resp: + assert 200 == resp.status + async with client.get("/") as resp: + assert 200 == resp.status + + # Should be 2 separate connections + conns = next(iter(client.session.connector._conns.values())) + assert len(conns) == 2 + + async def test_HTTP_304_WITH_BODY(aiohttp_client: Any) -> None: async def handler(request): body = await request.read() @@ -238,8 +293,8 @@ async def handler(request): client = await aiohttp_client(app) with io.BytesIO(data) as file_handle: - resp = await client.post("/", data=file_handle) - assert 200 == resp.status + async with client.post("/", data=file_handle) as resp: + assert 200 == resp.status async def test_post_data_with_bytesio_file(aiohttp_client: Any) -> None: diff --git a/tests/test_client_response.py b/tests/test_client_response.py index 07eb1e1e747..64161ac5941 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -18,6 +18,14 @@ from aiohttp.test_utils import make_mocked_coro +class WriterMock(mock.AsyncMock): + def __await__(self) -> None: + return self().__await__() + + def done(self) -> bool: + return True + + @pytest.fixture def session(): return mock.Mock() @@ -30,7 +38,7 @@ async def test_http_processing_error(session: Any) -> None: "get", URL("http://del-cl-resp.org"), request_info=request_info, - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -58,7 +66,7 @@ def test_del(session: Any) -> None: "get", URL("http://del-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -85,7 +93,7 @@ def test_close(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -106,7 +114,7 @@ def test_wait_for_100_1(loop: Any, session: Any) -> None: URL("http://python.org"), continue100=object(), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), timer=TimerNoop(), traces=[], loop=loop, @@ -122,7 +130,7 @@ def test_wait_for_100_2(loop: Any, session: Any) -> None: URL("http://python.org"), request_info=mock.Mock(), continue100=None, - writer=mock.Mock(), + writer=WriterMock(), timer=TimerNoop(), traces=[], loop=loop, @@ -137,7 +145,7 @@ def test_repr(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -154,7 +162,7 @@ def test_repr_non_ascii_url() -> None: "get", URL("http://fake-host.org/\u03bb"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -169,7 +177,7 @@ def test_repr_non_ascii_reason() -> None: "get", URL("http://fake-host.org/path"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -187,7 +195,7 @@ async def test_read_and_release_connection(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -213,7 +221,7 @@ async def test_read_and_release_connection_with_error(loop: Any, session: Any) - "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -234,7 +242,7 @@ async def test_release(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -263,7 +271,7 @@ def run(conn): "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -283,7 +291,7 @@ async def test_response_eof(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -304,7 +312,7 @@ async def test_response_eof_upgraded(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -325,7 +333,7 @@ async def test_response_eof_after_connection_detach(loop: Any, session: Any) -> "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -346,7 +354,7 @@ async def test_text(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -373,7 +381,7 @@ async def test_text_bad_encoding(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -403,7 +411,7 @@ async def test_text_custom_encoding(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -436,7 +444,7 @@ async def test_text_charset_resolver( "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -465,7 +473,7 @@ async def test_get_encoding_body_none(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -495,7 +503,7 @@ async def test_text_after_read(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -522,7 +530,7 @@ async def test_json(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -549,7 +557,7 @@ async def test_json_extended_content_type(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -578,7 +586,7 @@ async def test_json_custom_content_type(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -605,7 +613,7 @@ async def test_json_custom_loader(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -627,7 +635,7 @@ async def test_json_invalid_content_type(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -648,7 +656,7 @@ async def test_json_no_content(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -667,7 +675,7 @@ async def test_json_override_encoding(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -696,7 +704,7 @@ def test_get_encoding_unknown(loop: Any, session: Any) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -713,7 +721,7 @@ def test_raise_for_status_2xx() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -730,7 +738,7 @@ def test_raise_for_status_4xx() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -751,7 +759,7 @@ def test_raise_for_status_4xx_without_reason() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -772,7 +780,7 @@ def test_resp_host() -> None: "get", URL("http://del-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -787,7 +795,7 @@ def test_content_type() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -804,7 +812,7 @@ def test_content_type_no_header() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -821,7 +829,7 @@ def test_charset() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -838,7 +846,7 @@ def test_charset_no_header() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -855,7 +863,7 @@ def test_charset_no_charset() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -872,7 +880,7 @@ def test_content_disposition_full() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -895,7 +903,7 @@ def test_content_disposition_no_parameters() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -914,7 +922,7 @@ def test_content_disposition_no_header() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -933,7 +941,7 @@ def test_response_request_info() -> None: "get", URL(url), request_info=RequestInfo(url, "get", headers, url), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -952,7 +960,7 @@ def test_request_info_in_exception() -> None: "get", URL(url), request_info=RequestInfo(url, "get", headers, url), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -973,7 +981,7 @@ def test_no_redirect_history_in_exception() -> None: "get", URL(url), request_info=RequestInfo(url, "get", headers, url), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -996,7 +1004,7 @@ def test_redirect_history_in_exception() -> None: "get", URL(url), request_info=RequestInfo(url, "get", headers, url), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1010,7 +1018,7 @@ def test_redirect_history_in_exception() -> None: "get", URL(hist_url), request_info=RequestInfo(url, "get", headers, url), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1039,7 +1047,7 @@ async def test_response_read_triggers_callback(loop: Any, session: Any) -> None: response_method, response_url, request_info=mock.Mock, - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), loop=loop, @@ -1072,7 +1080,7 @@ def test_response_real_url(loop: Any, session: Any) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1089,7 +1097,7 @@ def test_response_links_comma_separated(loop: Any, session: Any) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1119,7 +1127,7 @@ def test_response_links_multiple_headers(loop: Any, session: Any) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1144,7 +1152,7 @@ def test_response_links_no_rel(loop: Any, session: Any) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1163,7 +1171,7 @@ def test_response_links_quoted(loop: Any, session: Any) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1186,7 +1194,7 @@ def test_response_links_relative(loop: Any, session: Any) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1209,7 +1217,7 @@ def test_response_links_empty(loop: Any, session: Any) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1225,7 +1233,7 @@ def test_response_not_closed_after_get_ok(mocker) -> None: "get", URL("http://del-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], diff --git a/tests/test_proxy_functional.py b/tests/test_proxy_functional.py index 3de78e80665..8433dca23de 100644 --- a/tests/test_proxy_functional.py +++ b/tests/test_proxy_functional.py @@ -407,7 +407,8 @@ async def test_proxy_http_acquired_cleanup(proxy_test_server: Any, loop: Any) -> assert 0 == len(conn._acquired) - resp = await sess.get(url, proxy=proxy.url) + async with sess.get(url, proxy=proxy.url) as resp: + pass assert resp.closed assert 0 == len(conn._acquired) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 05ca6afb8bf..47be76255a9 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -1816,7 +1816,6 @@ async def handler(request): resp = await session.get(server.make_url("/")) async with resp: assert resp.status == 200 - assert resp.connection is None assert resp.connection is None await session.close()