Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missing eof when writer cancelled #7764

Merged
merged 14 commits into from
Nov 3, 2023
1 change: 1 addition & 0 deletions CHANGES/7764.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed an issue when a client request is closed before completing a chunked payload -- by :user:`Dreamsorcerer`
1 change: 1 addition & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,7 @@ async def __aexit__(
# explicitly. Otherwise connection error handling should kick in
# and close/recycle the connection as required.
self._resp.release()
await self._resp.wait_for_close()


class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]):
Expand Down
59 changes: 36 additions & 23 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,11 @@ async def write_bytes(
"""Support coroutines that yields bytes objects."""
# 100 response
if self._continue is not None:
await writer.drain()
await self._continue
try:
await writer.drain()
await self._continue
except asyncio.CancelledError:
return

protocol = conn.protocol
assert protocol is not None
Expand All @@ -543,8 +546,6 @@ async def write_bytes(

for chunk in self.body:
await writer.write(chunk) # type: ignore[arg-type]

await writer.write_eof()
except OSError as exc:
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
protocol.set_exception(exc)
Expand All @@ -555,12 +556,12 @@ async def write_bytes(
new_exc.__context__ = exc
new_exc.__cause__ = exc
protocol.set_exception(new_exc)
except asyncio.CancelledError as exc:
if not conn.closed:
protocol.set_exception(exc)
except asyncio.CancelledError:
await writer.write_eof()
except Exception as exc:
protocol.set_exception(exc)
else:
await writer.write_eof()
protocol.start_timeout()
finally:
self._writer = None
Expand Down Expand Up @@ -649,7 +650,8 @@ async def send(self, conn: "Connection") -> "ClientResponse":
async def close(self) -> None:
if self._writer is not None:
try:
await self._writer
with contextlib.suppress(asyncio.CancelledError):
await self._writer
Dismissed Show dismissed Hide dismissed
finally:
self._writer = None

Expand Down Expand Up @@ -914,8 +916,7 @@ def _response_eof(self) -> None:
):
return

self._connection.release()
self._connection = None
self._release_connection()

self._closed = True
self._cleanup_writer()
Expand All @@ -927,30 +928,22 @@ def closed(self) -> bool:
def close(self) -> None:
if not self._released:
self._notify_content()
if self._closed:
return

self._closed = True
if self._loop.is_closed():
return

if self._connection is not None:
self._connection.close()
self._connection = None
self._cleanup_writer()
self._release_connection()

def release(self) -> Any:
if not self._released:
self._notify_content()
if self._closed:
return noop()

self._closed = True
if self._connection is not None:
self._connection.release()
self._connection = None

self._cleanup_writer()
self._release_connection()
return noop()

@property
Expand All @@ -975,10 +968,28 @@ def raise_for_status(self) -> None:
headers=self.headers,
)

def _release_connection(self) -> None:
if self._connection is not None:
if self._writer is None:
self._connection.release()
self._connection = None
else:
self._writer.add_done_callback(lambda f: self._release_connection())

async def _wait_released(self) -> None:
if self._writer is not None:
try:
await self._writer
finally:
self._writer = None
self._release_connection()

def _cleanup_writer(self) -> None:
if self._writer is not None:
self._writer.cancel()
self._writer = None
if self._writer.done():
self._writer = None
else:
self._writer.cancel()
self._session = None

def _notify_content(self) -> None:
Expand Down Expand Up @@ -1008,9 +1019,10 @@ async def read(self) -> bytes:
except BaseException:
self.close()
raise
elif self._released:
elif self._released: # Response explicity released
raise ClientConnectionError("Connection closed")

await self._wait_released() # Underlying connection released
return self._body

def get_encoding(self) -> str:
Expand Down Expand Up @@ -1087,3 +1099,4 @@ async def __aexit__(
# for exceptions, response object can close connection
# if state is broken
self.release()
await self.wait_for_close()
59 changes: 57 additions & 2 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ async def handler(request):
client = await aiohttp_client(app)
resp = await client.get("/")
assert resp.closed
await resp.wait_for_close()
assert 1 == len(client._session.connector._conns)


Expand All @@ -156,6 +157,60 @@ async def handler(request):
assert content == b""


async def test_stream_request_on_server_eof(aiohttp_client) -> None:
async def handler(request):
return web.Response(text="OK", status=200)

app = web.Application()
app.add_routes([web.get("/", handler)])
app.add_routes([web.put("/", handler)])

client = await aiohttp_client(app)

async def data_gen():
for _ in range(2):
yield b"just data"
await asyncio.sleep(0.1)

async with client.put("/", data=data_gen()) as resp:
assert 200 == resp.status
assert len(client.session.connector._acquired) == 1
conn = next(iter(client.session.connector._acquired))

async with client.get("/") as resp:
assert 200 == resp.status

# Connection should have been reused
conns = next(iter(client.session.connector._conns.values()))
assert len(conns) == 1
assert conns[0][0] is conn


async def test_stream_request_on_server_eof_nested(aiohttp_client) -> None:
async def handler(request):
return web.Response(text="OK", status=200)

app = web.Application()
app.add_routes([web.get("/", handler)])
app.add_routes([web.put("/", handler)])

client = await aiohttp_client(app)

async def data_gen():
for _ in range(2):
yield b"just data"
await asyncio.sleep(0.1)

async with client.put("/", data=data_gen()) as resp:
assert 200 == resp.status
async with client.get("/") as resp:
assert 200 == resp.status

# Should be 2 separate connections
conns = next(iter(client.session.connector._conns.values()))
assert len(conns) == 2


async def test_HTTP_304_WITH_BODY(aiohttp_client: Any) -> None:
async def handler(request):
body = await request.read()
Expand Down Expand Up @@ -238,8 +293,8 @@ async def handler(request):
client = await aiohttp_client(app)

with io.BytesIO(data) as file_handle:
resp = await client.post("/", data=file_handle)
assert 200 == resp.status
async with client.post("/", data=file_handle) as resp:
assert 200 == resp.status


async def test_post_data_with_bytesio_file(aiohttp_client: Any) -> None:
Expand Down
Loading
Loading