diff --git a/CHANGES/9348.feature.rst b/CHANGES/9348.feature.rst new file mode 100644 index 0000000000..66fa5c1a06 --- /dev/null +++ b/CHANGES/9348.feature.rst @@ -0,0 +1 @@ +Added :py:meth:`~aiohttp.ClientWebSocketResponse.send_frame` and :py:meth:`~aiohttp.web.WebSocketResponse.send_frame` for WebSockets -- by :user:`bdraco`. diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 28b0b4b90e..850dd8ca8a 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -230,15 +230,23 @@ async def ping(self, message: bytes = b"") -> None: async def pong(self, message: bytes = b"") -> None: await self._writer.pong(message) + async def send_frame( + self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None + ) -> None: + """Send a frame over the websocket.""" + await self._writer.send_frame(message, opcode, compress) + async def send_str(self, data: str, compress: Optional[int] = None) -> None: if not isinstance(data, str): raise TypeError("data argument must be str (%r)" % type(data)) - await self._writer.send(data, binary=False, compress=compress) + await self._writer.send_frame( + data.encode("utf-8"), WSMsgType.TEXT, compress=compress + ) async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError("data argument must be byte-ish (%r)" % type(data)) - await self._writer.send(data, binary=True, compress=compress) + await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress) async def send_json( self, diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index b63d41f860..268f1b624d 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -605,7 +605,7 @@ def __init__( self._output_size = 0 self._compressobj: Any = None # actually compressobj - async def _send_frame( + async def send_frame( self, message: bytes, opcode: int, compress: Optional[int] = None ) -> None: """Send a frame over the websocket with message as its payload.""" @@ -710,32 +710,18 @@ def _write(self, data: bytes) -> None: async def pong(self, message: bytes = b"") -> None: """Send pong message.""" - await self._send_frame(message, WSMsgType.PONG) + await self.send_frame(message, WSMsgType.PONG) async def ping(self, message: bytes = b"") -> None: """Send ping message.""" - await self._send_frame(message, WSMsgType.PING) - - async def send( - self, - message: Union[str, bytes], - binary: bool = False, - compress: Optional[int] = None, - ) -> None: - """Send a frame over the websocket with message as its payload.""" - if isinstance(message, str): - message = message.encode("utf-8") - if binary: - await self._send_frame(message, WSMsgType.BINARY, compress) - else: - await self._send_frame(message, WSMsgType.TEXT, compress) + await self.send_frame(message, WSMsgType.PING) async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None: """Close the websocket, sending the specified code and message.""" if isinstance(message, str): message = message.encode("utf-8") try: - await self._send_frame( + await self.send_frame( PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE ) finally: diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 14d47b66e4..56e4ad1cf8 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -404,19 +404,29 @@ async def pong(self, message: bytes = b"") -> None: raise RuntimeError("Call .prepare() first") await self._writer.pong(message) + async def send_frame( + self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None + ) -> None: + """Send a frame over the websocket.""" + if self._writer is None: + raise RuntimeError("Call .prepare() first") + await self._writer.send_frame(message, opcode, compress) + async def send_str(self, data: str, compress: Optional[int] = None) -> None: if self._writer is None: raise RuntimeError("Call .prepare() first") if not isinstance(data, str): raise TypeError("data argument must be str (%r)" % type(data)) - await self._writer.send(data, binary=False, compress=compress) + await self._writer.send_frame( + data.encode("utf-8"), WSMsgType.TEXT, compress=compress + ) async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: if self._writer is None: raise RuntimeError("Call .prepare() first") if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError("data argument must be byte-ish (%r)" % type(data)) - await self._writer.send(data, binary=True, compress=compress) + await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress) async def send_json( self, diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 086458edce..4e6224c78b 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1594,6 +1594,32 @@ manually. The method is converted into :term:`coroutine`, *compress* parameter added. + .. method:: send_frame(message, opcode, compress=None) + :async: + + Send a :const:`~aiohttp.WSMsgType` message *message* to peer. + + This method is low-level and should be used with caution as it + only accepts bytes which must conform to the correct message type + for *message*. + + It is recommended to use the :meth:`send_str`, :meth:`send_bytes` + or :meth:`send_json` methods instead of this method. + + The primary use case for this method is to send bytes that are + have already been encoded without having to decode and + re-encode them. + + :param bytes message: message to send. + + :param ~aiohttp.WSMsgType opcode: opcode of the message. + + :param int compress: sets specific level of compression for + single message, + ``None`` for not overriding per-socket setting. + + .. versionadded:: 3.11 + .. method:: close(*, code=WSCloseCode.OK, message=b'') :async: diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 9e351c998b..5b92e2206c 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -934,8 +934,8 @@ and :ref:`aiohttp-web-signals` handlers:: To enable back-pressure from slow websocket clients treat methods :meth:`ping`, :meth:`pong`, :meth:`send_str`, - :meth:`send_bytes`, :meth:`send_json` as coroutines. By - default write buffer size is set to 64k. + :meth:`send_bytes`, :meth:`send_json`, :meth:`send_frame` as coroutines. + By default write buffer size is set to 64k. :param bool autoping: Automatically send :const:`~aiohttp.WSMsgType.PONG` on @@ -1149,6 +1149,32 @@ and :ref:`aiohttp-web-signals` handlers:: The method is converted into :term:`coroutine`, *compress* parameter added. + .. method:: send_frame(message, opcode, compress=None) + :async: + + Send a :const:`~aiohttp.WSMsgType` message *message* to peer. + + This method is low-level and should be used with caution as it + only accepts bytes which must conform to the correct message type + for *message*. + + It is recommended to use the :meth:`send_str`, :meth:`send_bytes` + or :meth:`send_json` methods instead of this method. + + The primary use case for this method is to send bytes that are + have already been encoded without having to decode and + re-encode them. + + :param bytes message: message to send. + + :param ~aiohttp.WSMsgType opcode: opcode of the message. + + :param int compress: sets specific level of compression for + single message, + ``None`` for not overriding per-socket setting. + + .. versionadded:: 3.11 + .. method:: close(*, code=WSCloseCode.OK, message=b'', drain=True) :async: diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index b4c3f6820b..8276a2e2fe 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -569,6 +569,7 @@ async def test_send_data_after_close( (resp.send_str, ("s",)), (resp.send_bytes, (b"b",)), (resp.send_json, ({},)), + (resp.send_frame, (b"", aiohttp.WSMsgType.BINARY)), ): with pytest.raises(exc): # Verify exc can be caught with both classes await meth(*args) @@ -775,19 +776,28 @@ async def test_ws_connect_deflate_per_message( m_req.return_value = loop.create_future() m_req.return_value.set_result(mresp) writer = WebSocketWriter.return_value = mock.Mock() - send = writer.send = make_mocked_coro() + send_frame = writer.send_frame = make_mocked_coro() session = aiohttp.ClientSession() resp = await session.ws_connect("http://test.org") await resp.send_str("string", compress=-1) - send.assert_called_with("string", binary=False, compress=-1) + send_frame.assert_called_with( + b"string", aiohttp.WSMsgType.TEXT, compress=-1 + ) await resp.send_bytes(b"bytes", compress=15) - send.assert_called_with(b"bytes", binary=True, compress=15) + send_frame.assert_called_with( + b"bytes", aiohttp.WSMsgType.BINARY, compress=15 + ) await resp.send_json([{}], compress=-9) - send.assert_called_with("[{}]", binary=False, compress=-9) + send_frame.assert_called_with( + b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9 + ) + + await resp.send_frame(b"[{}]", aiohttp.WSMsgType.TEXT, compress=-9) + send_frame.assert_called_with(b"[{}]", aiohttp.WSMsgType.TEXT, -9) await session.close() diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index dc86691bb1..bc70bbd212 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -131,6 +131,28 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await resp.close() +async def test_send_recv_frame(aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + data = await ws.receive() + await ws.send_frame(data.data, data.type) + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + await resp.send_frame(b"test", WSMsgType.BINARY) + + data = await resp.receive() + assert data.data == b"test" + assert data.type is WSMsgType.BINARY + await resp.close() + + async def test_ping_pong(aiohttp_client: AiohttpClient) -> None: loop = asyncio.get_event_loop() closed = loop.create_future() diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index bc89bc7896..76656b2ab7 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -81,6 +81,12 @@ async def test_nonstarted_pong() -> None: await ws.pong() +async def test_nonstarted_send_frame() -> None: + ws = web.WebSocketResponse() + with pytest.raises(RuntimeError): + await ws.send_frame(b"string", WSMsgType.TEXT) + + async def test_nonstarted_send_str() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): @@ -268,6 +274,18 @@ async def test_send_json_closed(make_request: _RequestMaker) -> None: await ws.send_json({"type": "json"}) +async def test_send_frame_closed(make_request: _RequestMaker) -> None: + req = make_request("GET", "/") + ws = web.WebSocketResponse() + await ws.prepare(req) + assert ws._reader is not None + ws._reader.feed_data(WS_CLOSED_MESSAGE) + await ws.close() + + with pytest.raises(ConnectionError): + await ws.send_frame(b'{"type": "json"}', WSMsgType.TEXT) + + async def test_ping_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse() @@ -560,15 +578,18 @@ async def test_send_with_per_message_deflate( req = make_request("GET", "/") ws = web.WebSocketResponse() await ws.prepare(req) - with mock.patch.object(ws._writer, "send", autospec=True, spec_set=True) as m: + with mock.patch.object(ws._writer, "send_frame", autospec=True, spec_set=True) as m: await ws.send_str("string", compress=15) - m.assert_called_with("string", binary=False, compress=15) + m.assert_called_with(b"string", WSMsgType.TEXT, compress=15) await ws.send_bytes(b"bytes", compress=0) - m.assert_called_with(b"bytes", binary=True, compress=0) + m.assert_called_with(b"bytes", WSMsgType.BINARY, compress=0) await ws.send_json("[{}]", compress=9) - m.assert_called_with('"[{}]"', binary=False, compress=9) + m.assert_called_with(b'"[{}]"', WSMsgType.TEXT, compress=9) + + await ws.send_frame(b"[{}]", WSMsgType.TEXT, compress=9) + m.assert_called_with(b"[{}]", WSMsgType.TEXT, compress=9) async def test_no_transfer_encoding_header( diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index f718126b63..2f20cfa8ac 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -397,7 +397,7 @@ async def handler(request: web.Request) -> web.WebSocketResponse: ws = await client.ws_connect("/", protocols=("eggs", "bar")) - await ws._writer._send_frame(b"", WSMsgType.CLOSE) + await ws._writer.send_frame(b"", WSMsgType.CLOSE) msg = await ws.receive() assert msg.type == WSMsgType.CLOSE diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 1abebf40fc..57745d237d 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -5,7 +5,7 @@ import pytest -from aiohttp import DataQueue, WSMessage +from aiohttp import DataQueue, WSMessage, WSMsgType from aiohttp.base_protocol import BaseProtocol from aiohttp.http import WebSocketReader, WebSocketWriter from aiohttp.test_utils import make_mocked_coro @@ -41,22 +41,22 @@ async def test_ping(writer: WebSocketWriter) -> None: async def test_send_text(writer: WebSocketWriter) -> None: - await writer.send(b"text") + await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with(b"\x81\x04text") # type: ignore[attr-defined] async def test_send_binary(writer: WebSocketWriter) -> None: - await writer.send("binary", True) + await writer.send_frame(b"binary", WSMsgType.BINARY) writer.transport.write.assert_called_with(b"\x82\x06binary") # type: ignore[attr-defined] async def test_send_binary_long(writer: WebSocketWriter) -> None: - await writer.send(b"b" * 127, True) + await writer.send_frame(b"b" * 127, WSMsgType.BINARY) assert writer.transport.write.call_args[0][0].startswith(b"\x82~\x00\x7fb") # type: ignore[attr-defined] async def test_send_binary_very_long(writer: WebSocketWriter) -> None: - await writer.send(b"b" * 65537, True) + await writer.send_frame(b"b" * 65537, WSMsgType.BINARY) assert ( writer.transport.write.call_args_list[0][0][0] # type: ignore[attr-defined] == b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01" @@ -82,7 +82,7 @@ async def test_send_text_masked( writer = WebSocketWriter( protocol, transport, use_mask=True, random=random.Random(123) ) - await writer.send(b"text") + await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with(b"\x81\x84\rg\xb3fy\x02\xcb\x12") # type: ignore[attr-defined] @@ -90,9 +90,9 @@ async def test_send_compress_text( protocol: BaseProtocol, transport: asyncio.Transport ) -> None: writer = WebSocketWriter(protocol, transport, compress=15) - await writer.send(b"text") + await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") # type: ignore[attr-defined] - await writer.send(b"text") + await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with(b"\xc1\x05*\x01b\x00\x00") # type: ignore[attr-defined] @@ -100,9 +100,9 @@ async def test_send_compress_text_notakeover( protocol: BaseProtocol, transport: asyncio.Transport ) -> None: writer = WebSocketWriter(protocol, transport, compress=15, notakeover=True) - await writer.send(b"text") + await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") # type: ignore[attr-defined] - await writer.send(b"text") + await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") # type: ignore[attr-defined] @@ -110,11 +110,11 @@ async def test_send_compress_text_per_message( protocol: BaseProtocol, transport: asyncio.Transport ) -> None: writer = WebSocketWriter(protocol, transport) - await writer.send(b"text", compress=15) + await writer.send_frame(b"text", WSMsgType.TEXT, compress=15) writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") # type: ignore[attr-defined] - await writer.send(b"text") + await writer.send_frame(b"text", WSMsgType.TEXT) writer.transport.write.assert_called_with(b"\x81\x04text") # type: ignore[attr-defined] - await writer.send(b"text", compress=15) + await writer.send_frame(b"text", WSMsgType.TEXT, compress=15) writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00") # type: ignore[attr-defined] @@ -161,7 +161,7 @@ async def test_concurrent_messages( point = payload_point_generator(count) payload = bytes((point,)) * point payloads.append(payload) - writers.append(writer.send(payload, binary=True)) + writers.append(writer.send_frame(payload, WSMsgType.BINARY)) await asyncio.gather(*writers) for call in writer.transport.write.call_args_list: # type: ignore[attr-defined]