diff --git a/docs/docs/SUMMARY.md b/docs/docs/SUMMARY.md index f919573a1e..8a76ee0dce 100644 --- a/docs/docs/SUMMARY.md +++ b/docs/docs/SUMMARY.md @@ -356,7 +356,6 @@ search: - [BrokerUsecase](api/faststream/broker/core/usecase/BrokerUsecase.md) - fastapi - [StreamMessage](api/faststream/broker/fastapi/StreamMessage.md) - - [StreamRoute](api/faststream/broker/fastapi/StreamRoute.md) - [StreamRouter](api/faststream/broker/fastapi/StreamRouter.md) - context - [Context](api/faststream/broker/fastapi/context/Context.md) @@ -365,8 +364,9 @@ search: - [get_fastapi_native_dependant](api/faststream/broker/fastapi/get_dependant/get_fastapi_native_dependant.md) - route - [StreamMessage](api/faststream/broker/fastapi/route/StreamMessage.md) - - [StreamRoute](api/faststream/broker/fastapi/route/StreamRoute.md) + - [build_faststream_to_fastapi_parser](api/faststream/broker/fastapi/route/build_faststream_to_fastapi_parser.md) - [make_fastapi_execution](api/faststream/broker/fastapi/route/make_fastapi_execution.md) + - [wrap_callable_to_fastapi_compatible](api/faststream/broker/fastapi/route/wrap_callable_to_fastapi_compatible.md) - router - [StreamRouter](api/faststream/broker/fastapi/router/StreamRouter.md) - message @@ -565,11 +565,13 @@ search: - [HandlerException](api/faststream/exceptions/HandlerException.md) - [IgnoredException](api/faststream/exceptions/IgnoredException.md) - [NackMessage](api/faststream/exceptions/NackMessage.md) + - [OperationForbiddenError](api/faststream/exceptions/OperationForbiddenError.md) - [RejectMessage](api/faststream/exceptions/RejectMessage.md) - [SetupError](api/faststream/exceptions/SetupError.md) - [SkipMessage](api/faststream/exceptions/SkipMessage.md) - [StopApplication](api/faststream/exceptions/StopApplication.md) - [StopConsume](api/faststream/exceptions/StopConsume.md) + - [SubscriberNotFound](api/faststream/exceptions/SubscriberNotFound.md) - [ValidationError](api/faststream/exceptions/ValidationError.md) - kafka - [KafkaBroker](api/faststream/kafka/KafkaBroker.md) @@ -844,6 +846,7 @@ search: - usecase - [LogicPublisher](api/faststream/rabbit/publisher/usecase/LogicPublisher.md) - [PublishKwargs](api/faststream/rabbit/publisher/usecase/PublishKwargs.md) + - [RequestPublishKwargs](api/faststream/rabbit/publisher/usecase/RequestPublishKwargs.md) - response - [RabbitResponse](api/faststream/rabbit/response/RabbitResponse.md) - router @@ -990,8 +993,12 @@ search: - [LogicSubscriber](api/faststream/redis/subscriber/usecase/LogicSubscriber.md) - [StreamSubscriber](api/faststream/redis/subscriber/usecase/StreamSubscriber.md) - testing + - [ChannelVisitor](api/faststream/redis/testing/ChannelVisitor.md) - [FakeProducer](api/faststream/redis/testing/FakeProducer.md) + - [ListVisitor](api/faststream/redis/testing/ListVisitor.md) + - [StreamVisitor](api/faststream/redis/testing/StreamVisitor.md) - [TestRedisBroker](api/faststream/redis/testing/TestRedisBroker.md) + - [Visitor](api/faststream/redis/testing/Visitor.md) - [build_message](api/faststream/redis/testing/build_message.md) - security - [BaseSecurity](api/faststream/security/BaseSecurity.md) @@ -1006,7 +1013,6 @@ search: - [TestApp](api/faststream/testing/app/TestApp.md) - broker - [TestBroker](api/faststream/testing/broker/TestBroker.md) - - [call_handler](api/faststream/testing/broker/call_handler.md) - [patch_broker_calls](api/faststream/testing/broker/patch_broker_calls.md) - types - [LoggerProto](api/faststream/types/LoggerProto.md) @@ -1046,6 +1052,7 @@ search: - [call_or_await](api/faststream/utils/functions/call_or_await.md) - [drop_response_type](api/faststream/utils/functions/drop_response_type.md) - [fake_context](api/faststream/utils/functions/fake_context.md) + - [return_input](api/faststream/utils/functions/return_input.md) - [sync_fake_context](api/faststream/utils/functions/sync_fake_context.md) - [timeout_scope](api/faststream/utils/functions/timeout_scope.md) - [to_async](api/faststream/utils/functions/to_async.md) diff --git a/docs/docs/en/api/faststream/broker/fastapi/route/build_faststream_to_fastapi_parser.md b/docs/docs/en/api/faststream/broker/fastapi/route/build_faststream_to_fastapi_parser.md new file mode 100644 index 0000000000..dc05bb190e --- /dev/null +++ b/docs/docs/en/api/faststream/broker/fastapi/route/build_faststream_to_fastapi_parser.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.broker.fastapi.route.build_faststream_to_fastapi_parser diff --git a/docs/docs/en/api/faststream/broker/fastapi/route/wrap_callable_to_fastapi_compatible.md b/docs/docs/en/api/faststream/broker/fastapi/route/wrap_callable_to_fastapi_compatible.md new file mode 100644 index 0000000000..ab7081c711 --- /dev/null +++ b/docs/docs/en/api/faststream/broker/fastapi/route/wrap_callable_to_fastapi_compatible.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.broker.fastapi.route.wrap_callable_to_fastapi_compatible diff --git a/docs/docs/en/api/faststream/exceptions/OperationForbiddenError.md b/docs/docs/en/api/faststream/exceptions/OperationForbiddenError.md new file mode 100644 index 0000000000..e34e86542b --- /dev/null +++ b/docs/docs/en/api/faststream/exceptions/OperationForbiddenError.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.exceptions.OperationForbiddenError diff --git a/docs/docs/en/api/faststream/broker/fastapi/route/StreamRoute.md b/docs/docs/en/api/faststream/exceptions/SubscriberNotFound.md similarity index 69% rename from docs/docs/en/api/faststream/broker/fastapi/route/StreamRoute.md rename to docs/docs/en/api/faststream/exceptions/SubscriberNotFound.md index 4899cbe531..89428f8251 100644 --- a/docs/docs/en/api/faststream/broker/fastapi/route/StreamRoute.md +++ b/docs/docs/en/api/faststream/exceptions/SubscriberNotFound.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: faststream.broker.fastapi.route.StreamRoute +::: faststream.exceptions.SubscriberNotFound diff --git a/docs/docs/en/api/faststream/rabbit/publisher/usecase/RequestPublishKwargs.md b/docs/docs/en/api/faststream/rabbit/publisher/usecase/RequestPublishKwargs.md new file mode 100644 index 0000000000..5668633016 --- /dev/null +++ b/docs/docs/en/api/faststream/rabbit/publisher/usecase/RequestPublishKwargs.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.rabbit.publisher.usecase.RequestPublishKwargs diff --git a/docs/docs/en/api/faststream/redis/testing/ChannelVisitor.md b/docs/docs/en/api/faststream/redis/testing/ChannelVisitor.md new file mode 100644 index 0000000000..f916be2ae8 --- /dev/null +++ b/docs/docs/en/api/faststream/redis/testing/ChannelVisitor.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.redis.testing.ChannelVisitor diff --git a/docs/docs/en/api/faststream/broker/fastapi/StreamRoute.md b/docs/docs/en/api/faststream/redis/testing/ListVisitor.md similarity index 72% rename from docs/docs/en/api/faststream/broker/fastapi/StreamRoute.md rename to docs/docs/en/api/faststream/redis/testing/ListVisitor.md index 423e8fdbdc..414b8a0400 100644 --- a/docs/docs/en/api/faststream/broker/fastapi/StreamRoute.md +++ b/docs/docs/en/api/faststream/redis/testing/ListVisitor.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: faststream.broker.fastapi.StreamRoute +::: faststream.redis.testing.ListVisitor diff --git a/docs/docs/en/api/faststream/testing/broker/call_handler.md b/docs/docs/en/api/faststream/redis/testing/StreamVisitor.md similarity index 71% rename from docs/docs/en/api/faststream/testing/broker/call_handler.md rename to docs/docs/en/api/faststream/redis/testing/StreamVisitor.md index fd11830902..0b72d99109 100644 --- a/docs/docs/en/api/faststream/testing/broker/call_handler.md +++ b/docs/docs/en/api/faststream/redis/testing/StreamVisitor.md @@ -8,4 +8,4 @@ search: boost: 0.5 --- -::: faststream.testing.broker.call_handler +::: faststream.redis.testing.StreamVisitor diff --git a/docs/docs/en/api/faststream/redis/testing/Visitor.md b/docs/docs/en/api/faststream/redis/testing/Visitor.md new file mode 100644 index 0000000000..746688710f --- /dev/null +++ b/docs/docs/en/api/faststream/redis/testing/Visitor.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.redis.testing.Visitor diff --git a/docs/docs/en/api/faststream/utils/functions/return_input.md b/docs/docs/en/api/faststream/utils/functions/return_input.md new file mode 100644 index 0000000000..d5514e013f --- /dev/null +++ b/docs/docs/en/api/faststream/utils/functions/return_input.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.utils.functions.return_input diff --git a/docs/docs/en/nats/rpc.md b/docs/docs/en/nats/rpc.md index d75a7c5c3a..ff390750fd 100644 --- a/docs/docs/en/nats/rpc.md +++ b/docs/docs/en/nats/rpc.md @@ -23,19 +23,15 @@ Just send a message like a regular one and get a response synchronously. It is very close to the common **requests** syntax: -```python hl_lines="1 4" -msg = await broker.publish( +```python hl_lines="3" +from faststream.nats import NatsMessage + +msg: NatsMessage = await broker.request( "Hi!", subject="test", - rpc=True, ) ``` -Also, you have two extra options to control this behavior: - -* `#!python rpc_timeout: Optional[float] = 30.0` - controls how long you are waiting for a response. -* `#!python raise_timeout: bool = False` - by default, a timeout request returns `None`, but if you need to raise a `TimeoutException` directly, you can specify this option. - ## Reply-To Also, if you want to create a permanent request-reply data flow, probably, you should create a permanent subject to consume responses. diff --git a/docs/docs/en/rabbit/ack.md b/docs/docs/en/rabbit/ack.md index 3024256c24..d68b66a0cd 100644 --- a/docs/docs/en/rabbit/ack.md +++ b/docs/docs/en/rabbit/ack.md @@ -44,6 +44,9 @@ async def base_handler(body: str): ... ``` +!!! tip + **FastStream** identifies the message by its `message_id`. To make this option work, you should manually set this field on the producer side (if your library doesn't set it automatically). + !!! bug At the moment, attempts are counted only by the current consumer. If the message goes to another consumer, it will have its own counter. Subsequently, this logic will be reworked. diff --git a/docs/docs/en/rabbit/rpc.md b/docs/docs/en/rabbit/rpc.md index 843d3edc3c..2a132a7705 100644 --- a/docs/docs/en/rabbit/rpc.md +++ b/docs/docs/en/rabbit/rpc.md @@ -20,19 +20,15 @@ Just send a message like a regular one and get a response synchronously. It is very close to common **requests** syntax: -```python hl_lines="1 4" -msg = await broker.publish( +```python hl_lines="3" +from faststream.rabbit import RabbitMessage + +msg: RabbitMessage = await broker.request( "Hi!", queue="test", - rpc=True, ) ``` -Also, you have two extra options to control this behavior: - -* `#!python rpc_timeout: Optional[float] = 30.0` - controls how long you are waiting for a response -* `#!python raise_timeout: bool = False` - by default, a timeout request returns `None`, but if you need to raise a `TimeoutException` directly, you can specify this option - ## Reply-To Also, if you want to create a permanent request-reply data flow, probably, you should create a permanent queue to consume responses. diff --git a/docs/docs/en/redis/rpc.md b/docs/docs/en/redis/rpc.md index 91a233eb6f..d6ec7ee06a 100644 --- a/docs/docs/en/redis/rpc.md +++ b/docs/docs/en/redis/rpc.md @@ -39,9 +39,9 @@ To implement **Redis** RPC with `RedisBroker` in **FastStream**, follow the step 3. Send RPC messages through `RedisBroker` and await responses on the correct data type. - After your application has started and the subscribers are ready to receive messages, you can publish messages with the `rpc` option enabled. Additionally, you can set an `rpc_timeout` to decide how long the publisher should wait for a response before timing out. + Additionally, you can set a `timeout` to decide how long the publisher should wait for a response before timing out. - ```python linenums="1" + ```python linenums="1" hl_lines="5 12 19" {!> docs_src/redis/rpc/app.py [ln:26-49] !} ``` diff --git a/docs/docs_src/getting_started/serialization/parser_confluent.py b/docs/docs_src/getting_started/serialization/parser_confluent.py index 93dd0cef22..52fe1612f6 100644 --- a/docs/docs_src/getting_started/serialization/parser_confluent.py +++ b/docs/docs_src/getting_started/serialization/parser_confluent.py @@ -11,7 +11,7 @@ async def custom_parser( original_parser: Callable[[Message], Awaitable[KafkaMessage]], ) -> KafkaMessage: parsed_msg = await original_parser(msg) - parsed_msg.message_id = parsed_msg.headers["custom_message_id"] + parsed_msg.message_id = parsed_msg.headers.get("custom_message_id") return parsed_msg diff --git a/docs/docs_src/getting_started/serialization/parser_kafka.py b/docs/docs_src/getting_started/serialization/parser_kafka.py index 7d526740b6..701a5cd903 100644 --- a/docs/docs_src/getting_started/serialization/parser_kafka.py +++ b/docs/docs_src/getting_started/serialization/parser_kafka.py @@ -11,7 +11,7 @@ async def custom_parser( original_parser: Callable[[ConsumerRecord], Awaitable[KafkaMessage]], ) -> KafkaMessage: parsed_msg = await original_parser(msg) - parsed_msg.message_id = parsed_msg.headers["custom_message_id"] + parsed_msg.message_id = parsed_msg.headers.get("custom_message_id") return parsed_msg diff --git a/docs/docs_src/getting_started/serialization/parser_nats.py b/docs/docs_src/getting_started/serialization/parser_nats.py index 5061c06625..2eb0bcac75 100644 --- a/docs/docs_src/getting_started/serialization/parser_nats.py +++ b/docs/docs_src/getting_started/serialization/parser_nats.py @@ -11,7 +11,7 @@ async def custom_parser( original_parser: Callable[[Msg], Awaitable[NatsMessage]], ) -> NatsMessage: parsed_msg = await original_parser(msg) - parsed_msg.message_id = parsed_msg.headers["custom_message_id"] + parsed_msg.message_id = parsed_msg.headers.get("custom_message_id") return parsed_msg diff --git a/docs/docs_src/getting_started/serialization/parser_rabbit.py b/docs/docs_src/getting_started/serialization/parser_rabbit.py index 4674d8033c..5a4bcb61d7 100644 --- a/docs/docs_src/getting_started/serialization/parser_rabbit.py +++ b/docs/docs_src/getting_started/serialization/parser_rabbit.py @@ -11,7 +11,7 @@ async def custom_parser( original_parser: Callable[[IncomingMessage], Awaitable[RabbitMessage]], ) -> RabbitMessage: parsed_msg = await original_parser(msg) - parsed_msg.message_id = parsed_msg.headers["custom_message_id"] + parsed_msg.message_id = parsed_msg.headers.get("custom_message_id") return parsed_msg diff --git a/docs/docs_src/getting_started/serialization/parser_redis.py b/docs/docs_src/getting_started/serialization/parser_redis.py index b15e4ef023..8190461fb2 100644 --- a/docs/docs_src/getting_started/serialization/parser_redis.py +++ b/docs/docs_src/getting_started/serialization/parser_redis.py @@ -10,7 +10,7 @@ async def custom_parser( original_parser: Callable[[PubSubMessage], Awaitable[RedisMessage]], ) -> RedisMessage: parsed_msg = await original_parser(msg) - parsed_msg.message_id = parsed_msg.headers["custom_message_id"] + parsed_msg.message_id = parsed_msg.headers.get("custom_message_id") return parsed_msg diff --git a/docs/docs_src/redis/rpc/app.py b/docs/docs_src/redis/rpc/app.py index 67debe5cdd..7758ba9ad2 100644 --- a/docs/docs_src/redis/rpc/app.py +++ b/docs/docs_src/redis/rpc/app.py @@ -1,5 +1,5 @@ from faststream import FastStream, Logger -from faststream.redis import RedisBroker +from faststream.redis import RedisBroker, RedisMessage broker = RedisBroker("redis://localhost:6379") app = FastStream(broker) @@ -27,23 +27,23 @@ async def handle_stream(msg: str, logger: Logger): async def t(): msg = "Hi!" - assert msg == await broker.publish( + response: RedisMessage = await broker.request( "Hi!", channel="test-channel", - rpc=True, - rpc_timeout=3.0, + timeout=3.0, ) + assert await response.decode() == msg - assert msg == await broker.publish( + response: RedisMessage = await broker.request( "Hi!", list="test-list", - rpc=True, - rpc_timeout=3.0, + timeout=3.0, ) + assert await response.decode() == msg - assert msg == await broker.publish( + response: RedisMessage = await broker.request( "Hi!", stream="test-stream", - rpc=True, - rpc_timeout=3.0, + timeout=3.0, ) + assert await response.decode() == msg diff --git a/examples/e05_rpc_request.py b/examples/e05_rpc_request.py index 3ee2611635..74e1f55197 100644 --- a/examples/e05_rpc_request.py +++ b/examples/e05_rpc_request.py @@ -14,4 +14,5 @@ async def handle(msg, logger: Logger): @app.after_startup async def test_publishing(): - assert (await broker.publish("ping", "test-queue", rpc=True)) == "pong" + response = await broker.request("ping", "test-queue") + assert await response.decode() == "pong" diff --git a/examples/e10_middlewares.py b/examples/e10_middlewares.py index 3916c9472b..03a0519d79 100644 --- a/examples/e10_middlewares.py +++ b/examples/e10_middlewares.py @@ -25,7 +25,7 @@ async def subscriber_middleware( msg: RabbitMessage, ) -> Any: print(f"call handler middleware with body: {msg}") - msg.decoded_body = "fake message" + msg._decoded_body = "fake message" result = await call_next(msg) print("handler middleware out") return result diff --git a/examples/nats/e02_basic_rpc.py b/examples/nats/e02_basic_rpc.py index 4be2c44e25..739c09c6e6 100644 --- a/examples/nats/e02_basic_rpc.py +++ b/examples/nats/e02_basic_rpc.py @@ -13,5 +13,5 @@ async def handler(msg: str, logger: Logger): @app.after_startup async def test_send(): - response = await broker.publish("Hi!", "subject", rpc=True) - assert response == "Response" + response = await broker.request("Hi!", "subject") + assert await response.decode() == "Response" diff --git a/examples/redis/rpc.py b/examples/redis/rpc.py index 009d4496fd..4278658484 100644 --- a/examples/redis/rpc.py +++ b/examples/redis/rpc.py @@ -27,23 +27,23 @@ async def handle_stream(msg: str, logger: Logger): async def t(): msg = "Hi!" - assert msg == await broker.publish( + response = await broker.request( "Hi!", channel="test-channel", - rpc=True, - rpc_timeout=3.0, + timeout=3.0, ) + assert await response.decode() == msg - assert msg == await broker.publish( + response = await broker.request( "Hi!", list="test-list", - rpc=True, - rpc_timeout=3.0, + timeout=3.0, ) + assert await response.decode() == msg - assert msg == await broker.publish( + response = await broker.request( "Hi!", stream="test-stream", - rpc=True, - rpc_timeout=3.0, + timeout=3.0, ) + assert await response.decode() == msg diff --git a/faststream/__about__.py b/faststream/__about__.py index 6752d941d0..d687b1244e 100644 --- a/faststream/__about__.py +++ b/faststream/__about__.py @@ -1,5 +1,5 @@ """Simple and fast framework to create message brokers based microservices.""" -__version__ = "0.5.18" +__version__ = "0.5.19" SERVICE_NAME = f"faststream-{__version__}" diff --git a/faststream/asgi/app.py b/faststream/asgi/app.py index ac09a500bf..40df520097 100644 --- a/faststream/asgi/app.py +++ b/faststream/asgi/app.py @@ -94,6 +94,7 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No async def start_lifespan_context(self) -> AsyncIterator[None]: async with anyio.create_task_group() as tg, self.lifespan_context(): tg.start_soon(self._startup) + try: yield finally: diff --git a/faststream/broker/core/usecase.py b/faststream/broker/core/usecase.py index e790d5a8a0..7069dd2652 100644 --- a/faststream/broker/core/usecase.py +++ b/faststream/broker/core/usecase.py @@ -1,5 +1,6 @@ import logging from abc import abstractmethod +from contextlib import AsyncExitStack from functools import partial from typing import ( TYPE_CHECKING, @@ -32,7 +33,7 @@ from faststream.exceptions import NOT_CONNECTED_YET from faststream.log.logging import set_logger_fmt from faststream.utils.context.repository import context -from faststream.utils.functions import to_async +from faststream.utils.functions import return_input, to_async if TYPE_CHECKING: from types import TracebackType @@ -40,6 +41,7 @@ from fast_depends.dependencies import Depends from faststream.asyncapi.schema import Tag, TagDict + from faststream.broker.message import StreamMessage from faststream.broker.publisher.proto import ProducerProto, PublisherProto from faststream.security import BaseSecurity from faststream.types import AnyDict, Decorator, LoggerProto @@ -331,6 +333,7 @@ async def publish( msg: Any, *, producer: Optional["ProducerProto"], + correlation_id: Optional[str] = None, **kwargs: Any, ) -> Optional[Any]: """Publish message directly.""" @@ -341,7 +344,39 @@ async def publish( for m in self._middlewares: publish = partial(m(None).publish_scope, publish) - return await publish(msg, **kwargs) + return await publish(msg, correlation_id=correlation_id, **kwargs) + + async def request( + self, + msg: Any, + *, + producer: Optional["ProducerProto"], + correlation_id: Optional[str] = None, + **kwargs: Any, + ) -> Any: + """Publish message directly.""" + assert producer, NOT_CONNECTED_YET # nosec B101 + + request = producer.request + for m in self._middlewares: + request = partial(m(None).publish_scope, request) + + published_msg = await request( + msg, + correlation_id=correlation_id, + **kwargs, + ) + + async with AsyncExitStack() as stack: + return_msg = return_input + for m in self._middlewares: + mid = m(published_msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg: StreamMessage[Any] = await producer._parser(published_msg) + parsed_msg._decoded_body = await producer._decoder(parsed_msg) + return await return_msg(parsed_msg) @abstractmethod async def ping(self, timeout: Optional[float]) -> bool: diff --git a/faststream/broker/fastapi/__init__.py b/faststream/broker/fastapi/__init__.py index 11212a120f..4b683d238c 100644 --- a/faststream/broker/fastapi/__init__.py +++ b/faststream/broker/fastapi/__init__.py @@ -1,8 +1,7 @@ -from faststream.broker.fastapi.route import StreamMessage, StreamRoute +from faststream.broker.fastapi.route import StreamMessage from faststream.broker.fastapi.router import StreamRouter __all__ = ( "StreamMessage", - "StreamRoute", "StreamRouter", ) diff --git a/faststream/broker/fastapi/get_dependant.py b/faststream/broker/fastapi/get_dependant.py index ca85fce295..45d5aaba30 100644 --- a/faststream/broker/fastapi/get_dependant.py +++ b/faststream/broker/fastapi/get_dependant.py @@ -12,13 +12,11 @@ def get_fastapi_dependant( orig_call: Callable[..., Any], dependencies: Iterable["params.Depends"], - path_name: str = "", ) -> Any: """Generate FastStream-Compatible FastAPI Dependant object.""" dependent = get_fastapi_native_dependant( orig_call=orig_call, dependencies=dependencies, - path_name=path_name, ) dependent = _patch_fastapi_dependent(dependent) @@ -29,18 +27,17 @@ def get_fastapi_dependant( def get_fastapi_native_dependant( orig_call: Callable[..., Any], dependencies: Iterable["params.Depends"], - path_name: str = "", ) -> Any: """Generate native FastAPI Dependant.""" dependent = get_dependant( - path=path_name, + path="", call=orig_call, ) for depends in list(dependencies)[::-1]: dependent.dependencies.insert( 0, - get_parameterless_sub_dependant(depends=depends, path=path_name), + get_parameterless_sub_dependant(depends=depends, path=""), ) return dependent diff --git a/faststream/broker/fastapi/route.py b/faststream/broker/fastapi/route.py index 1040acfcd1..a7eca0fd42 100644 --- a/faststream/broker/fastapi/route.py +++ b/faststream/broker/fastapi/route.py @@ -8,7 +8,6 @@ Any, Awaitable, Callable, - Generic, Iterable, List, Optional, @@ -19,12 +18,10 @@ from fastapi.routing import run_endpoint_function, serialize_response from fastapi.utils import create_response_field from starlette.requests import Request -from starlette.routing import BaseRoute from faststream._compat import FASTAPI_V106, raise_fastapi_validation_error from faststream.broker.fastapi.get_dependant import get_fastapi_native_dependant -from faststream.broker.types import MsgType, P_HandlerParams, T_HandlerReturn -from faststream.broker.wrapper.call import HandlerCallWrapper +from faststream.broker.types import P_HandlerParams, T_HandlerReturn if TYPE_CHECKING: from fastapi import params @@ -32,97 +29,10 @@ from fastapi.dependencies.models import Dependant from fastapi.types import IncEx - from faststream.broker.core.usecase import BrokerUsecase from faststream.broker.message import StreamMessage as NativeMessage - from faststream.broker.schemas import NameRequired from faststream.types import AnyDict -class StreamRoute( - BaseRoute, - Generic[MsgType, P_HandlerParams, T_HandlerReturn], -): - """A class representing a stream route.""" - - handler: "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]" - - def __init__( - self, - path: Union["NameRequired", str, None], - *extra: Any, - provider_factory: Callable[[], Any], - endpoint: Union[ - Callable[P_HandlerParams, T_HandlerReturn], - "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", - ], - broker: "BrokerUsecase[MsgType, Any]", - dependencies: Iterable["params.Depends"], - response_model: Any, - response_model_include: Optional["IncEx"], - response_model_exclude: Optional["IncEx"], - response_model_by_alias: bool, - response_model_exclude_unset: bool, - response_model_exclude_defaults: bool, - response_model_exclude_none: bool, - **handle_kwargs: Any, - ) -> None: - self.path, path_name = path or "", getattr(path, "name", "") - self.broker = broker - - if isinstance(endpoint, HandlerCallWrapper): - orig_call = endpoint._original_call - while hasattr(orig_call, "__consumer__"): - orig_call = orig_call.__wrapped__ # type: ignore[union-attr] - - else: - orig_call = endpoint - - dependent = get_fastapi_native_dependant( - orig_call, - list(dependencies), - path_name=path_name, - ) - - if response_model: - response_field = create_response_field( - name="ResponseModel", - type_=response_model, - mode="serialization", - ) - else: - response_field = None - - call = wraps(orig_call)( - StreamMessage.get_consumer( - dependent=dependent, - provider_factory=provider_factory, - response_field=response_field, - response_model_include=response_model_include, - response_model_exclude=response_model_exclude, - response_model_by_alias=response_model_by_alias, - response_model_exclude_unset=response_model_exclude_unset, - response_model_exclude_defaults=response_model_exclude_defaults, - response_model_exclude_none=response_model_exclude_none, - ) - ) - - handler: HandlerCallWrapper[Any, Any, Any] - if isinstance(endpoint, HandlerCallWrapper): - endpoint._original_call = call - handler = endpoint - - else: - handler = call # type: ignore[assignment] - - self.handler = broker.subscriber( # type: ignore[call-arg] - *extra, - dependencies=list(dependencies), - **handle_kwargs, - )( - handler, - ) - - class StreamMessage(Request): """A class to represent a stream message.""" @@ -147,75 +57,116 @@ def __init__( self.scope = {"path_params": self._query_params} self._cookies = {} - @classmethod - def get_consumer( - cls, - *, - dependent: "Dependant", - provider_factory: Callable[[], Any], - response_field: Optional["ModelField"], - response_model_include: Optional["IncEx"], - response_model_exclude: Optional["IncEx"], - response_model_by_alias: bool, - response_model_exclude_unset: bool, - response_model_exclude_defaults: bool, - response_model_exclude_none: bool, - ) -> Callable[["NativeMessage[Any]"], Awaitable[Any]]: - """Creates a session for handling requests.""" - assert dependent.call # nosec B101 - - consume = make_fastapi_execution( - dependent=dependent, - provider_factory=provider_factory, - response_field=response_field, - response_model_include=response_model_include, - response_model_exclude=response_model_exclude, - response_model_by_alias=response_model_by_alias, - response_model_exclude_unset=response_model_exclude_unset, - response_model_exclude_defaults=response_model_exclude_defaults, - response_model_exclude_none=response_model_exclude_none, - ) - dependencies_names = tuple(i.name for i in dependent.dependencies) +def wrap_callable_to_fastapi_compatible( + user_callable: Callable[P_HandlerParams, T_HandlerReturn], + *, + provider_factory: Callable[[], Any], + dependencies: Iterable["params.Depends"], + response_model: Any, + response_model_include: Optional["IncEx"], + response_model_exclude: Optional["IncEx"], + response_model_by_alias: bool, + response_model_exclude_unset: bool, + response_model_exclude_defaults: bool, + response_model_exclude_none: bool, +) -> Callable[["NativeMessage[Any]"], Awaitable[Any]]: + __magic_attr = "__faststream_consumer__" + + if getattr(user_callable, __magic_attr, False): + return user_callable # type: ignore[return-value] - first_arg = next( - dropwhile( - lambda i: i in dependencies_names, - inspect.signature(dependent.call).parameters, - ), - None, + if response_model: + response_field = create_response_field( + name="ResponseModel", + type_=response_model, + mode="serialization", ) + else: + response_field = None + + parsed_callable = build_faststream_to_fastapi_parser( + dependent=get_fastapi_native_dependant(user_callable, list(dependencies)), + provider_factory=provider_factory, + response_field=response_field, + response_model_include=response_model_include, + response_model_exclude=response_model_exclude, + response_model_by_alias=response_model_by_alias, + response_model_exclude_unset=response_model_exclude_unset, + response_model_exclude_defaults=response_model_exclude_defaults, + response_model_exclude_none=response_model_exclude_none, + ) + + setattr(parsed_callable, __magic_attr, True) + return wraps(user_callable)(parsed_callable) + + +def build_faststream_to_fastapi_parser( + *, + dependent: "Dependant", + provider_factory: Callable[[], Any], + response_field: Optional["ModelField"], + response_model_include: Optional["IncEx"], + response_model_exclude: Optional["IncEx"], + response_model_by_alias: bool, + response_model_exclude_unset: bool, + response_model_exclude_defaults: bool, + response_model_exclude_none: bool, +) -> Callable[["NativeMessage[Any]"], Awaitable[Any]]: + """Creates a session for handling requests.""" + assert dependent.call # nosec B101 + + consume = make_fastapi_execution( + dependent=dependent, + provider_factory=provider_factory, + response_field=response_field, + response_model_include=response_model_include, + response_model_exclude=response_model_exclude, + response_model_by_alias=response_model_by_alias, + response_model_exclude_unset=response_model_exclude_unset, + response_model_exclude_defaults=response_model_exclude_defaults, + response_model_exclude_none=response_model_exclude_none, + ) + + dependencies_names = tuple(i.name for i in dependent.dependencies) + + first_arg = next( + dropwhile( + lambda i: i in dependencies_names, + inspect.signature(dependent.call).parameters, + ), + None, + ) + + async def parsed_consumer(message: "NativeMessage[Any]") -> Any: + """Wrapper, that parser FastStream message to FastAPI compatible one.""" + body = await message.decode() + + fastapi_body: Union[AnyDict, List[Any]] + if first_arg is not None: + if isinstance(body, dict): + path = fastapi_body = body or {} + elif isinstance(body, list): + fastapi_body, path = body, {} + else: + path = fastapi_body = {first_arg: body} - async def real_consumer(message: "NativeMessage[Any]") -> Any: - """An asynchronous function that processes an incoming message and returns a sendable message.""" - body = message.decoded_body - - fastapi_body: Union[AnyDict, List[Any]] - if first_arg is not None: - if isinstance(body, dict): - path = fastapi_body = body or {} - elif isinstance(body, list): - fastapi_body, path = body, {} - else: - path = fastapi_body = {first_arg: body} - - stream_message = cls( - body=fastapi_body, - headers=message.headers, - path={**path, **message.path}, - ) + stream_message = StreamMessage( + body=fastapi_body, + headers=message.headers, + path={**path, **message.path}, + ) - else: - stream_message = cls( - body={}, - headers={}, - path={}, - ) + else: + stream_message = StreamMessage( + body={}, + headers={}, + path={}, + ) - return await consume(stream_message, message) + return await consume(stream_message, message) - real_consumer.__consumer__ = True # type: ignore[attr-defined] - return real_consumer + return parsed_consumer def make_fastapi_execution( @@ -230,15 +181,15 @@ def make_fastapi_execution( response_model_exclude_defaults: bool, response_model_exclude_none: bool, ) -> Callable[ - [StreamMessage, "NativeMessage[Any]"], + ["StreamMessage", "NativeMessage[Any]"], Awaitable[Any], ]: """Creates a FastAPI application.""" is_coroutine = asyncio.iscoroutinefunction(dependent.call) async def app( - request: StreamMessage, - raw_message: "NativeMessage[Any]", + request: "StreamMessage", + raw_message: "NativeMessage[Any]", # to support BackgroundTasks by middleware ) -> Any: """Consume StreamMessage and return user function result.""" async with AsyncExitStack() as stack: @@ -256,7 +207,13 @@ async def app( **kwargs, # type: ignore[arg-type] ) - values, errors, raw_message.background, _, _2 = solved_result # type: ignore[attr-defined] + ( + values, + errors, + raw_message.background, # type: ignore[attr-defined] + _response, + _dependency_cache, + ) = solved_result if errors: raise_fastapi_validation_error(errors, request._body) # type: ignore[arg-type] diff --git a/faststream/broker/fastapi/router.py b/faststream/broker/fastapi/router.py index 5072476f18..420bfe962c 100644 --- a/faststream/broker/fastapi/router.py +++ b/faststream/broker/fastapi/router.py @@ -33,7 +33,7 @@ from faststream.asyncapi.proto import AsyncAPIApplication from faststream.asyncapi.site import get_asyncapi_html from faststream.broker.fastapi.get_dependant import get_fastapi_dependant -from faststream.broker.fastapi.route import StreamRoute +from faststream.broker.fastapi.route import wrap_callable_to_fastapi_compatible from faststream.broker.middlewares import BaseMiddleware from faststream.broker.types import ( MsgType, @@ -54,6 +54,7 @@ from faststream.asyncapi import schema as asyncapi from faststream.asyncapi.schema import Schema from faststream.broker.core.usecase import BrokerUsecase + from faststream.broker.message import StreamMessage from faststream.broker.publisher.proto import PublisherProto from faststream.broker.schemas import NameRequired from faststream.broker.types import BrokerMiddleware @@ -143,6 +144,7 @@ def __init__( ), _get_dependant=get_fastapi_dependant, tags=asyncapi_tags, + apply_types=False, **connection_kwars, ) @@ -204,9 +206,6 @@ def _get_dependencies_overides_provider(self) -> Optional[Any]: def _add_api_mq_route( self, - path: Union["NameRequired", str], - *extra: Union["NameRequired", str], - endpoint: Callable[P_HandlerParams, T_HandlerReturn], dependencies: Iterable["params.Depends"], response_model: Any, response_model_include: Optional["IncEx"], @@ -215,31 +214,33 @@ def _add_api_mq_route( response_model_exclude_unset: bool, response_model_exclude_defaults: bool, response_model_exclude_none: bool, - **broker_kwargs: Any, - ) -> "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]": - """Add an API message queue route.""" - route = StreamRoute[MsgType, P_HandlerParams, T_HandlerReturn]( - path, - *extra, - endpoint=endpoint, - dependencies=(*self.dependencies, *dependencies), - provider_factory=self._get_dependencies_overides_provider, - broker=self.broker, - response_model=response_model, - response_model_include=response_model_include, - response_model_exclude=response_model_exclude, - response_model_by_alias=response_model_by_alias, - response_model_exclude_unset=response_model_exclude_unset, - response_model_exclude_defaults=response_model_exclude_defaults, - response_model_exclude_none=response_model_exclude_none, - **broker_kwargs, - ) - self.routes.append(route) - return route.handler + ) -> Callable[ + [Callable[..., Any]], + Callable[["StreamMessage[Any]"], Awaitable[Any]], + ]: + """Decorator before `broker.subscriber`, that wraps function to FastAPI-compatible one.""" + + def wrapper( + endpoint: Callable[..., Any], + ) -> Callable[["StreamMessage[Any]"], Awaitable[Any]]: + """Patch user function to make it FastAPI-compatible.""" + return wrap_callable_to_fastapi_compatible( + user_callable=endpoint, + dependencies=dependencies, + response_model=response_model, + response_model_include=response_model_include, + response_model_exclude=response_model_exclude, + response_model_by_alias=response_model_by_alias, + response_model_exclude_unset=response_model_exclude_unset, + response_model_exclude_defaults=response_model_exclude_defaults, + response_model_exclude_none=response_model_exclude_none, + provider_factory=self._get_dependencies_overides_provider, + ) + + return wrapper def subscriber( self, - path: Union[str, "NameRequired"], *extra: Union["NameRequired", str], dependencies: Iterable["params.Depends"], response_model: Any, @@ -255,15 +256,16 @@ def subscriber( "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", ]: """A function decorator for subscribing to a message queue.""" + dependencies = (*self.dependencies, *dependencies) - def decorator( - func: Callable[P_HandlerParams, T_HandlerReturn], - ) -> "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]": - """A decorator function.""" - return self._add_api_mq_route( - path, - *extra, - endpoint=func, + sub = self.broker.subscriber( # type: ignore[call-arg] + *extra, # type: ignore[arg-type] + dependencies=dependencies, + **broker_kwargs, + ) + + sub._call_decorators = ( # type: ignore[attr-defined] + self._add_api_mq_route( dependencies=dependencies, response_model=response_model, response_model_include=response_model_include, @@ -272,10 +274,10 @@ def decorator( response_model_exclude_unset=response_model_exclude_unset, response_model_exclude_defaults=response_model_exclude_defaults, response_model_exclude_none=response_model_exclude_none, - **broker_kwargs, - ) + ), + ) - return decorator + return sub def _wrap_lifespan( self, lifespan: Optional["Lifespan[Any]"] = None diff --git a/faststream/broker/message.py b/faststream/broker/message.py index be590b4eab..c80cde94ce 100644 --- a/faststream/broker/message.py +++ b/faststream/broker/message.py @@ -15,6 +15,8 @@ ) from uuid import uuid4 +from typing_extensions import deprecated + from faststream._compat import dump_json, json_loads from faststream.constants import ContentTypes from faststream.types import EMPTY @@ -49,9 +51,36 @@ class StreamMessage(Generic[MsgType]): default_factory=gen_cor_id # pragma: no cover ) - decoded_body: Optional["DecodedMessage"] = field(default=None, init=False) processed: bool = field(default=False, init=False) committed: bool = field(default=False, init=False) + _decoded_body: Optional["DecodedMessage"] = field(default=None, init=False) + + async def decode(self) -> Optional["DecodedMessage"]: + """Serialize the message by lazy decoder.""" + # TODO: make it lazy after `decoded_body` removed + return self._decoded_body + + @property + @deprecated( + "Deprecated in **FastStream 0.5.19**. " + "Please, use `decode` lazy method instead. " + "Argument will be removed in **FastStream 0.6.0**.", + category=DeprecationWarning, + stacklevel=1, + ) + def decoded_body(self) -> Optional["DecodedMessage"]: + return self._decoded_body + + @decoded_body.setter + @deprecated( + "Deprecated in **FastStream 0.5.19**. " + "Please, use `decode` lazy method instead. " + "Argument will be removed in **FastStream 0.6.0**.", + category=DeprecationWarning, + stacklevel=1, + ) + def decoded_body(self, value: Optional["DecodedMessage"]) -> None: + self._decoded_body = value async def ack(self) -> None: self.committed = True diff --git a/faststream/broker/publisher/fake.py b/faststream/broker/publisher/fake.py index d77c43406b..7ef19903d3 100644 --- a/faststream/broker/publisher/fake.py +++ b/faststream/broker/publisher/fake.py @@ -44,3 +44,16 @@ async def publish( call = partial(m, call) return await call(message, **publish_kwargs) + + async def request( + self, + message: "SendableMessage", + /, + *, + correlation_id: Optional[str] = None, + _extra_middlewares: Iterable["PublisherMiddleware"] = (), + ) -> Any: + raise NotImplementedError( + "`FakePublisher` can be used only to publish " + "a response for `reply-to` or `RPC` messages." + ) diff --git a/faststream/broker/publisher/proto.py b/faststream/broker/publisher/proto.py index 747b29b048..67ef329f19 100644 --- a/faststream/broker/publisher/proto.py +++ b/faststream/broker/publisher/proto.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from faststream.broker.types import ( + AsyncCallable, BrokerMiddleware, P_HandlerParams, PublisherMiddleware, @@ -18,6 +19,9 @@ class ProducerProto(Protocol): + _parser: "AsyncCallable" + _decoder: "AsyncCallable" + @abstractmethod async def publish( self, @@ -29,6 +33,17 @@ async def publish( """Publishes a message asynchronously.""" ... + @abstractmethod + async def request( + self, + message: "SendableMessage", + /, + *, + correlation_id: Optional[str] = None, + ) -> Any: + """Publishes a message synchronously.""" + ... + class BasePublisherProto(Protocol): @abstractmethod @@ -43,6 +58,18 @@ async def publish( """Publishes a message asynchronously.""" ... + @abstractmethod + async def request( + self, + message: "SendableMessage", + /, + *, + correlation_id: Optional[str] = None, + _extra_middlewares: Iterable["PublisherMiddleware"] = (), + ) -> Optional[Any]: + """Publishes a message synchronously.""" + ... + class PublisherProto( AsyncAPIProto, diff --git a/faststream/broker/subscriber/call_item.py b/faststream/broker/subscriber/call_item.py index 77bdb70c9a..c7c32b3609 100644 --- a/faststream/broker/subscriber/call_item.py +++ b/faststream/broker/subscriber/call_item.py @@ -139,7 +139,7 @@ async def is_suitable( "StreamMessage[MsgType]", cache.get(parser) or await parser(msg) ) - message.decoded_body = cache[decoder] = cache.get(decoder) or await decoder( + message._decoded_body = cache[decoder] = cache.get(decoder) or await decoder( message ) diff --git a/faststream/broker/subscriber/proto.py b/faststream/broker/subscriber/proto.py index 116d003d48..612d497196 100644 --- a/faststream/broker/subscriber/proto.py +++ b/faststream/broker/subscriber/proto.py @@ -13,6 +13,7 @@ from faststream.broker.message import StreamMessage from faststream.broker.publisher.proto import BasePublisherProto, ProducerProto + from faststream.broker.response import Response from faststream.broker.subscriber.call_item import HandlerItem from faststream.broker.types import ( BrokerMiddleware, @@ -83,7 +84,7 @@ async def close(self) -> None: ... async def consume(self, msg: MsgType) -> Any: ... @abstractmethod - async def process_message(self, msg: MsgType) -> Any: ... + async def process_message(self, msg: MsgType) -> "Response": ... @abstractmethod def add_call( diff --git a/faststream/broker/subscriber/usecase.py b/faststream/broker/subscriber/usecase.py index 1897826d9a..4141ca17d5 100644 --- a/faststream/broker/subscriber/usecase.py +++ b/faststream/broker/subscriber/usecase.py @@ -30,7 +30,7 @@ ) from faststream.broker.utils import MultiLock, get_watcher_context, resolve_custom_func from faststream.broker.wrapper.call import HandlerCallWrapper -from faststream.exceptions import SetupError, StopConsume +from faststream.exceptions import SetupError, StopConsume, SubscriberNotFound from faststream.utils.context.repository import context from faststream.utils.functions import sync_fake_context, to_async @@ -40,6 +40,7 @@ from faststream.broker.message import StreamMessage from faststream.broker.middlewares import BaseMiddleware from faststream.broker.publisher.proto import BasePublisherProto, ProducerProto + from faststream.broker.response import Response from faststream.broker.types import ( AsyncCallable, BrokerMiddleware, @@ -88,6 +89,7 @@ class SubscriberUsecase( _broker_dependencies: Iterable["Depends"] _call_options: Optional["_CallOptions"] + _call_decorators: Iterable["Decorator"] def __init__( self, @@ -115,6 +117,7 @@ def __init__( self._retry = retry self._call_options = None + self._call_decorators = () self.running = False self.lock = sync_fake_context() @@ -182,7 +185,7 @@ def setup( # type: ignore[override] apply_types=apply_types, is_validate=is_validate, _get_dependant=_get_dependant, - _call_decorators=_call_decorators, + _call_decorators=(*self._call_decorators, *_call_decorators), broker_dependencies=self._broker_dependencies, ) @@ -315,7 +318,7 @@ async def consume(self, msg: MsgType) -> Any: # All other exceptions were logged by CriticalLogMiddleware pass - async def process_message(self, msg: MsgType) -> Any: + async def process_message(self, msg: MsgType) -> "Response": """Execute all message processing stages.""" async with AsyncExitStack() as stack: stack.enter_context(self.lock) @@ -383,7 +386,8 @@ async def process_message(self, msg: MsgType) -> Any: _extra_middlewares=(m.publish_scope for m in middlewares), ) - return result_msg.body + # Return data for tests + return result_msg # Suitable handler was not found or # parsing/decoding exception occurred @@ -394,9 +398,10 @@ async def process_message(self, msg: MsgType) -> Any: raise parsing_error else: - raise AssertionError(f"There is no suitable handler for {msg=}") + raise SubscriberNotFound(f"There is no suitable handler for {msg=}") - return None + # An error was raised and processed by some middleware + return ensure_response(None) def __get_response_publisher( self, diff --git a/faststream/broker/wrapper/call.py b/faststream/broker/wrapper/call.py index 0c997eb5b2..1ae083dd7f 100644 --- a/faststream/broker/wrapper/call.py +++ b/faststream/broker/wrapper/call.py @@ -97,7 +97,7 @@ def call_wrapped( assert self._wrapped_call, "You should use `set_wrapped` first" # nosec B101 if self.is_test: assert self.mock # nosec B101 - self.mock(message.decoded_body) + self.mock(message._decoded_body) return self._wrapped_call(message) async def wait_call(self, timeout: Optional[float] = None) -> None: @@ -190,7 +190,7 @@ def _wrap_decode_message( async def decode_wrapper(message: "StreamMessage[MsgType]") -> T_HandlerReturn: """A wrapper function to decode and handle a message.""" - msg = message.decoded_body + msg = await message.decode() if params_ln > 1: if isinstance(msg, Mapping): diff --git a/faststream/confluent/broker/broker.py b/faststream/confluent/broker/broker.py index e290a276f5..e5c9754e00 100644 --- a/faststream/confluent/broker/broker.py +++ b/faststream/confluent/broker/broker.py @@ -431,7 +431,7 @@ async def _connect( # type: ignore[override] security_params = parse_security(self.security) kwargs.update(security_params) - producer = AsyncConfluentProducer( + native_producer = AsyncConfluentProducer( **kwargs, client_id=client_id, logger=self.logger, @@ -439,7 +439,9 @@ async def _connect( # type: ignore[override] ) self._producer = AsyncConfluentFastProducer( - producer=producer, + producer=native_producer, + parser=self._parser, + decoder=self._decoder, ) return partial( @@ -497,6 +499,32 @@ async def publish( # type: ignore[override] **kwargs, ) + @override + async def request( # type: ignore[override] + self, + message: "SendableMessage", + topic: str, + key: Optional[bytes] = None, + partition: Optional[int] = None, + timestamp_ms: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + correlation_id: Optional[str] = None, + timeout: float = 0.5, + ) -> Optional[Any]: + correlation_id = correlation_id or gen_cor_id() + + return await super().request( + message, + producer=self._producer, + topic=topic, + key=key, + partition=partition, + timestamp_ms=timestamp_ms, + headers=headers, + correlation_id=correlation_id, + timeout=timeout, + ) + async def publish_batch( self, *msgs: "SendableMessage", diff --git a/faststream/confluent/fastapi/fastapi.py b/faststream/confluent/fastapi/fastapi.py index b4a2e6bba4..eacfc7b37a 100644 --- a/faststream/confluent/fastapi/fastapi.py +++ b/faststream/confluent/fastapi/fastapi.py @@ -2210,10 +2210,6 @@ def subscriber( "AsyncAPIDefaultSubscriber", ]: subscriber = super().subscriber( - ( # path - next(iter(topics), "") - or getattr(next(iter(partitions), None), "topic", "") - ), *topics, polling_interval=polling_interval, partitions=partitions, diff --git a/faststream/confluent/message.py b/faststream/confluent/message.py index 29ab56cbb4..14fe05ae7b 100644 --- a/faststream/confluent/message.py +++ b/faststream/confluent/message.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Protocol, Tuple, Union +from typing import TYPE_CHECKING, Any, Optional, Protocol, Tuple, Union from faststream.broker.message import StreamMessage @@ -11,6 +11,13 @@ class ConsumerProtocol(Protocol): async def commit(self) -> None: ... + async def seek( + self, + topic: Optional[str], + partition: Optional[int], + offset: Optional[int], + ) -> None: ... + class FakeConsumer: """A fake Kafka consumer.""" @@ -18,7 +25,12 @@ class FakeConsumer: async def commit(self) -> None: pass - async def seek(self, **kwargs: Any) -> None: + async def seek( + self, + topic: Optional[str], + partition: Optional[int], + offset: Optional[int], + ) -> None: pass @@ -64,7 +76,7 @@ async def nack(self) -> None: if isinstance(self.raw_message, tuple) else self.raw_message ) - await self.consumer.seek( # type: ignore[attr-defined] + await self.consumer.seek( topic=raw_message.topic(), partition=raw_message.partition(), offset=raw_message.offset(), diff --git a/faststream/confluent/publisher/producer.py b/faststream/confluent/publisher/producer.py index 99c75d32b7..cf8ae1cc09 100644 --- a/faststream/confluent/publisher/producer.py +++ b/faststream/confluent/publisher/producer.py @@ -1,11 +1,15 @@ -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from typing_extensions import override from faststream.broker.message import encode_message from faststream.broker.publisher.proto import ProducerProto +from faststream.broker.utils import resolve_custom_func +from faststream.confluent.parser import AsyncConfluentParser +from faststream.exceptions import OperationForbiddenError if TYPE_CHECKING: + from faststream.broker.types import CustomCallable from faststream.confluent.client import AsyncConfluentProducer from faststream.types import SendableMessage @@ -16,9 +20,16 @@ class AsyncConfluentFastProducer(ProducerProto): def __init__( self, producer: "AsyncConfluentProducer", + parser: Optional["CustomCallable"], + decoder: Optional["CustomCallable"], ) -> None: self._producer = producer + # NOTE: register default parser to be compatible with request + default = AsyncConfluentParser + self._parser = resolve_custom_func(parser, default.parse_message) + self._decoder = resolve_custom_func(decoder, default.decode_message) + @override async def publish( # type: ignore[override] self, @@ -99,3 +110,9 @@ async def publish_batch( ) await self._producer.send_batch(batch, topic, partition=partition) + + @override + async def request(self, *args: Any, **kwargs: Any) -> Optional[Any]: + raise OperationForbiddenError( + "Kafka doesn't support `request` method without test client." + ) diff --git a/faststream/confluent/publisher/usecase.py b/faststream/confluent/publisher/usecase.py index 7d5e07a304..8a55a33b31 100644 --- a/faststream/confluent/publisher/usecase.py +++ b/faststream/confluent/publisher/usecase.py @@ -1,6 +1,18 @@ +from contextlib import AsyncExitStack from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + Optional, + Tuple, + Union, + cast, +) from confluent_kafka import Message from typing_extensions import override @@ -9,9 +21,11 @@ from faststream.broker.publisher.usecase import PublisherUsecase from faststream.broker.types import MsgType from faststream.exceptions import NOT_CONNECTED_YET +from faststream.utils.functions import return_input if TYPE_CHECKING: from faststream.broker.types import BrokerMiddleware, PublisherMiddleware + from faststream.confluent.message import KafkaMessage from faststream.confluent.publisher.producer import AsyncConfluentFastProducer from faststream.types import AnyDict, AsyncFunc, SendableMessage @@ -60,6 +74,60 @@ def __hash__(self) -> int: def add_prefix(self, prefix: str) -> None: self.topic = "".join((prefix, self.topic)) + @override + async def request( + self, + message: "SendableMessage", + topic: str = "", + *, + key: Optional[bytes] = None, + partition: Optional[int] = None, + timestamp_ms: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + correlation_id: Optional[str] = None, + timeout: float = 0.5, + # publisher specific + _extra_middlewares: Iterable["PublisherMiddleware"] = (), + ) -> "KafkaMessage": + assert self._producer, NOT_CONNECTED_YET # nosec B101 + + kwargs: AnyDict = { + "key": key, + # basic args + "timeout": timeout, + "timestamp_ms": timestamp_ms, + "topic": topic or self.topic, + "partition": partition or self.partition, + "headers": headers or self.headers, + "correlation_id": correlation_id or gen_cor_id(), + } + + request: AsyncFunc = self._producer.request + + for pub_m in chain( + ( + _extra_middlewares + or (m(None).publish_scope for m in self._broker_middlewares) + ), + self._middlewares, + ): + request = partial(pub_m, request) + + published_msg = await request(message, **kwargs) + + async with AsyncExitStack() as stack: + return_msg: Callable[[KafkaMessage], Awaitable[KafkaMessage]] = return_input + for m in self._broker_middlewares: + mid = m(published_msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._producer._parser(published_msg) + parsed_msg._decoded_body = await self._producer._decoder(parsed_msg) + return await return_msg(parsed_msg) + + raise AssertionError("unreachable") + class DefaultPublisher(LogicPublisher[Message]): def __init__( @@ -137,6 +205,33 @@ async def publish( return await call(message, **kwargs) + @override + async def request( + self, + message: "SendableMessage", + topic: str = "", + *, + key: Optional[bytes] = None, + partition: Optional[int] = None, + timestamp_ms: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + correlation_id: Optional[str] = None, + timeout: float = 0.5, + # publisher specific + _extra_middlewares: Iterable["PublisherMiddleware"] = (), + ) -> "KafkaMessage": + return await super().request( + message=message, + topic=topic, + key=key or self.key, + partition=partition, + timestamp_ms=timestamp_ms, + headers=headers, + correlation_id=correlation_id, + timeout=timeout, + _extra_middlewares=_extra_middlewares, + ) + class BatchPublisher(LogicPublisher[Tuple[Message, ...]]): @override diff --git a/faststream/confluent/testing.py b/faststream/confluent/testing.py index 7fba2a7bc1..3d02527ee2 100644 --- a/faststream/confluent/testing.py +++ b/faststream/confluent/testing.py @@ -2,15 +2,20 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from unittest.mock import AsyncMock, MagicMock +import anyio from typing_extensions import override from faststream.broker.message import encode_message, gen_cor_id +from faststream.broker.utils import resolve_custom_func from faststream.confluent.broker import KafkaBroker +from faststream.confluent.parser import AsyncConfluentParser from faststream.confluent.publisher.asyncapi import AsyncAPIBatchPublisher from faststream.confluent.publisher.producer import AsyncConfluentFastProducer from faststream.confluent.schemas import TopicPartition from faststream.confluent.subscriber.asyncapi import AsyncAPIBatchSubscriber -from faststream.testing.broker import TestBroker, call_handler +from faststream.exceptions import SubscriberNotFound +from faststream.testing.broker import TestBroker +from faststream.utils.functions import timeout_scope if TYPE_CHECKING: from faststream.broker.wrapper.call import HandlerCallWrapper @@ -90,6 +95,11 @@ class FakeProducer(AsyncConfluentFastProducer): def __init__(self, broker: KafkaBroker) -> None: self.broker = broker + default = AsyncConfluentParser + + self._parser = resolve_custom_func(broker._parser, default.parse_message) + self._decoder = resolve_custom_func(broker._decoder, default.decode_message) + @override async def publish( # type: ignore[override] self, @@ -107,8 +117,6 @@ async def publish( # type: ignore[override] raise_timeout: bool = False, ) -> Optional[Any]: """Publish a message to the Kafka broker.""" - correlation_id = correlation_id or gen_cor_id() - incoming = build_message( message=message, topic=topic, @@ -116,7 +124,7 @@ async def publish( # type: ignore[override] partition=partition, timestamp_ms=timestamp_ms, headers=headers, - correlation_id=correlation_id, + correlation_id=correlation_id or gen_cor_id(), reply_to=reply_to, ) @@ -124,17 +132,20 @@ async def publish( # type: ignore[override] for handler in self.broker._subscribers.values(): # pragma: no branch if _is_handler_matches(handler, topic, partition): - handle_value = await call_handler( - handler=handler, - message=[incoming] + msg_to_send = ( + [incoming] if isinstance(handler, AsyncAPIBatchSubscriber) - else incoming, - rpc=rpc, - rpc_timeout=rpc_timeout, - raise_timeout=raise_timeout, + else incoming ) - return_value = return_value or handle_value + with timeout_scope(rpc_timeout, raise_timeout): + response_msg = await self._execute_handler( + msg_to_send, topic, handler + ) + if rpc: + return_value = return_value or await self._decoder( + await self._parser(response_msg) + ) return return_value @@ -149,8 +160,6 @@ async def publish_batch( correlation_id: Optional[str] = None, ) -> None: """Publish a batch of messages to the Kafka broker.""" - correlation_id = correlation_id or gen_cor_id() - for handler in self.broker._subscribers.values(): # pragma: no branch if _is_handler_matches(handler, topic, partition): messages = ( @@ -160,26 +169,72 @@ async def publish_batch( partition=partition, timestamp_ms=timestamp_ms, headers=headers, - correlation_id=correlation_id, + correlation_id=correlation_id or gen_cor_id(), reply_to=reply_to, ) for message in msgs ) if isinstance(handler, AsyncAPIBatchSubscriber): - await call_handler( - handler=handler, - message=list(messages), - ) + await self._execute_handler(list(messages), topic, handler) else: for m in messages: - await call_handler( - handler=handler, - message=m, - ) + await self._execute_handler(m, topic, handler) + return None + @override + async def request( # type: ignore[override] + self, + message: "SendableMessage", + topic: str, + key: Optional[bytes] = None, + partition: Optional[int] = None, + timestamp_ms: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + correlation_id: Optional[str] = None, + *, + timeout: Optional[float] = 0.5, + ) -> "MockConfluentMessage": + incoming = build_message( + message=message, + topic=topic, + key=key, + partition=partition, + timestamp_ms=timestamp_ms, + headers=headers, + correlation_id=correlation_id or gen_cor_id(), + ) + + for handler in self.broker._subscribers.values(): # pragma: no branch + if _is_handler_matches(handler, topic, partition): + msg_to_send = ( + [incoming] + if isinstance(handler, AsyncAPIBatchSubscriber) + else incoming + ) + + with anyio.fail_after(timeout): + return await self._execute_handler(msg_to_send, topic, handler) + + raise SubscriberNotFound + + async def _execute_handler( + self, + msg: Any, + topic: str, + handler: "LogicSubscriber[Any]", + ) -> "MockConfluentMessage": + result = await handler.process_message(msg) + + return build_message( + topic=topic, + message=result.body, + headers=result.headers, + correlation_id=result.correlation_id or gen_cor_id(), + ) + class MockConfluentMessage: def __init__( diff --git a/faststream/exceptions.py b/faststream/exceptions.py index 916e2e56fc..5de18549ee 100644 --- a/faststream/exceptions.py +++ b/faststream/exceptions.py @@ -102,6 +102,14 @@ def __init__(self, fields: Iterable[str] = ()) -> None: self.fields = fields +class OperationForbiddenError(FastStreamException, NotImplementedError): + """Raises at planned NotImplemented operation call.""" + + +class SubscriberNotFound(FastStreamException): + """Raises as a service message or in tests.""" + + WRONG_PUBLISH_ARGS = SetupError( "You should use `reply_to` to send response to long-living queue " "and `rpc` to get response in sync mode." diff --git a/faststream/kafka/broker/broker.py b/faststream/kafka/broker/broker.py index 3fa4573ffa..b40333c329 100644 --- a/faststream/kafka/broker/broker.py +++ b/faststream/kafka/broker/broker.py @@ -632,6 +632,8 @@ async def _connect( # type: ignore[override] await producer.start() self._producer = AioKafkaFastProducer( producer=producer, + parser=self._parser, + decoder=self._decoder, ) return partial( @@ -742,6 +744,80 @@ async def publish( # type: ignore[override] **kwargs, ) + @override + async def request( # type: ignore[override] + self, + message: Annotated[ + "SendableMessage", + Doc("Message body to send."), + ], + topic: Annotated[ + str, + Doc("Topic where the message will be published."), + ], + *, + key: Annotated[ + Union[bytes, Any, None], + Doc( + """ + A key to associate with the message. Can be used to + determine which partition to send the message to. If partition + is `None` (and producer's partitioner config is left as default), + then messages with the same key will be delivered to the same + partition (but if key is `None`, partition is chosen randomly). + Must be type `bytes`, or be serializable to bytes via configured + `key_serializer`. + """ + ), + ] = None, + partition: Annotated[ + Optional[int], + Doc( + """ + Specify a partition. If not set, the partition will be + selected using the configured `partitioner`. + """ + ), + ] = None, + timestamp_ms: Annotated[ + Optional[int], + Doc( + """ + Epoch milliseconds (from Jan 1 1970 UTC) to use as + the message timestamp. Defaults to current time. + """ + ), + ] = None, + headers: Annotated[ + Optional[Dict[str, str]], + Doc("Message headers to store metainformation."), + ] = None, + correlation_id: Annotated[ + Optional[str], + Doc( + "Manual message **correlation_id** setter. " + "**correlation_id** is a useful option to trace messages." + ), + ] = None, + timeout: Annotated[ + float, + Doc("Timeout to send RPC request."), + ] = 0.5, + ) -> Optional[Any]: + correlation_id = correlation_id or gen_cor_id() + + return await super().request( + message, + producer=self._producer, + topic=topic, + key=key, + partition=partition, + timestamp_ms=timestamp_ms, + headers=headers, + correlation_id=correlation_id, + timeout=timeout, + ) + async def publish_batch( self, *msgs: Annotated[ diff --git a/faststream/kafka/fastapi/fastapi.py b/faststream/kafka/fastapi/fastapi.py index d8fc7331ee..17b8c03192 100644 --- a/faststream/kafka/fastapi/fastapi.py +++ b/faststream/kafka/fastapi/fastapi.py @@ -2623,7 +2623,6 @@ def subscriber( "AsyncAPIDefaultSubscriber", ]: subscriber = super().subscriber( - topics[0], # path *topics, group_id=group_id, key_deserializer=key_deserializer, diff --git a/faststream/kafka/message.py b/faststream/kafka/message.py index 86e1d432dd..d83a57bf6a 100644 --- a/faststream/kafka/message.py +++ b/faststream/kafka/message.py @@ -13,6 +13,13 @@ class ConsumerProtocol(Protocol): async def commit(self) -> None: ... + def seek( + self, + partition: AIOKafkaTopicPartition, + offset: int, + ) -> None: + pass + class FakeConsumer: """A fake Kafka consumer.""" @@ -20,7 +27,11 @@ class FakeConsumer: async def commit(self) -> None: pass - def seek(self, **kwargs: Any) -> None: + def seek( + self, + partition: AIOKafkaTopicPartition, + offset: int, + ) -> None: pass @@ -62,12 +73,10 @@ async def nack(self) -> None: raw_message.topic, raw_message.partition, ) - - self.consumer.seek( # type: ignore[attr-defined] + self.consumer.seek( partition=topic_partition, offset=raw_message.offset, ) - await super().nack() diff --git a/faststream/kafka/publisher/producer.py b/faststream/kafka/publisher/producer.py index 398dba7e4a..f7eb23ce00 100644 --- a/faststream/kafka/publisher/producer.py +++ b/faststream/kafka/publisher/producer.py @@ -4,10 +4,15 @@ from faststream.broker.message import encode_message from faststream.broker.publisher.proto import ProducerProto +from faststream.broker.utils import resolve_custom_func +from faststream.exceptions import OperationForbiddenError +from faststream.kafka.message import KafkaMessage +from faststream.kafka.parser import AioKafkaParser if TYPE_CHECKING: from aiokafka import AIOKafkaProducer + from faststream.broker.types import CustomCallable from faststream.types import SendableMessage @@ -17,9 +22,19 @@ class AioKafkaFastProducer(ProducerProto): def __init__( self, producer: "AIOKafkaProducer", + parser: Optional["CustomCallable"], + decoder: Optional["CustomCallable"], ) -> None: self._producer = producer + # NOTE: register default parser to be compatible with request + default = AioKafkaParser( + msg_class=KafkaMessage, + regex=None, + ) + self._parser = resolve_custom_func(parser, default.parse_message) + self._decoder = resolve_custom_func(decoder, default.decode_message) + @override async def publish( # type: ignore[override] self, @@ -100,3 +115,9 @@ async def publish_batch( ) await self._producer.send_batch(batch, topic, partition=partition) + + @override + async def request(self, *args: Any, **kwargs: Any) -> Optional[Any]: + raise OperationForbiddenError( + "Kafka doesn't support `request` method without test client." + ) diff --git a/faststream/kafka/publisher/usecase.py b/faststream/kafka/publisher/usecase.py index a317986d35..076bfb10b3 100644 --- a/faststream/kafka/publisher/usecase.py +++ b/faststream/kafka/publisher/usecase.py @@ -1,6 +1,18 @@ +from contextlib import AsyncExitStack from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + Optional, + Tuple, + Union, + cast, +) from aiokafka import ConsumerRecord from typing_extensions import Annotated, Doc, override @@ -9,9 +21,11 @@ from faststream.broker.publisher.usecase import PublisherUsecase from faststream.broker.types import MsgType from faststream.exceptions import NOT_CONNECTED_YET +from faststream.utils.functions import return_input if TYPE_CHECKING: from faststream.broker.types import BrokerMiddleware, PublisherMiddleware + from faststream.kafka.message import KafkaMessage from faststream.kafka.publisher.producer import AioKafkaFastProducer from faststream.types import AsyncFunc, SendableMessage @@ -60,6 +74,113 @@ def __hash__(self) -> int: def add_prefix(self, prefix: str) -> None: self.topic = "".join((prefix, self.topic)) + @override + async def request( + self, + message: Annotated[ + "SendableMessage", + Doc("Message body to send."), + ], + topic: Annotated[ + str, + Doc("Topic where the message will be published."), + ] = "", + *, + key: Annotated[ + Union[bytes, Any, None], + Doc( + """ + A key to associate with the message. Can be used to + determine which partition to send the message to. If partition + is `None` (and producer's partitioner config is left as default), + then messages with the same key will be delivered to the same + partition (but if key is `None`, partition is chosen randomly). + Must be type `bytes`, or be serializable to bytes via configured + `key_serializer`. + """ + ), + ] = None, + partition: Annotated[ + Optional[int], + Doc( + """ + Specify a partition. If not set, the partition will be + selected using the configured `partitioner`. + """ + ), + ] = None, + timestamp_ms: Annotated[ + Optional[int], + Doc( + """ + Epoch milliseconds (from Jan 1 1970 UTC) to use as + the message timestamp. Defaults to current time. + """ + ), + ] = None, + headers: Annotated[ + Optional[Dict[str, str]], + Doc("Message headers to store metainformation."), + ] = None, + correlation_id: Annotated[ + Optional[str], + Doc( + "Manual message **correlation_id** setter. " + "**correlation_id** is a useful option to trace messages." + ), + ] = None, + timeout: Annotated[ + float, + Doc("Timeout to send RPC request."), + ] = 0.5, + # publisher specific + _extra_middlewares: Annotated[ + Iterable["PublisherMiddleware"], + Doc("Extra middlewares to wrap publishing process."), + ] = (), + ) -> "KafkaMessage": + assert self._producer, NOT_CONNECTED_YET # nosec B101 + + topic = topic or self.topic + partition = partition or self.partition + headers = headers or self.headers + correlation_id = correlation_id or gen_cor_id() + + request: AsyncFunc = self._producer.request + + for pub_m in chain( + ( + _extra_middlewares + or (m(None).publish_scope for m in self._broker_middlewares) + ), + self._middlewares, + ): + request = partial(pub_m, request) + + published_msg = await request( + message, + topic=topic, + key=key, + partition=partition, + headers=headers, + timeout=timeout, + correlation_id=correlation_id, + timestamp_ms=timestamp_ms, + ) + + async with AsyncExitStack() as stack: + return_msg: Callable[[KafkaMessage], Awaitable[KafkaMessage]] = return_input + for m in self._broker_middlewares: + mid = m(published_msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._producer._parser(published_msg) + parsed_msg._decoded_body = await self._producer._decoder(parsed_msg) + return await return_msg(parsed_msg) + + raise AssertionError("unreachable") + class DefaultPublisher(LogicPublisher[ConsumerRecord]): def __init__( @@ -192,6 +313,83 @@ async def publish( timestamp_ms=timestamp_ms, ) + @override + async def request( + self, + message: Annotated[ + "SendableMessage", + Doc("Message body to send."), + ], + topic: Annotated[ + str, + Doc("Topic where the message will be published."), + ] = "", + *, + key: Annotated[ + Union[bytes, Any, None], + Doc( + """ + A key to associate with the message. Can be used to + determine which partition to send the message to. If partition + is `None` (and producer's partitioner config is left as default), + then messages with the same key will be delivered to the same + partition (but if key is `None`, partition is chosen randomly). + Must be type `bytes`, or be serializable to bytes via configured + `key_serializer`. + """ + ), + ] = None, + partition: Annotated[ + Optional[int], + Doc( + """ + Specify a partition. If not set, the partition will be + selected using the configured `partitioner`. + """ + ), + ] = None, + timestamp_ms: Annotated[ + Optional[int], + Doc( + """ + Epoch milliseconds (from Jan 1 1970 UTC) to use as + the message timestamp. Defaults to current time. + """ + ), + ] = None, + headers: Annotated[ + Optional[Dict[str, str]], + Doc("Message headers to store metainformation."), + ] = None, + correlation_id: Annotated[ + Optional[str], + Doc( + "Manual message **correlation_id** setter. " + "**correlation_id** is a useful option to trace messages." + ), + ] = None, + timeout: Annotated[ + float, + Doc("Timeout to send RPC request."), + ] = 0.5, + # publisher specific + _extra_middlewares: Annotated[ + Iterable["PublisherMiddleware"], + Doc("Extra middlewares to wrap publishing process."), + ] = (), + ) -> "KafkaMessage": + return await super().request( + message=message, + topic=topic, + key=key or self.key, + partition=partition, + timestamp_ms=timestamp_ms, + headers=headers, + correlation_id=correlation_id, + timeout=timeout, + _extra_middlewares=_extra_middlewares, + ) + class BatchPublisher(LogicPublisher[Tuple["ConsumerRecord", ...]]): @override diff --git a/faststream/kafka/testing.py b/faststream/kafka/testing.py index db676f001f..744778a129 100755 --- a/faststream/kafka/testing.py +++ b/faststream/kafka/testing.py @@ -3,16 +3,22 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from unittest.mock import AsyncMock, MagicMock +import anyio from aiokafka import ConsumerRecord from typing_extensions import override from faststream.broker.message import encode_message, gen_cor_id +from faststream.broker.utils import resolve_custom_func +from faststream.exceptions import SubscriberNotFound from faststream.kafka import TopicPartition from faststream.kafka.broker import KafkaBroker +from faststream.kafka.message import KafkaMessage +from faststream.kafka.parser import AioKafkaParser from faststream.kafka.publisher.asyncapi import AsyncAPIBatchPublisher from faststream.kafka.publisher.producer import AioKafkaFastProducer from faststream.kafka.subscriber.asyncapi import AsyncAPIBatchSubscriber -from faststream.testing.broker import TestBroker, call_handler +from faststream.testing.broker import TestBroker +from faststream.utils.functions import timeout_scope if TYPE_CHECKING: from faststream.broker.wrapper.call import HandlerCallWrapper @@ -79,6 +85,14 @@ class FakeProducer(AioKafkaFastProducer): def __init__(self, broker: KafkaBroker) -> None: self.broker = broker + default = AioKafkaParser( + msg_class=KafkaMessage, + regex=None, + ) + + self._parser = resolve_custom_func(broker._parser, default.parse_message) + self._decoder = resolve_custom_func(broker._decoder, default.decode_message) + @override async def publish( # type: ignore[override] self, @@ -111,20 +125,59 @@ async def publish( # type: ignore[override] for handler in self.broker._subscribers.values(): # pragma: no branch if _is_handler_matches(handler, topic, partition): - handle_value = await call_handler( - handler=handler, - message=[incoming] + msg_to_send = ( + [incoming] if isinstance(handler, AsyncAPIBatchSubscriber) - else incoming, - rpc=rpc, - rpc_timeout=rpc_timeout, - raise_timeout=raise_timeout, + else incoming ) - return_value = return_value or handle_value + with timeout_scope(rpc_timeout, raise_timeout): + response_msg = await self._execute_handler( + msg_to_send, topic, handler + ) + if rpc: + return_value = return_value or await self._decoder( + await self._parser(response_msg) + ) return return_value + @override + async def request( # type: ignore[override] + self, + message: "SendableMessage", + topic: str, + key: Optional[bytes] = None, + partition: Optional[int] = None, + timestamp_ms: Optional[int] = None, + headers: Optional[Dict[str, str]] = None, + correlation_id: Optional[str] = None, + *, + timeout: Optional[float] = 0.5, + ) -> "ConsumerRecord": + incoming = build_message( + message=message, + topic=topic, + key=key, + partition=partition, + timestamp_ms=timestamp_ms, + headers=headers, + correlation_id=correlation_id, + ) + + for handler in self.broker._subscribers.values(): # pragma: no branch + if _is_handler_matches(handler, topic, partition): + msg_to_send = ( + [incoming] + if isinstance(handler, AsyncAPIBatchSubscriber) + else incoming + ) + + with anyio.fail_after(timeout): + return await self._execute_handler(msg_to_send, topic, handler) + + raise SubscriberNotFound + async def publish_batch( self, *msgs: "SendableMessage", @@ -152,19 +205,28 @@ async def publish_batch( ) if isinstance(handler, AsyncAPIBatchSubscriber): - await call_handler( - handler=handler, - message=list(messages), - ) + await self._execute_handler(list(messages), topic, handler) else: for m in messages: - await call_handler( - handler=handler, - message=m, - ) + await self._execute_handler(m, topic, handler) return None + async def _execute_handler( + self, + msg: Any, + topic: str, + handler: "LogicSubscriber[Any]", + ) -> "ConsumerRecord": + result = await handler.process_message(msg) + + return build_message( + topic=topic, + message=result.body, + headers=result.headers, + correlation_id=result.correlation_id, + ) + def build_message( message: "SendableMessage", @@ -176,7 +238,7 @@ def build_message( correlation_id: Optional[str] = None, *, reply_to: str = "", -) -> ConsumerRecord: +) -> "ConsumerRecord": """Build a Kafka ConsumerRecord for a sendable message.""" msg, content_type = encode_message(message) diff --git a/faststream/nats/broker/broker.py b/faststream/nats/broker/broker.py index f4b732010b..ae956d50e7 100644 --- a/faststream/nats/broker/broker.py +++ b/faststream/nats/broker/broker.py @@ -27,7 +27,7 @@ ) from nats.errors import Error from nats.js.errors import BadRequestError -from typing_extensions import Annotated, Doc, override +from typing_extensions import Annotated, Doc, deprecated, override from faststream.__about__ import SERVICE_NAME from faststream.broker.message import gen_cor_id @@ -65,6 +65,7 @@ BrokerMiddleware, CustomCallable, ) + from faststream.nats.message import NatsMessage from faststream.nats.publisher.asyncapi import AsyncAPIPublisher from faststream.security import BaseSecurity from faststream.types import ( @@ -715,10 +716,20 @@ async def publish( # type: ignore[override] rpc: Annotated[ bool, Doc("Whether to wait for reply in blocking mode."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, rpc_timeout: Annotated[ Optional[float], Doc("RPC reply waiting time."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method with `timeout` instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = 30.0, raise_timeout: Annotated[ bool, @@ -726,6 +737,11 @@ async def publish( # type: ignore[override] "Whetever to raise `TimeoutError` or return `None` at **rpc_timeout**. " "RPC request returns `None` at timeout by default." ), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "`request` always raises TimeoutError instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, ) -> Optional["DecodedMessage"]: """Publish message directly. @@ -739,7 +755,6 @@ async def publish( # type: ignore[override] "subject": subject, "headers": headers, "reply_to": reply_to, - "correlation_id": correlation_id or gen_cor_id(), "rpc": rpc, "rpc_timeout": rpc_timeout, "raise_timeout": raise_timeout, @@ -760,8 +775,71 @@ async def publish( # type: ignore[override] return await super().publish( message, producer=producer, + correlation_id=correlation_id or gen_cor_id(), + **publish_kwargs, + ) + + @override + async def request( # type: ignore[override] + self, + message: Annotated[ + "SendableMessage", + Doc( + "Message body to send. " + "Can be any encodable object (native python types or `pydantic.BaseModel`)." + ), + ], + subject: Annotated[ + str, + Doc("NATS subject to send message."), + ], + headers: Annotated[ + Optional[Dict[str, str]], + Doc( + "Message headers to store metainformation. " + "**content-type** and **correlation_id** will be set automatically by framework anyway." + ), + ] = None, + correlation_id: Annotated[ + Optional[str], + Doc( + "Manual message **correlation_id** setter. " + "**correlation_id** is a useful option to trace messages." + ), + ] = None, + stream: Annotated[ + Optional[str], + Doc( + "This option validates that the target subject is in presented stream. " + "Can be omitted without any effect." + ), + ] = None, + timeout: Annotated[ + float, + Doc("Timeout to send message to NATS."), + ] = 0.5, + ) -> "NatsMessage": + publish_kwargs = { + "subject": subject, + "headers": headers, + "timeout": timeout, + } + + producer: Optional[ProducerProto] + if stream is None: + producer = self._producer + + else: + producer = self._js_producer + publish_kwargs.update({"stream": stream}) + + msg: NatsMessage = await super().request( + message, + producer=producer, + correlation_id=correlation_id or gen_cor_id(), **publish_kwargs, ) + return msg @override def setup_subscriber( # type: ignore[override] diff --git a/faststream/nats/fastapi/fastapi.py b/faststream/nats/fastapi/fastapi.py index 5f5dfe3281..263465543e 100644 --- a/faststream/nats/fastapi/fastapi.py +++ b/faststream/nats/fastapi/fastapi.py @@ -866,7 +866,6 @@ def subscriber( # type: ignore[override] return cast( AsyncAPISubscriber, super().subscriber( - path=subject, subject=subject, queue=queue, pending_msgs_limit=pending_msgs_limit, diff --git a/faststream/nats/publisher/producer.py b/faststream/nats/publisher/producer.py index 61230f8c2c..ffc4aedb70 100644 --- a/faststream/nats/publisher/producer.py +++ b/faststream/nats/publisher/producer.py @@ -1,6 +1,7 @@ import asyncio from typing import TYPE_CHECKING, Any, Dict, Optional +import anyio import nats from typing_extensions import override @@ -37,10 +38,10 @@ def __init__( decoder: Optional["CustomCallable"], ) -> None: self._connection = connection - self._parser = resolve_custom_func(parser, NatsParser(pattern="").parse_message) - self._decoder = resolve_custom_func( - decoder, NatsParser(pattern="").decode_message - ) + + default = NatsParser(pattern="") + self._parser = resolve_custom_func(parser, default.parse_message) + self._decoder = resolve_custom_func(decoder, default.decode_message) @override async def publish( # type: ignore[override] @@ -99,6 +100,31 @@ async def publish( # type: ignore[override] return None + @override + async def request( # type: ignore[override] + self, + message: "SendableMessage", + subject: str, + *, + correlation_id: str, + headers: Optional[Dict[str, str]] = None, + timeout: float = 0.5, + ) -> "Msg": + payload, content_type = encode_message(message) + + headers_to_send = { + "content-type": content_type or "", + "correlation_id": correlation_id, + **(headers or {}), + } + + return await self._connection.request( + subject=subject, + payload=payload, + headers=headers_to_send, + timeout=timeout, + ) + class NatsJSFastProducer(ProducerProto): """A class to represent a NATS JetStream producer.""" @@ -114,10 +140,10 @@ def __init__( decoder: Optional["CustomCallable"], ) -> None: self._connection = connection - self._parser = resolve_custom_func(parser, NatsParser(pattern="").parse_message) - self._decoder = resolve_custom_func( - decoder, NatsParser(pattern="").decode_message - ) + + default = NatsParser(pattern="") + self._parser = resolve_custom_func(parser, default.parse_message) + self._decoder = resolve_custom_func(decoder, default.decode_message) @override async def publish( # type: ignore[override] @@ -178,3 +204,50 @@ async def publish( # type: ignore[override] return await self._decoder(await self._parser(msg)) return None + + @override + async def request( # type: ignore[override] + self, + message: "SendableMessage", + subject: str, + *, + correlation_id: str, + headers: Optional[Dict[str, str]] = None, + stream: Optional[str] = None, + timeout: float = 0.5, + ) -> "Msg": + payload, content_type = encode_message(message) + + reply_to = self._connection._nc.new_inbox() + future: asyncio.Future[Msg] = asyncio.Future() + sub = await self._connection._nc.subscribe(reply_to, future=future, max_msgs=1) + await sub.unsubscribe(limit=1) + + headers_to_send = { + "content-type": content_type or "", + "correlation_id": correlation_id, + "reply_to": reply_to, + **(headers or {}), + } + + with anyio.fail_after(timeout): + await self._connection.publish( + subject=subject, + payload=payload, + headers=headers_to_send, + stream=stream, + timeout=timeout, + ) + + msg = await future + + if ( # pragma: no cover + msg.headers + and ( + msg.headers.get(nats.js.api.Header.STATUS) + == nats.aio.client.NO_RESPONDERS_STATUS + ) + ): + raise nats.errors.NoRespondersError + + return msg diff --git a/faststream/nats/publisher/usecase.py b/faststream/nats/publisher/usecase.py index 291eb94ac2..83f9a7c0e4 100644 --- a/faststream/nats/publisher/usecase.py +++ b/faststream/nats/publisher/usecase.py @@ -1,16 +1,28 @@ +from contextlib import AsyncExitStack from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + Optional, + Union, +) from nats.aio.msg import Msg -from typing_extensions import override +from typing_extensions import Annotated, Doc, override from faststream.broker.message import gen_cor_id from faststream.broker.publisher.usecase import PublisherUsecase from faststream.exceptions import NOT_CONNECTED_YET +from faststream.utils.functions import return_input if TYPE_CHECKING: from faststream.broker.types import BrokerMiddleware, PublisherMiddleware + from faststream.nats.message import NatsMessage from faststream.nats.publisher.producer import NatsFastProducer, NatsJSFastProducer from faststream.nats.schemas import JStream from faststream.types import AnyDict, AsyncFunc, SendableMessage @@ -127,5 +139,82 @@ async def publish( return await call(message, **kwargs) + @override + async def request( + self, + message: Annotated[ + "SendableMessage", + Doc( + "Message body to send. " + "Can be any encodable object (native python types or `pydantic.BaseModel`)." + ), + ], + subject: Annotated[ + str, + Doc("NATS subject to send message."), + ] = "", + *, + headers: Annotated[ + Optional[Dict[str, str]], + Doc( + "Message headers to store metainformation. " + "**content-type** and **correlation_id** will be set automatically by framework anyway." + ), + ] = None, + correlation_id: Annotated[ + Optional[str], + Doc( + "Manual message **correlation_id** setter. " + "**correlation_id** is a useful option to trace messages." + ), + ] = None, + timeout: Annotated[ + float, + Doc("Timeout to send message to NATS."), + ] = 0.5, + # publisher specific + _extra_middlewares: Annotated[ + Iterable["PublisherMiddleware"], + Doc("Extra middlewares to wrap publishing process."), + ] = (), + ) -> "NatsMessage": + assert self._producer, NOT_CONNECTED_YET # nosec B101 + + kwargs: AnyDict = { + "subject": subject or self.subject, + "headers": headers or self.headers, + "timeout": timeout or self.timeout, + "correlation_id": correlation_id or gen_cor_id(), + } + + request: AsyncFunc = self._producer.request + + for pub_m in chain( + ( + _extra_middlewares + or (m(None).publish_scope for m in self._broker_middlewares) + ), + self._middlewares, + ): + request = partial(pub_m, request) + + published_msg = await request( + message, + **kwargs, + ) + + async with AsyncExitStack() as stack: + return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = return_input + for m in self._broker_middlewares: + mid = m(published_msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._producer._parser(published_msg) + parsed_msg._decoded_body = await self._producer._decoder(parsed_msg) + return await return_msg(parsed_msg) + + raise AssertionError("unreachable") + def add_prefix(self, prefix: str) -> None: self.subject = prefix + self.subject diff --git a/faststream/nats/testing.py b/faststream/nats/testing.py index 7d9b254118..425af25762 100644 --- a/faststream/nats/testing.py +++ b/faststream/nats/testing.py @@ -1,16 +1,20 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from unittest.mock import AsyncMock +import anyio from nats.aio.msg import Msg from typing_extensions import override from faststream.broker.message import encode_message, gen_cor_id -from faststream.exceptions import WRONG_PUBLISH_ARGS +from faststream.broker.utils import resolve_custom_func +from faststream.exceptions import WRONG_PUBLISH_ARGS, SubscriberNotFound from faststream.nats.broker import NatsBroker +from faststream.nats.parser import NatsParser from faststream.nats.publisher.producer import NatsFastProducer from faststream.nats.schemas.js_stream import is_subject_match_wildcard from faststream.nats.subscriber.usecase import LogicSubscriber -from faststream.testing.broker import TestBroker, call_handler +from faststream.testing.broker import TestBroker +from faststream.utils.functions import timeout_scope if TYPE_CHECKING: from faststream.broker.wrapper.call import HandlerCallWrapper @@ -47,7 +51,9 @@ async def _fake_connect( # type: ignore[override] **kwargs: Any, ) -> AsyncMock: broker.stream = AsyncMock() - broker._js_producer = broker._producer = FakeProducer(broker) # type: ignore[assignment] + broker._js_producer = broker._producer = FakeProducer( # type: ignore[assignment] + broker, + ) return AsyncMock() @staticmethod @@ -63,6 +69,10 @@ class FakeProducer(NatsFastProducer): def __init__(self, broker: NatsBroker) -> None: self.broker = broker + default = NatsParser(pattern="") + self._parser = resolve_custom_func(broker._parser, default.parse_message) + self._decoder = resolve_custom_func(broker._decoder, default.decode_message) + @override async def publish( # type: ignore[override] self, @@ -91,35 +101,88 @@ async def publish( # type: ignore[override] ) for handler in self.broker._subscribers.values(): # pragma: no branch - if stream and ( - not (handler_stream := getattr(handler, "stream", None)) - or stream != handler_stream.name - ): - continue - - if is_subject_match_wildcard(subject, handler.clear_subject) or any( - is_subject_match_wildcard(subject, filter_subject) - for filter_subject in (handler.config.filter_subjects or ()) - ): + if _is_handler_suitable(handler, subject, stream): msg: Union[List[PatchedMessage], PatchedMessage] + if (pull := getattr(handler, "pull_sub", None)) and pull.batch: msg = [incoming] else: msg = incoming - r = await call_handler( - handler=handler, - message=msg, - rpc=rpc, - rpc_timeout=rpc_timeout, - raise_timeout=raise_timeout, - ) - - if rpc: - return r + with timeout_scope(rpc_timeout, raise_timeout): + response = await self._execute_handler(msg, subject, handler) + if rpc: + return await self._decoder(await self._parser(response)) return None + @override + async def request( # type: ignore[override] + self, + message: "SendableMessage", + subject: str, + *, + correlation_id: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + timeout: float = 0.5, + # NatsJSFastProducer compatibility + stream: Optional[str] = None, + ) -> "PatchedMessage": + incoming = build_message( + message=message, + subject=subject, + headers=headers, + correlation_id=correlation_id, + ) + + for handler in self.broker._subscribers.values(): # pragma: no branch + if _is_handler_suitable(handler, subject, stream): + msg: Union[List[PatchedMessage], PatchedMessage] + + if (pull := getattr(handler, "pull_sub", None)) and pull.batch: + msg = [incoming] + else: + msg = incoming + + with anyio.fail_after(timeout): + return await self._execute_handler(msg, subject, handler) + + raise SubscriberNotFound + + async def _execute_handler( + self, msg: Any, subject: str, handler: "LogicSubscriber[Any]" + ) -> "PatchedMessage": + result = await handler.process_message(msg) + + return build_message( + subject=subject, + message=result.body, + headers=result.headers, + correlation_id=result.correlation_id, + ) + + +def _is_handler_suitable( + handler: "LogicSubscriber[Any]", + subject: str, + stream: Optional[str] = None, +) -> bool: + if stream: + if not (handler_stream := getattr(handler, "stream", None)): + return False + + if stream != handler_stream.name: + return False + + if is_subject_match_wildcard(subject, handler.clear_subject): + return True + + for filter_subject in handler.config.filter_subjects or (): + if is_subject_match_wildcard(subject, filter_subject): + return True + + return False + def build_message( message: "SendableMessage", diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index da8fe49707..774a88665e 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -13,7 +13,7 @@ import anyio from aio_pika import connect_robust -from typing_extensions import Annotated, Doc, override +from typing_extensions import Annotated, Doc, deprecated, override from faststream.__about__ import SERVICE_NAME from faststream.broker.message import gen_cor_id @@ -53,6 +53,7 @@ BrokerMiddleware, CustomCallable, ) + from faststream.rabbit.message import RabbitMessage from faststream.rabbit.types import AioPikaSendableMessage from faststream.security import BaseSecurity from faststream.types import AnyDict, Decorator, LoggerProto @@ -572,10 +573,20 @@ async def publish( # type: ignore[override] rpc: Annotated[ bool, Doc("Whether to wait for reply in blocking mode."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, rpc_timeout: Annotated[ Optional[float], Doc("RPC reply waiting time."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method with `timeout` instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = 30.0, raise_timeout: Annotated[ bool, @@ -583,6 +594,11 @@ async def publish( # type: ignore[override] "Whetever to raise `TimeoutError` or return `None` at **rpc_timeout**. " "RPC request returns `None` at timeout by default." ), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "`request` always raises TimeoutError instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, # message args correlation_id: Annotated[ @@ -669,6 +685,126 @@ async def publish( # type: ignore[override] raise_timeout=raise_timeout, ) + @override + async def request( # type: ignore[override] + self, + message: Annotated[ + "AioPikaSendableMessage", + Doc("Message body to send."), + ] = None, + queue: Annotated[ + Union["RabbitQueue", str], + Doc("Message routing key to publish with."), + ] = "", + exchange: Annotated[ + Union["RabbitExchange", str, None], + Doc("Target exchange to publish message to."), + ] = None, + *, + routing_key: Annotated[ + str, + Doc( + "Message routing key to publish with. " + "Overrides `queue` option if presented." + ), + ] = "", + mandatory: Annotated[ + bool, + Doc( + "Client waits for confirmation that the message is placed to some queue. " + "RabbitMQ returns message to client if there is no suitable queue." + ), + ] = True, + immediate: Annotated[ + bool, + Doc( + "Client expects that there is consumer ready to take the message to work. " + "RabbitMQ returns message to client if there is no suitable consumer." + ), + ] = False, + timeout: Annotated[ + "TimeoutType", + Doc("Send confirmation time from RabbitMQ."), + ] = None, + persist: Annotated[ + bool, + Doc("Restore the message on RabbitMQ reboot."), + ] = False, + # message args + correlation_id: Annotated[ + Optional[str], + Doc( + "Manual message **correlation_id** setter. " + "**correlation_id** is a useful option to trace messages." + ), + ] = None, + headers: Annotated[ + Optional["HeadersType"], + Doc("Message headers to store metainformation."), + ] = None, + content_type: Annotated[ + Optional[str], + Doc( + "Message **content-type** header. " + "Used by application, not core RabbitMQ. " + "Will be set automatically if not specified." + ), + ] = None, + content_encoding: Annotated[ + Optional[str], + Doc("Message body content encoding, e.g. **gzip**."), + ] = None, + expiration: Annotated[ + Optional["DateType"], + Doc("Message expiration (lifetime) in seconds (or datetime or timedelta)."), + ] = None, + message_id: Annotated[ + Optional[str], + Doc("Arbitrary message id. Generated automatically if not presented."), + ] = None, + timestamp: Annotated[ + Optional["DateType"], + Doc("Message publish timestamp. Generated automatically if not presented."), + ] = None, + message_type: Annotated[ + Optional[str], + Doc("Application-specific message type, e.g. **orders.created**."), + ] = None, + user_id: Annotated[ + Optional[str], + Doc("Publisher connection User ID, validated if set."), + ] = None, + priority: Annotated[ + Optional[int], + Doc("The message priority (0 by default)."), + ] = None, + ) -> "RabbitMessage": + routing = routing_key or RabbitQueue.validate(queue).routing + correlation_id = correlation_id or gen_cor_id() + + msg: RabbitMessage = await super().request( + message, + producer=self._producer, + correlation_id=correlation_id, + routing_key=routing, + app_id=self.app_id, + exchange=exchange, + mandatory=mandatory, + immediate=immediate, + persist=persist, + headers=headers, + content_type=content_type, + content_encoding=content_encoding, + expiration=expiration, + message_id=message_id, + timestamp=timestamp, + message_type=message_type, + user_id=user_id, + timeout=timeout, + priority=priority, + ) + return msg + async def declare_queue( self, queue: Annotated[ diff --git a/faststream/rabbit/fastapi/router.py b/faststream/rabbit/fastapi/router.py index 2d155ee973..b634c2738f 100644 --- a/faststream/rabbit/fastapi/router.py +++ b/faststream/rabbit/fastapi/router.py @@ -691,11 +691,9 @@ def subscriber( # type: ignore[override] ), ] = False, ) -> AsyncAPISubscriber: - queue = RabbitQueue.validate(queue) return cast( AsyncAPISubscriber, super().subscriber( - path=queue.name, queue=queue, exchange=exchange, consume_args=consume_args, diff --git a/faststream/rabbit/message.py b/faststream/rabbit/message.py index 44c4ef4446..4287cf2fd7 100644 --- a/faststream/rabbit/message.py +++ b/faststream/rabbit/message.py @@ -17,10 +17,7 @@ async def ack( """Acknowledge the RabbitMQ message.""" pika_message = self.raw_message await super().ack() - if ( - pika_message._IncomingMessage__processed # type: ignore[attr-defined] - or pika_message._IncomingMessage__no_ack # type: ignore[attr-defined] - ): + if pika_message.locked: return await pika_message.ack(multiple=multiple) @@ -32,10 +29,7 @@ async def nack( """Negative Acknowledgment of the RabbitMQ message.""" pika_message = self.raw_message await super().nack() - if ( - pika_message._IncomingMessage__processed # type: ignore[attr-defined] - or pika_message._IncomingMessage__no_ack # type: ignore[attr-defined] - ): + if pika_message.locked: return await pika_message.nack(multiple=multiple, requeue=requeue) @@ -46,9 +40,6 @@ async def reject( """Reject the RabbitMQ message.""" pika_message = self.raw_message await super().reject() - if ( - pika_message._IncomingMessage__processed # type: ignore[attr-defined] - or pika_message._IncomingMessage__no_ack # type: ignore[attr-defined] - ): + if pika_message.locked: return await pika_message.reject(requeue=requeue) diff --git a/faststream/rabbit/publisher/producer.py b/faststream/rabbit/publisher/producer.py index 71e12ab427..ea83ba0672 100644 --- a/faststream/rabbit/publisher/producer.py +++ b/faststream/rabbit/publisher/producer.py @@ -134,6 +134,58 @@ async def publish( # type: ignore[override] return None + @override + async def request( # type: ignore[override] + self, + message: "AioPikaSendableMessage", + exchange: Union["RabbitExchange", str, None] = None, + *, + correlation_id: str = "", + routing_key: str = "", + mandatory: bool = True, + immediate: bool = False, + timeout: Optional[float] = None, + persist: bool = False, + headers: Optional["HeadersType"] = None, + content_type: Optional[str] = None, + content_encoding: Optional[str] = None, + priority: Optional[int] = None, + expiration: Optional["DateType"] = None, + message_id: Optional[str] = None, + timestamp: Optional["DateType"] = None, + message_type: Optional[str] = None, + user_id: Optional[str] = None, + app_id: Optional[str] = None, + ) -> "IncomingMessage": + """Publish a message to a RabbitMQ queue.""" + async with _RPCCallback( + self._rpc_lock, + await self.declarer.declare_queue(RABBIT_REPLY), + ) as response_queue: + with anyio.fail_after(timeout): + await self._publish( + message=message, + exchange=exchange, + routing_key=routing_key, + mandatory=mandatory, + immediate=immediate, + timeout=timeout, + persist=persist, + reply_to=RABBIT_REPLY.name, + headers=headers, + content_type=content_type, + content_encoding=content_encoding, + priority=priority, + correlation_id=correlation_id, + expiration=expiration, + message_id=message_id, + timestamp=timestamp, + message_type=message_type, + user_id=user_id, + app_id=app_id, + ) + return await response_queue.receive() + async def _publish( self, message: "AioPikaSendableMessage", diff --git a/faststream/rabbit/publisher/usecase.py b/faststream/rabbit/publisher/usecase.py index 6df3b1078a..c3306c78d2 100644 --- a/faststream/rabbit/publisher/usecase.py +++ b/faststream/rabbit/publisher/usecase.py @@ -1,21 +1,32 @@ +from contextlib import AsyncExitStack from copy import deepcopy from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Any, Iterable, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Iterable, + Optional, + Union, +) from aio_pika import IncomingMessage -from typing_extensions import Annotated, Doc, TypedDict, Unpack, override +from typing_extensions import Annotated, Doc, TypedDict, Unpack, deprecated, override from faststream.broker.message import gen_cor_id from faststream.broker.publisher.usecase import PublisherUsecase from faststream.exceptions import NOT_CONNECTED_YET from faststream.rabbit.schemas import BaseRMQInformation, RabbitQueue from faststream.rabbit.subscriber.usecase import LogicSubscriber +from faststream.utils.functions import return_input if TYPE_CHECKING: from aio_pika.abc import DateType, HeadersType, TimeoutType from faststream.broker.types import BrokerMiddleware, PublisherMiddleware + from faststream.rabbit.message import RabbitMessage from faststream.rabbit.publisher.producer import AioPikaFastProducer from faststream.rabbit.schemas.exchange import RabbitExchange from faststream.rabbit.types import AioPikaSendableMessage @@ -23,8 +34,8 @@ # should be public to use in imports -class PublishKwargs(TypedDict, total=False): - """Typed dict to annotate RabbitMQ publishers.""" +class RequestPublishKwargs(TypedDict, total=False): + """Typed dict to annotate RabbitMQ requesters.""" headers: Annotated[ Optional["HeadersType"], @@ -55,12 +66,7 @@ class PublishKwargs(TypedDict, total=False): Optional[bool], Doc("Restore the message on RabbitMQ reboot."), ] - reply_to: Annotated[ - Optional[str], - Doc( - "Reply message routing key to send with (always sending to default exchange)." - ), - ] + priority: Annotated[ Optional[int], Doc("The message priority (0 by default)."), @@ -91,6 +97,17 @@ class PublishKwargs(TypedDict, total=False): ] +class PublishKwargs(RequestPublishKwargs): + """Typed dict to annotate RabbitMQ publishers.""" + + reply_to: Annotated[ + Optional[str], + Doc( + "Reply message routing key to send with (always sending to default exchange)." + ), + ] + + class LogicPublisher( PublisherUsecase[IncomingMessage], BaseRMQInformation, @@ -128,7 +145,10 @@ def __init__( ) self.routing_key = routing_key - self.message_kwargs = message_kwargs + + request_kwargs = dict(message_kwargs) + self.reply_to = request_kwargs.pop("reply_to", None) + self.message_kwargs = request_kwargs # BaseRMQInformation self.queue = queue @@ -161,7 +181,7 @@ def __hash__(self) -> int: ) @override - async def publish( + async def publish( # type: ignore[override] self, message: "AioPikaSendableMessage", queue: Annotated[ @@ -200,10 +220,20 @@ async def publish( rpc: Annotated[ bool, Doc("Whether to wait for reply in blocking mode."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, rpc_timeout: Annotated[ Optional[float], Doc("RPC reply waiting time."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method with `timeout` instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = 30.0, raise_timeout: Annotated[ bool, @@ -211,6 +241,11 @@ async def publish( "Whetever to raise `TimeoutError` or return `None` at **rpc_timeout**. " "RPC request returns `None` at timeout by default." ), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "`request` always raises TimeoutError instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, # publisher specific _extra_middlewares: Annotated[ @@ -234,6 +269,7 @@ async def publish( "rpc": rpc, "rpc_timeout": rpc_timeout, "raise_timeout": raise_timeout, + "reply_to": self.reply_to, **self.message_kwargs, **publish_kwargs, } @@ -251,6 +287,96 @@ async def publish( return await call(message, **kwargs) + @override + async def request( + self, + message: "AioPikaSendableMessage", + queue: Annotated[ + Union["RabbitQueue", str, None], + Doc("Message routing key to publish with."), + ] = None, + exchange: Annotated[ + Union["RabbitExchange", str, None], + Doc("Target exchange to publish message to."), + ] = None, + *, + routing_key: Annotated[ + str, + Doc( + "Message routing key to publish with. " + "Overrides `queue` option if presented." + ), + ] = "", + # message args + correlation_id: Annotated[ + Optional[str], + Doc( + "Manual message **correlation_id** setter. " + "**correlation_id** is a useful option to trace messages." + ), + ] = None, + message_id: Annotated[ + Optional[str], + Doc("Arbitrary message id. Generated automatically if not presented."), + ] = None, + timestamp: Annotated[ + Optional["DateType"], + Doc("Message publish timestamp. Generated automatically if not presented."), + ] = None, + # publisher specific + _extra_middlewares: Annotated[ + Iterable["PublisherMiddleware"], + Doc("Extra middlewares to wrap publishing process."), + ] = (), + **publish_kwargs: "Unpack[RequestPublishKwargs]", + ) -> "RabbitMessage": + assert self._producer, NOT_CONNECTED_YET # nosec B101 + + kwargs: AnyDict = { + "routing_key": routing_key + or self.routing_key + or RabbitQueue.validate(queue or self.queue).routing, + "exchange": exchange or self.exchange.name, + "app_id": self.app_id, + "correlation_id": correlation_id or gen_cor_id(), + "message_id": message_id, + "timestamp": timestamp, + # specific args + **self.message_kwargs, + **publish_kwargs, + } + + request: AsyncFunc = self._producer.request + + for pub_m in chain( + ( + _extra_middlewares + or (m(None).publish_scope for m in self._broker_middlewares) + ), + self._middlewares, + ): + request = partial(pub_m, request) + + published_msg = await request( + message, + **kwargs, + ) + + async with AsyncExitStack() as stack: + return_msg: Callable[[RabbitMessage], Awaitable[RabbitMessage]] = ( + return_input + ) + for m in self._broker_middlewares: + mid = m(published_msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._producer._parser(published_msg) + parsed_msg._decoded_body = await self._producer._decoder(parsed_msg) + return await return_msg(parsed_msg) + + raise AssertionError("unreachable") + def add_prefix(self, prefix: str) -> None: """Include Publisher in router.""" new_q = deepcopy(self.queue) diff --git a/faststream/rabbit/testing.py b/faststream/rabbit/testing.py index f2a9183c7c..5e78fd6033 100644 --- a/faststream/rabbit/testing.py +++ b/faststream/rabbit/testing.py @@ -1,16 +1,18 @@ from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Generator, Optional, Union +from typing import TYPE_CHECKING, Any, Generator, Mapping, Optional, Union from unittest import mock from unittest.mock import AsyncMock import aiormq +import anyio from aio_pika.message import IncomingMessage from pamqp import commands as spec from pamqp.header import ContentHeader from typing_extensions import override from faststream.broker.message import gen_cor_id -from faststream.exceptions import WRONG_PUBLISH_ARGS +from faststream.broker.utils import resolve_custom_func +from faststream.exceptions import WRONG_PUBLISH_ARGS, SubscriberNotFound from faststream.rabbit.broker.broker import RabbitBroker from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.publisher.asyncapi import AsyncAPIPublisher @@ -20,8 +22,9 @@ RabbitExchange, RabbitQueue, ) -from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber -from faststream.testing.broker import TestBroker, call_handler +from faststream.rabbit.subscriber.usecase import LogicSubscriber +from faststream.testing.broker import TestBroker +from faststream.utils.functions import timeout_scope if TYPE_CHECKING: from aio_pika.abc import DateType, HeadersType, TimeoutType @@ -29,6 +32,7 @@ from faststream.broker.wrapper.call import HandlerCallWrapper from faststream.rabbit.types import AioPikaSendableMessage + __all__ = ("TestRabbitBroker",) @@ -79,7 +83,7 @@ def remove_publisher_fake_subscriber( publisher: AsyncAPIPublisher, ) -> None: broker._subscribers.pop( - AsyncAPISubscriber.get_routing_hash( + LogicSubscriber.get_routing_hash( queue=RabbitQueue.validate(publisher.routing), exchange=RabbitExchange.validate(publisher.exchange), ), @@ -93,6 +97,8 @@ class PatchedMessage(IncomingMessage): This class extends aio_pika's IncomingMessage class and is used to simulate RabbitMQ message handling during tests. """ + routing_key: str + async def ack(self, multiple: bool = False) -> None: """Asynchronously acknowledge a message.""" pass @@ -185,12 +191,19 @@ class FakeProducer(AioPikaFastProducer): def __init__(self, broker: RabbitBroker) -> None: self.broker = broker + default_parser = AioPikaParser() + self._parser = resolve_custom_func(broker._parser, default_parser.parse_message) + self._decoder = resolve_custom_func( + broker._decoder, default_parser.decode_message + ) + @override - async def publish( + async def publish( # type: ignore[override] self, - message: "AioPikaSendableMessage" = "", + message: "AioPikaSendableMessage", exchange: Union["RabbitExchange", str, None] = None, *, + correlation_id: str = "", routing_key: str = "", mandatory: bool = True, immediate: bool = False, @@ -204,7 +217,6 @@ async def publish( content_type: Optional[str] = None, content_encoding: Optional[str] = None, priority: Optional[int] = None, - correlation_id: Optional[str] = None, expiration: Optional["DateType"] = None, message_id: Optional[str] = None, timestamp: Optional["DateType"] = None, @@ -238,61 +250,123 @@ async def publish( ) for handler in self.broker._subscribers.values(): # pragma: no branch - if handler.exchange == exch: - call: bool = False + if _is_handler_suitable( + handler, incoming.routing_key, incoming.headers, exch + ): + with timeout_scope(rpc_timeout, raise_timeout): + response = await self._execute_handler(incoming, handler) + if rpc: + return await self._decoder(await self._parser(response)) - if ( - handler.exchange is None - or handler.exchange.type == ExchangeType.DIRECT - ): - call = handler.queue.name == incoming.routing_key + return None - elif handler.exchange.type == ExchangeType.FANOUT: - call = True + @override + async def request( # type: ignore[override] + self, + message: "AioPikaSendableMessage" = "", + exchange: Union["RabbitExchange", str, None] = None, + *, + correlation_id: str = "", + routing_key: str = "", + mandatory: bool = True, + immediate: bool = False, + timeout: Optional[float] = None, + persist: bool = False, + headers: Optional["HeadersType"] = None, + content_type: Optional[str] = None, + content_encoding: Optional[str] = None, + priority: Optional[int] = None, + expiration: Optional["DateType"] = None, + message_id: Optional[str] = None, + timestamp: Optional["DateType"] = None, + message_type: Optional[str] = None, + user_id: Optional[str] = None, + app_id: Optional[str] = None, + ) -> "PatchedMessage": + """Publish a message to a RabbitMQ queue or exchange.""" + exch = RabbitExchange.validate(exchange) - elif handler.exchange.type == ExchangeType.TOPIC: - call = apply_pattern( - handler.queue.routing, - incoming.routing_key or "", - ) + incoming = build_message( + message=message, + exchange=exch, + routing_key=routing_key, + app_id=app_id, + user_id=user_id, + message_type=message_type, + headers=headers, + persist=persist, + message_id=message_id, + priority=priority, + content_encoding=content_encoding, + content_type=content_type, + correlation_id=correlation_id, + expiration=expiration, + timestamp=timestamp, + ) + + for handler in self.broker._subscribers.values(): # pragma: no branch + if _is_handler_suitable( + handler, incoming.routing_key, incoming.headers, exch + ): + with anyio.fail_after(timeout): + return await self._execute_handler(incoming, handler) + + raise SubscriberNotFound + + async def _execute_handler( + self, msg: PatchedMessage, handler: "LogicSubscriber" + ) -> "PatchedMessage": + result = await handler.process_message(msg) + + return build_message( + routing_key=msg.routing_key, + message=result.body, + headers=result.headers, + correlation_id=result.correlation_id, + ) - elif handler.exchange.type == ExchangeType.HEADERS: # pramga: no branch - queue_headers = (handler.queue.bind_arguments or {}).copy() - msg_headers = incoming.headers - if not queue_headers: - call = True +def _is_handler_suitable( + handler: "LogicSubscriber", + routing_key: str, + headers: "Mapping[Any, Any]", + exchange: "RabbitExchange", +) -> bool: + if handler.exchange != exchange: + return False - else: - matcher = queue_headers.pop("x-match", "all") + if handler.exchange is None or handler.exchange.type == ExchangeType.DIRECT: + return handler.queue.name == routing_key - full = True - none = True - for k, v in queue_headers.items(): - if msg_headers.get(k) != v: - full = False - else: - none = False + elif handler.exchange.type == ExchangeType.FANOUT: + return True - if not none: - call = (matcher == "any") or full + elif handler.exchange.type == ExchangeType.TOPIC: + return apply_pattern(handler.queue.routing, routing_key) + elif handler.exchange.type == ExchangeType.HEADERS: + queue_headers = (handler.queue.bind_arguments or {}).copy() + + if not queue_headers: + return True + + else: + match_rule = queue_headers.pop("x-match", "all") + + full_match = True + is_headers_empty = True + for k, v in queue_headers.items(): + if headers.get(k) != v: + full_match = False else: - raise AssertionError("unreachable") + is_headers_empty = False - if call: - r = await call_handler( - handler=handler, - message=incoming, - rpc=rpc, - rpc_timeout=rpc_timeout, - raise_timeout=raise_timeout, - ) + if is_headers_empty: + return False - if rpc: # pragma: no branch - return r + return full_match or (match_rule == "any") - return None + raise AssertionError def apply_pattern(pattern: str, current: str) -> bool: diff --git a/faststream/redis/broker/broker.py b/faststream/redis/broker/broker.py index fcdc472668..fdafa04c02 100644 --- a/faststream/redis/broker/broker.py +++ b/faststream/redis/broker/broker.py @@ -23,7 +23,7 @@ parse_url, ) from redis.exceptions import ConnectionError -from typing_extensions import Annotated, Doc, TypeAlias, override +from typing_extensions import Annotated, Doc, TypeAlias, deprecated, override from faststream.__about__ import __version__ from faststream.broker.message import gen_cor_id @@ -46,7 +46,7 @@ BrokerMiddleware, CustomCallable, ) - from faststream.redis.message import BaseMessage + from faststream.redis.message import BaseMessage, RedisMessage from faststream.security import BaseSecurity from faststream.types import ( AnyDict, @@ -408,10 +408,20 @@ async def publish( # type: ignore[override] rpc: Annotated[ bool, Doc("Whether to wait for reply in blocking mode."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, rpc_timeout: Annotated[ Optional[float], Doc("RPC reply waiting time."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method with `timeout` instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = 30.0, raise_timeout: Annotated[ bool, @@ -419,6 +429,11 @@ async def publish( # type: ignore[override] "Whetever to raise `TimeoutError` or return `None` at **rpc_timeout**. " "RPC request returns `None` at timeout by default." ), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "`request` always raises TimeoutError instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, ) -> Optional["DecodedMessage"]: """Publish message directly. @@ -428,23 +443,47 @@ async def publish( # type: ignore[override] Please, use `@broker.publisher(...)` or `broker.publisher(...).publish(...)` instead in a regular way. """ - correlation_id = correlation_id or gen_cor_id() - return await super().publish( message, producer=self._producer, + correlation_id=correlation_id or gen_cor_id(), channel=channel, list=list, stream=stream, maxlen=maxlen, reply_to=reply_to, headers=headers, - correlation_id=correlation_id, rpc=rpc, rpc_timeout=rpc_timeout, raise_timeout=raise_timeout, ) + @override + async def request( # type: ignore[override] + self, + message: "SendableMessage", + channel: Optional[str] = None, + *, + list: Optional[str] = None, + stream: Optional[str] = None, + maxlen: Optional[int] = None, + correlation_id: Optional[str] = None, + headers: Optional["AnyDict"] = None, + timeout: Optional[float] = 30.0, + ) -> "RedisMessage": + msg: RedisMessage = await super().request( + message, + producer=self._producer, + correlation_id=correlation_id or gen_cor_id(), + channel=channel, + list=list, + stream=stream, + maxlen=maxlen, + headers=headers, + timeout=timeout, + ) + return msg + async def publish_batch( self, *msgs: Annotated[ diff --git a/faststream/redis/fastapi/fastapi.py b/faststream/redis/fastapi/fastapi.py index ba18c339cc..85ce2bf1e7 100644 --- a/faststream/redis/fastapi/fastapi.py +++ b/faststream/redis/fastapi/fastapi.py @@ -632,19 +632,12 @@ def subscriber( # type: ignore[override] ), ] = False, ) -> AsyncAPISubscriber: - list_sub = ListSub.validate(list) - channel = PubSub.validate(channel) - stream_sub = StreamSub.validate(stream) - - any_of = list_sub or channel or stream_sub - return cast( AsyncAPISubscriber, super().subscriber( - path=getattr(any_of, "name", ""), channel=channel, - list=list_sub, - stream=stream_sub, + list=list, + stream=stream, dependencies=dependencies, parser=parser, decoder=decoder, diff --git a/faststream/redis/opentelemetry/provider.py b/faststream/redis/opentelemetry/provider.py index 1fcfd4e9c3..a809db603d 100644 --- a/faststream/redis/opentelemetry/provider.py +++ b/faststream/redis/opentelemetry/provider.py @@ -30,7 +30,7 @@ def get_consume_attrs_from_message( if cast(str, msg.raw_message.get("type", "")).startswith("b"): attrs[SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT] = len( - cast(Sized, msg.decoded_body) + cast(Sized, msg._decoded_body) ) return attrs diff --git a/faststream/redis/publisher/producer.py b/faststream/redis/publisher/producer.py index 1fb31d07a1..3dc44271e0 100644 --- a/faststream/redis/publisher/producer.py +++ b/faststream/redis/publisher/producer.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Any, Optional +import anyio from typing_extensions import override from faststream.broker.publisher.proto import ProducerProto @@ -35,13 +36,15 @@ def __init__( decoder: Optional["CustomCallable"], ) -> None: self._connection = connection + + default = RedisPubSubParser() self._parser = resolve_custom_func( parser, - RedisPubSubParser().parse_message, + default.parse_message, ) self._decoder = resolve_custom_func( decoder, - RedisPubSubParser().decode_message, + default.decode_message, ) @override @@ -122,6 +125,68 @@ async def publish( # type: ignore[override] else: return await self._decoder(await self._parser(m)) + @override + async def request( # type: ignore[override] + self, + message: "SendableMessage", + *, + correlation_id: str, + channel: Optional[str] = None, + list: Optional[str] = None, + stream: Optional[str] = None, + maxlen: Optional[int] = None, + headers: Optional["AnyDict"] = None, + timeout: Optional[float] = 30.0, + ) -> "Any": + if not any((channel, list, stream)): + raise SetupError(INCORRECT_SETUP_MSG) + + nuid = NUID() + reply_to = str(nuid.next(), "utf-8") + psub = self._connection.pubsub() + await psub.subscribe(reply_to) + + msg = RawMessage.encode( + message=message, + reply_to=reply_to, + headers=headers, + correlation_id=correlation_id, + ) + + if channel is not None: + await self._connection.publish(channel, msg) + elif list is not None: + await self._connection.rpush(list, msg) + elif stream is not None: + await self._connection.xadd( + name=stream, + fields={DATA_KEY: msg}, + maxlen=maxlen, + ) + else: + raise AssertionError("unreachable") + + with anyio.fail_after(timeout) as scope: + # skip subscribe message + await psub.get_message( + ignore_subscribe_messages=True, + timeout=timeout or 0.0, + ) + + # get real response + response_msg = await psub.get_message( + ignore_subscribe_messages=True, + timeout=timeout or 0.0, + ) + + await psub.unsubscribe() + await psub.aclose() # type: ignore[attr-defined] + + if scope.cancel_called: + raise TimeoutError + + return response_msg + async def publish_batch( self, *msgs: "SendableMessage", diff --git a/faststream/redis/publisher/usecase.py b/faststream/redis/publisher/usecase.py index bf578abf24..dfa4e1c0ef 100644 --- a/faststream/redis/publisher/usecase.py +++ b/faststream/redis/publisher/usecase.py @@ -1,19 +1,22 @@ from abc import abstractmethod +from contextlib import AsyncExitStack from copy import deepcopy from functools import partial from itertools import chain -from typing import TYPE_CHECKING, Any, Iterable, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, Optional -from typing_extensions import Annotated, Doc, override +from typing_extensions import Annotated, Doc, deprecated, override from faststream.broker.message import gen_cor_id from faststream.broker.publisher.usecase import PublisherUsecase from faststream.exceptions import NOT_CONNECTED_YET from faststream.redis.message import UnifyRedisDict from faststream.redis.schemas import ListSub, PubSub, StreamSub +from faststream.utils.functions import return_input if TYPE_CHECKING: from faststream.broker.types import BrokerMiddleware, PublisherMiddleware + from faststream.redis.message import RedisMessage from faststream.redis.publisher.producer import RedisFastProducer from faststream.types import AnyDict, AsyncFunc, SendableMessage @@ -134,10 +137,20 @@ async def publish( rpc: Annotated[ bool, Doc("Whether to wait for reply in blocking mode."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, rpc_timeout: Annotated[ Optional[float], Doc("RPC reply waiting time."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method with `timeout` instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = 30.0, raise_timeout: Annotated[ bool, @@ -145,6 +158,11 @@ async def publish( "Whetever to raise `TimeoutError` or return `None` at **rpc_timeout**. " "RPC request returns `None` at timeout by default." ), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "`request` always raises TimeoutError instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, # publisher specific _extra_middlewares: Annotated[ @@ -184,6 +202,77 @@ async def publish( raise_timeout=raise_timeout, ) + @override + async def request( + self, + message: Annotated[ + "SendableMessage", + Doc("Message body to send."), + ] = None, + channel: Annotated[ + Optional[str], + Doc("Redis PubSub object name to send message."), + ] = None, + *, + correlation_id: Annotated[ + Optional[str], + Doc( + "Manual message **correlation_id** setter. " + "**correlation_id** is a useful option to trace messages." + ), + ] = None, + headers: Annotated[ + Optional["AnyDict"], + Doc("Message headers to store metainformation."), + ] = None, + timeout: Annotated[ + Optional[float], + Doc("RPC reply waiting time."), + ] = 30.0, + # publisher specific + _extra_middlewares: Annotated[ + Iterable["PublisherMiddleware"], + Doc("Extra middlewares to wrap publishing process."), + ] = (), + ) -> "RedisMessage": + assert self._producer, NOT_CONNECTED_YET # nosec B101 + + kwargs = { + "channel": PubSub.validate(channel or self.channel).name, + # basic args + "headers": headers or self.headers, + "correlation_id": correlation_id or gen_cor_id(), + "timeout": timeout, + } + request: AsyncFunc = self._producer.request + + for pub_m in chain( + ( + _extra_middlewares + or (m(None).publish_scope for m in self._broker_middlewares) + ), + self._middlewares, + ): + request = partial(pub_m, request) + + published_msg = await request( + message, + **kwargs, + ) + + async with AsyncExitStack() as stack: + return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = return_input + for m in self._broker_middlewares: + mid = m(published_msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._producer._parser(published_msg) + parsed_msg._decoded_body = await self._producer._decoder(parsed_msg) + return await return_msg(parsed_msg) + + raise AssertionError("unreachable") + class ListPublisher(LogicPublisher): def __init__( @@ -261,10 +350,20 @@ async def publish( rpc: Annotated[ bool, Doc("Whether to wait for reply in blocking mode."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, rpc_timeout: Annotated[ Optional[float], Doc("RPC reply waiting time."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method with `timeout` instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = 30.0, raise_timeout: Annotated[ bool, @@ -272,6 +371,11 @@ async def publish( "Whetever to raise `TimeoutError` or return `None` at **rpc_timeout**. " "RPC request returns `None` at timeout by default." ), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "`request` always raises TimeoutError instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, # publisher specific _extra_middlewares: Annotated[ @@ -310,6 +414,78 @@ async def publish( raise_timeout=raise_timeout, ) + @override + async def request( + self, + message: Annotated[ + "SendableMessage", + Doc("Message body to send."), + ] = None, + list: Annotated[ + Optional[str], + Doc("Redis List object name to send message."), + ] = None, + *, + correlation_id: Annotated[ + Optional[str], + Doc( + "Manual message **correlation_id** setter. " + "**correlation_id** is a useful option to trace messages." + ), + ] = None, + headers: Annotated[ + Optional["AnyDict"], + Doc("Message headers to store metainformation."), + ] = None, + timeout: Annotated[ + Optional[float], + Doc("RPC reply waiting time."), + ] = 30.0, + # publisher specific + _extra_middlewares: Annotated[ + Iterable["PublisherMiddleware"], + Doc("Extra middlewares to wrap publishing process."), + ] = (), + ) -> "RedisMessage": + assert self._producer, NOT_CONNECTED_YET # nosec B101 + + kwargs = { + "list": ListSub.validate(list or self.list).name, + # basic args + "headers": headers or self.headers, + "correlation_id": correlation_id or gen_cor_id(), + "timeout": timeout, + } + + request: AsyncFunc = self._producer.request + + for pub_m in chain( + ( + _extra_middlewares + or (m(None).publish_scope for m in self._broker_middlewares) + ), + self._middlewares, + ): + request = partial(pub_m, request) + + published_msg = await request( + message, + **kwargs, + ) + + async with AsyncExitStack() as stack: + return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = return_input + for m in self._broker_middlewares: + mid = m(published_msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._producer._parser(published_msg) + parsed_msg._decoded_body = await self._producer._decoder(parsed_msg) + return await return_msg(parsed_msg) + + raise AssertionError("unreachable") + class ListBatchPublisher(ListPublisher): @override @@ -442,10 +618,20 @@ async def publish( rpc: Annotated[ bool, Doc("Whether to wait for reply in blocking mode."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, rpc_timeout: Annotated[ Optional[float], Doc("RPC reply waiting time."), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "Please, use `request` method with `timeout` instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = 30.0, raise_timeout: Annotated[ bool, @@ -453,6 +639,11 @@ async def publish( "Whetever to raise `TimeoutError` or return `None` at **rpc_timeout**. " "RPC request returns `None` at timeout by default." ), + deprecated( + "Deprecated in **FastStream 0.5.17**. " + "`request` always raises TimeoutError instead. " + "Argument will be removed in **FastStream 0.6.0**." + ), ] = False, # publisher specific _extra_middlewares: Annotated[ @@ -492,3 +683,82 @@ async def publish( rpc_timeout=rpc_timeout, raise_timeout=raise_timeout, ) + + @override + async def request( + self, + message: Annotated[ + "SendableMessage", + Doc("Message body to send."), + ] = None, + stream: Annotated[ + Optional[str], + Doc("Redis Stream object name to send message."), + ] = None, + *, + maxlen: Annotated[ + Optional[int], + Doc( + "Redis Stream maxlen publish option. " + "Remove eldest message if maxlen exceeded." + ), + ] = None, + correlation_id: Annotated[ + Optional[str], + Doc( + "Manual message **correlation_id** setter. " + "**correlation_id** is a useful option to trace messages." + ), + ] = None, + headers: Annotated[ + Optional["AnyDict"], + Doc("Message headers to store metainformation."), + ] = None, + timeout: Annotated[ + Optional[float], + Doc("RPC reply waiting time."), + ] = 30.0, + # publisher specific + _extra_middlewares: Annotated[ + Iterable["PublisherMiddleware"], + Doc("Extra middlewares to wrap publishing process."), + ] = (), + ) -> "RedisMessage": + assert self._producer, NOT_CONNECTED_YET # nosec B101 + + kwargs = { + "stream": StreamSub.validate(stream or self.stream).name, + # basic args + "headers": headers or self.headers, + "correlation_id": correlation_id or gen_cor_id(), + "timeout": timeout, + } + + request: AsyncFunc = self._producer.request + + for pub_m in chain( + ( + _extra_middlewares + or (m(None).publish_scope for m in self._broker_middlewares) + ), + self._middlewares, + ): + request = partial(pub_m, request) + + published_msg = await request( + message, + **kwargs, + ) + + async with AsyncExitStack() as stack: + return_msg: Callable[[RedisMessage], Awaitable[RedisMessage]] = return_input + for m in self._broker_middlewares: + mid = m(published_msg) + await stack.enter_async_context(mid) + return_msg = partial(mid.consume_scope, return_msg) + + parsed_msg = await self._producer._parser(published_msg) + parsed_msg._decoded_body = await self._producer._decoder(parsed_msg) + return await return_msg(parsed_msg) + + raise AssertionError("unreachable") diff --git a/faststream/redis/schemas/pub_sub.py b/faststream/redis/schemas/pub_sub.py index 9dd1fc5422..5277cc1213 100644 --- a/faststream/redis/schemas/pub_sub.py +++ b/faststream/redis/schemas/pub_sub.py @@ -30,7 +30,7 @@ def __init__( super().__init__(path) self.path_regex = reg - self.pattern = pattern + self.pattern = channel if pattern else None self.polling_interval = polling_interval def __hash__(self) -> int: diff --git a/faststream/redis/testing.py b/faststream/redis/testing.py index 690dda7df0..c2b35a4eec 100644 --- a/faststream/redis/testing.py +++ b/faststream/redis/testing.py @@ -1,12 +1,13 @@ import re -from typing import TYPE_CHECKING, Any, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, List, Optional, Protocol, Sequence, Union, cast from unittest.mock import AsyncMock, MagicMock import anyio -from typing_extensions import override +from typing_extensions import TypedDict, override from faststream.broker.message import gen_cor_id -from faststream.exceptions import WRONG_PUBLISH_ARGS, SetupError +from faststream.broker.utils import resolve_custom_func +from faststream.exceptions import WRONG_PUBLISH_ARGS, SetupError, SubscriberNotFound from faststream.redis.broker.broker import RedisBroker from faststream.redis.message import ( BatchListMessage, @@ -16,11 +17,18 @@ PubSubMessage, bDATA_KEY, ) -from faststream.redis.parser import RawMessage +from faststream.redis.parser import RawMessage, RedisPubSubParser from faststream.redis.publisher.producer import RedisFastProducer from faststream.redis.schemas import INCORRECT_SETUP_MSG from faststream.redis.subscriber.factory import create_subscriber -from faststream.testing.broker import TestBroker, call_handler +from faststream.redis.subscriber.usecase import ( + ChannelSubscriber, + LogicSubscriber, + _ListHandlerMixin, + _StreamHandlerMixin, +) +from faststream.testing.broker import TestBroker +from faststream.utils.functions import timeout_scope if TYPE_CHECKING: from faststream.broker.wrapper.call import HandlerCallWrapper @@ -85,6 +93,16 @@ class FakeProducer(RedisFastProducer): def __init__(self, broker: RedisBroker) -> None: self.broker = broker + default = RedisPubSubParser() + self._parser = resolve_custom_func( + broker._parser, + default.parse_message, + ) + self._decoder = resolve_custom_func( + broker._decoder, + default.decode_message, + ) + @override async def publish( self, @@ -113,81 +131,62 @@ async def publish( headers=headers, ) - any_of = channel or list or stream - if any_of is None: - raise SetupError(INCORRECT_SETUP_MSG) + destination = _make_destionation_kwargs(channel, list, stream) + visitors = (ChannelVisitor(), ListVisitor(), StreamVisitor()) - msg: Any = None for handler in self.broker._subscribers.values(): # pragma: no branch - call = False - - if channel and (ch := getattr(handler, "channel", None)) is not None: - call = bool( - (not ch.pattern and ch.name == channel) - or ( - ch.pattern - and re.match( - ch.name.replace(".", "\\.").replace("*", ".*"), - channel, - ) + for visitor in visitors: + if visited_ch := visitor.visit(**destination, sub=handler): + msg = visitor.get_message( + visited_ch, + body, + handler, # type: ignore[arg-type] ) - ) - msg = PubSubMessage( - type="message", - data=body, - channel=channel, - pattern=ch.pattern, - ) + with timeout_scope(rpc_timeout, raise_timeout): + response_msg = await self._execute_handler(msg, handler) + if rpc: + return await self._decoder(await self._parser(response_msg)) - elif list and (ls := getattr(handler, "list_sub", None)) is not None: - if ls.batch: - msg = BatchListMessage( - type="blist", - channel=list, - data=[body], - ) - - else: - msg = DefaultListMessage( - type="list", - channel=list, - data=body, - ) + return None - call = list == ls.name + @override + async def request( # type: ignore[override] + self, + message: "SendableMessage", + *, + correlation_id: str, + channel: Optional[str] = None, + list: Optional[str] = None, + stream: Optional[str] = None, + maxlen: Optional[int] = None, + headers: Optional["AnyDict"] = None, + timeout: Optional[float] = 30.0, + ) -> "PubSubMessage": + correlation_id = correlation_id or gen_cor_id() - elif stream and (st := getattr(handler, "stream_sub", None)) is not None: - if st.batch: - msg = BatchStreamMessage( - type="bstream", - channel=stream, - data=[{bDATA_KEY: body}], - message_ids=[], - ) - else: - msg = DefaultStreamMessage( - type="stream", - channel=stream, - data={bDATA_KEY: body}, - message_ids=[], - ) + body = build_message( + message=message, + correlation_id=correlation_id, + headers=headers, + ) - call = stream == st.name + destination = _make_destionation_kwargs(channel, list, stream) + visitors = (ChannelVisitor(), ListVisitor(), StreamVisitor()) - if call: - r = await call_handler( - handler=handler, - message=msg, - rpc=rpc, - rpc_timeout=rpc_timeout, - raise_timeout=raise_timeout, - ) + for handler in self.broker._subscribers.values(): # pragma: no branch + for visitor in visitors: + if visited_ch := visitor.visit(**destination, sub=handler): + msg = visitor.get_message( + visited_ch, + body, + handler, # type: ignore[arg-type] + ) - if rpc: # pragma: no branch - return r + with anyio.fail_after(timeout): + return await self._execute_handler(msg, handler) - return None + raise SubscriberNotFound async def publish_batch( self, @@ -196,30 +195,43 @@ async def publish_batch( headers: Optional["AnyDict"] = None, correlation_id: Optional[str] = None, ) -> None: - correlation_id = correlation_id or gen_cor_id() - + data_to_send = [ + build_message( + m, + correlation_id=correlation_id or gen_cor_id(), + headers=headers, + ) + for m in msgs + ] + + visitor = ListVisitor() for handler in self.broker._subscribers.values(): # pragma: no branch - if ( - list_sub := getattr(handler, "list_sub", None) - ) and list_sub.name == list: - await call_handler( - handler=handler, - message=BatchListMessage( - type="blist", - channel=list, - data=[ - build_message( - m, - correlation_id=correlation_id, - headers=headers, - ) - for m in msgs - ], - ), - ) + if visitor.visit(list=list, sub=handler): + casted_handler = cast(_ListHandlerMixin, handler) + + if casted_handler.list_sub.batch: + msg = visitor.get_message(list, data_to_send, casted_handler) + + await self._execute_handler(msg, handler) return None + async def _execute_handler( + self, msg: Any, handler: "LogicSubscriber" + ) -> "PubSubMessage": + result = await handler.process_message(msg) + + return PubSubMessage( + type="message", + data=build_message( + message=result.body, + headers=result.headers, + correlation_id=result.correlation_id or "", + ), + channel="", + pattern=None, + ) + def build_message( message: Union[Sequence["SendableMessage"], "SendableMessage"], @@ -235,3 +247,160 @@ def build_message( correlation_id=correlation_id, ) return data + + +class Visitor(Protocol): + def visit( + self, + *, + channel: Optional[str], + list: Optional[str], + stream: Optional[str], + sub: "LogicSubscriber", + ) -> Optional[str]: ... + + def get_message(self, channel: str, body: Any, sub: "LogicSubscriber") -> Any: ... + + +class ChannelVisitor(Visitor): + def visit( + self, + *, + sub: "LogicSubscriber", + channel: Optional[str] = None, + list: Optional[str] = None, + stream: Optional[str] = None, + ) -> Optional[str]: + if channel is None or not isinstance(sub, ChannelSubscriber): + return None + + sub_channel = sub.channel + + if ( + sub_channel.pattern + and bool( + re.match( + sub_channel.name.replace(".", "\\.").replace("*", ".*"), + channel or "", + ) + ) + ) or channel == sub_channel.name: + return channel + + return None + + def get_message( # type: ignore[override] + self, + channel: str, + body: Any, + sub: "ChannelSubscriber", + ) -> Any: + return PubSubMessage( + type="message", + data=body, + channel=channel, + pattern=sub.channel.pattern.encode() if sub.channel.pattern else None, + ) + + +class ListVisitor(Visitor): + def visit( + self, + *, + sub: "LogicSubscriber", + channel: Optional[str] = None, + list: Optional[str] = None, + stream: Optional[str] = None, + ) -> Optional[str]: + if list is None or not isinstance(sub, _ListHandlerMixin): + return None + + if list == sub.list_sub.name: + return list + + return None + + def get_message( # type: ignore[override] + self, + channel: str, + body: Any, + sub: "_ListHandlerMixin", + ) -> Any: + if sub.list_sub.batch: + return BatchListMessage( + type="blist", + channel=channel, + data=body if isinstance(body, List) else [body], + ) + + else: + return DefaultListMessage( + type="list", + channel=channel, + data=body, + ) + + +class StreamVisitor(Visitor): + def visit( + self, + *, + sub: "LogicSubscriber", + channel: Optional[str] = None, + list: Optional[str] = None, + stream: Optional[str] = None, + ) -> Optional[str]: + if stream is None or not isinstance(sub, _StreamHandlerMixin): + return None + + if stream == sub.stream_sub.name: + return stream + + return None + + def get_message( # type: ignore[override] + self, + channel: str, + body: Any, + sub: "_StreamHandlerMixin", + ) -> Any: + if sub.stream_sub.batch: + return BatchStreamMessage( + type="bstream", + channel=channel, + data=[{bDATA_KEY: body}], + message_ids=[], + ) + + else: + return DefaultStreamMessage( + type="stream", + channel=channel, + data={bDATA_KEY: body}, + message_ids=[], + ) + + +class _DestinationKwargs(TypedDict, total=False): + channel: str + list: str + stream: str + + +def _make_destionation_kwargs( + channel: Optional[str], + list: Optional[str], + stream: Optional[str], +) -> _DestinationKwargs: + destination: _DestinationKwargs = {} + if channel: + destination["channel"] = channel + if list: + destination["list"] = list + if stream: + destination["stream"] = stream + + if len(destination) != 1: + raise SetupError(INCORRECT_SETUP_MSG) + + return destination diff --git a/faststream/testing/broker.py b/faststream/testing/broker.py index 38478f8852..fa4499cf1f 100644 --- a/faststream/testing/broker.py +++ b/faststream/testing/broker.py @@ -17,17 +17,15 @@ from unittest.mock import MagicMock from faststream.broker.core.usecase import BrokerUsecase -from faststream.broker.message import StreamMessage, decode_message, encode_message from faststream.broker.middlewares.logging import CriticalLogMiddleware from faststream.broker.wrapper.call import HandlerCallWrapper from faststream.testing.app import TestApp from faststream.utils.ast import is_contains_context_name -from faststream.utils.functions import sync_fake_context, timeout_scope +from faststream.utils.functions import sync_fake_context if TYPE_CHECKING: from types import TracebackType - from faststream.broker.subscriber.proto import SubscriberProto from faststream.broker.types import BrokerMiddleware Broker = TypeVar("Broker", bound=BrokerUsecase[Any, Any]) @@ -69,22 +67,6 @@ def __init__( self.connect_only = connect_only async def __aenter__(self) -> Broker: - # TODO: remove useless middlewares filter - middlewares = tuple( - filter( - lambda x: not isinstance(x, CriticalLogMiddleware), - self.broker._middlewares, - ) - ) - - self.broker._middlewares = middlewares - - for sub in self.broker._subscribers.values(): - sub._broker_middlewares = middlewares - - for pub in self.broker._publishers.values(): - pub._broker_middlewares = middlewares - self._ctx = self._create_ctx() return await self._ctx.__aenter__() @@ -226,27 +208,3 @@ def patch_broker_calls(broker: "BrokerUsecase[Any, Any]") -> None: for handler in broker._subscribers.values(): for h in handler.calls: h.handler.set_test() - - -async def call_handler( - handler: "SubscriberProto[Any]", - message: Any, - rpc: bool = False, - rpc_timeout: Optional[float] = 30.0, - raise_timeout: bool = False, -) -> Any: - """Asynchronously call a handler function.""" - with timeout_scope(rpc_timeout, raise_timeout): - result = await handler.process_message(message) - - if rpc: - message_body, content_type = encode_message(result) - msg_to_publish = StreamMessage( - raw_message=None, - body=message_body, - content_type=content_type, - ) - consumed_data = decode_message(msg_to_publish) - return consumed_data - - return None diff --git a/faststream/utils/functions.py b/faststream/utils/functions.py index 81b1b06db9..453c70ffc7 100644 --- a/faststream/utils/functions.py +++ b/faststream/utils/functions.py @@ -80,3 +80,7 @@ def drop_response_type( ) -> CallModel[F_Spec, F_Return]: model.response_model = None return model + + +async def return_input(x: Any) -> Any: + return x diff --git a/faststream/utils/nuid.py b/faststream/utils/nuid.py index d804d1a19c..a61aa08a8f 100644 --- a/faststream/utils/nuid.py +++ b/faststream/utils/nuid.py @@ -21,7 +21,7 @@ BASE = 62 PREFIX_LENGTH = 12 SEQ_LENGTH = 10 -MAX_SEQ = 839299365868340224 # BASE**10 +MAX_SEQ = BASE**10 MIN_INC = 33 MAX_INC = 333 INC = MAX_INC - MIN_INC diff --git a/tests/brokers/base/fastapi.py b/tests/brokers/base/fastapi.py index 009455161b..adb9e0c923 100644 --- a/tests/brokers/base/fastapi.py +++ b/tests/brokers/base/fastapi.py @@ -128,9 +128,11 @@ async def test_double_real(self, mock: Mock, queue: str, event: asyncio.Event): router = self.router_class() args, kwargs = self.get_subscriber_params(queue) + sub1 = router.subscriber(*args, **kwargs) + args2, kwargs2 = self.get_subscriber_params(queue + "2") - @router.subscriber(*args, **kwargs) + @sub1 @router.subscriber(*args2, **kwargs2) async def hello(msg: str): if event.is_set(): @@ -445,9 +447,10 @@ async def test_publisher_mock(self, queue: str): publisher = router.publisher(queue + "resp") args, kwargs = self.get_subscriber_params(queue) + sub = router.subscriber(*args, **kwargs) @publisher - @router.subscriber(*args, **kwargs) + @sub async def m(): return "response" diff --git a/tests/brokers/base/middlewares.py b/tests/brokers/base/middlewares.py index b125cbc700..dc6175198e 100644 --- a/tests/brokers/base/middlewares.py +++ b/tests/brokers/base/middlewares.py @@ -34,7 +34,7 @@ async def test_subscriber_middleware( raw_broker, ): async def mid(call_next, msg): - mock.start(msg.decoded_body) + mock.start(msg._decoded_body) result = await call_next(msg) mock.end() event.set() diff --git a/tests/brokers/base/requests.py b/tests/brokers/base/requests.py new file mode 100644 index 0000000000..78dcdcb58b --- /dev/null +++ b/tests/brokers/base/requests.py @@ -0,0 +1,174 @@ +import anyio +import pytest + +from .basic import BaseTestcaseConfig + + +class RequestsTestcase(BaseTestcaseConfig): + def get_middleware(self, **kwargs): + raise NotImplementedError + + def get_broker(self, **kwargs): + raise NotImplementedError + + def get_router(self, **kwargs): + raise NotImplementedError + + def patch_broker(self, broker, **kwargs): + return broker + + async def test_request_timeout(self, queue: str): + broker = self.get_broker() + + args, kwargs = self.get_subscriber_params(queue) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + await anyio.sleep(1.0) + return "Response" + + async with self.patch_broker(broker): + await broker.start() + + with pytest.raises(TimeoutError): + await broker.request( + None, + queue, + timeout=1e-24, + ) + + async def test_broker_base_request(self, queue: str): + broker = self.get_broker() + + args, kwargs = self.get_subscriber_params(queue) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + return "Response" + + async with self.patch_broker(broker): + await broker.start() + + response = await broker.request( + None, + queue, + timeout=self.timeout, + correlation_id="1", + ) + + assert await response.decode() == "Response" + assert response.correlation_id == "1", response.correlation_id + + async def test_publisher_base_request(self, queue: str): + broker = self.get_broker() + + publisher = broker.publisher(queue) + + args, kwargs = self.get_subscriber_params(queue) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + return "Response" + + async with self.patch_broker(broker): + await broker.start() + + response = await publisher.request( + None, + timeout=self.timeout, + correlation_id="1", + ) + + assert await response.decode() == "Response" + assert response.correlation_id == "1", response.correlation_id + + async def test_router_publisher_request(self, queue: str): + router = self.get_router() + + publisher = router.publisher(queue) + + args, kwargs = self.get_subscriber_params(queue) + + @router.subscriber(*args, **kwargs) + async def handler(msg): + return "Response" + + broker = self.get_broker() + broker.include_router(router) + + async with self.patch_broker(broker): + await broker.start() + + response = await publisher.request( + None, + timeout=self.timeout, + correlation_id="1", + ) + + assert await response.decode() == "Response" + assert response.correlation_id == "1", response.correlation_id + + async def test_broker_request_respect_middleware(self, queue: str): + broker = self.get_broker(middlewares=(self.get_middleware(),)) + + args, kwargs = self.get_subscriber_params(queue) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + return msg + + async with self.patch_broker(broker): + await broker.start() + + response = await broker.request( + "x", + queue, + timeout=self.timeout, + ) + + assert await response.decode() == "x" * 2 * 2 * 2 * 2 + + async def test_broker_publisher_request_respect_middleware(self, queue: str): + broker = self.get_broker(middlewares=(self.get_middleware(),)) + + publisher = broker.publisher(queue) + + args, kwargs = self.get_subscriber_params(queue) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + return msg + + async with self.patch_broker(broker): + await broker.start() + + response = await publisher.request( + "x", + timeout=self.timeout, + ) + + assert await response.decode() == "x" * 2 * 2 * 2 * 2 + + async def test_router_publisher_request_respect_middleware(self, queue: str): + router = self.get_router(middlewares=(self.get_middleware(),)) + + publisher = router.publisher(queue) + + args, kwargs = self.get_subscriber_params(queue) + + @router.subscriber(*args, **kwargs) + async def handler(msg): + return msg + + broker = self.get_broker() + broker.include_router(router) + + async with self.patch_broker(broker): + await broker.start() + + response = await publisher.request( + "x", + timeout=self.timeout, + ) + + assert await response.decode() == "x" * 2 * 2 * 2 * 2 diff --git a/tests/brokers/base/rpc.py b/tests/brokers/base/rpc.py index be2dcbeea0..dcdd8e85e0 100644 --- a/tests/brokers/base/rpc.py +++ b/tests/brokers/base/rpc.py @@ -28,13 +28,13 @@ async def test_rpc(self, queue: str): @rpc_broker.subscriber(*args, **kwargs) async def m(m): - return "1" + return "Hi!" async with self.patch_broker(rpc_broker) as br: await br.start() r = await br.publish("hello", queue, rpc_timeout=3, rpc=True) - assert r == "1" + assert r == "Hi!" @pytest.mark.asyncio async def test_rpc_timeout_raises(self, queue: str): diff --git a/tests/brokers/confluent/test_requests.py b/tests/brokers/confluent/test_requests.py new file mode 100644 index 0000000000..39f4677113 --- /dev/null +++ b/tests/brokers/confluent/test_requests.py @@ -0,0 +1,31 @@ +import pytest + +from faststream import BaseMiddleware +from faststream.confluent import KafkaBroker, KafkaRouter, TestKafkaBroker +from tests.brokers.base.requests import RequestsTestcase + +from .basic import ConfluentTestcaseConfig + + +class Mid(BaseMiddleware): + async def on_receive(self) -> None: + self.msg._raw_msg = self.msg._raw_msg * 2 + + async def consume_scope(self, call_next, msg): + msg._decoded_body = msg._decoded_body * 2 + return await call_next(msg) + + +@pytest.mark.asyncio +class TestRequestTestClient(ConfluentTestcaseConfig, RequestsTestcase): + def get_middleware(self, **kwargs): + return Mid + + def get_broker(self, **kwargs): + return KafkaBroker(**kwargs) + + def get_router(self, **kwargs): + return KafkaRouter(**kwargs) + + def patch_broker(self, broker, **kwargs): + return TestKafkaBroker(broker, **kwargs) diff --git a/tests/brokers/confluent/test_test_client.py b/tests/brokers/confluent/test_test_client.py index 50131549ac..82e9aefe91 100644 --- a/tests/brokers/confluent/test_test_client.py +++ b/tests/brokers/confluent/test_test_client.py @@ -60,7 +60,9 @@ async def test_with_real_testclient( ): broker = self.get_broker() - @broker.subscriber(queue, auto_offset_reset="earliest") + args, kwargs = self.get_subscriber_params(queue) + + @broker.subscriber(*args, **kwargs) def subscriber(m): event.set() @@ -81,7 +83,7 @@ async def test_batch_pub_by_default_pub( ): broker = self.get_broker() - @broker.subscriber(queue, batch=True, auto_offset_reset="earliest") + @broker.subscriber(queue, batch=True) async def m(msg): pass @@ -95,7 +97,7 @@ async def test_batch_pub_by_pub_batch( ): broker = self.get_broker() - @broker.subscriber(queue, batch=True, auto_offset_reset="earliest") + @broker.subscriber(queue, batch=True) async def m(msg): pass @@ -112,7 +114,7 @@ async def test_batch_publisher_mock( publisher = broker.publisher(queue + "1", batch=True) @publisher - @broker.subscriber(queue, auto_offset_reset="earliest") + @broker.subscriber(queue) async def m(msg): return 1, 2, 3 @@ -131,10 +133,10 @@ async def on_receive(self) -> None: broker = KafkaBroker(middlewares=(Middleware,)) - @broker.subscriber(queue, auto_offset_reset="earliest") + @broker.subscriber(queue) async def h1(): ... - @broker.subscriber(queue + "1", auto_offset_reset="earliest") + @broker.subscriber(queue + "1") async def h2(): ... async with TestKafkaBroker(broker) as br: @@ -154,10 +156,14 @@ async def on_receive(self) -> None: broker = KafkaBroker(middlewares=(Middleware,)) - @broker.subscriber(queue, auto_offset_reset="earliest") + args, kwargs = self.get_subscriber_params(queue) + + @broker.subscriber(*args, **kwargs) async def h1(): ... - @broker.subscriber(queue + "1", auto_offset_reset="earliest") + args2, kwargs2 = self.get_subscriber_params(queue + "1") + + @broker.subscriber(*args2, **kwargs2) async def h2(): ... async with TestKafkaBroker(broker, with_real=True) as br: diff --git a/tests/brokers/kafka/test_requests.py b/tests/brokers/kafka/test_requests.py new file mode 100644 index 0000000000..a518b2fa43 --- /dev/null +++ b/tests/brokers/kafka/test_requests.py @@ -0,0 +1,29 @@ +import pytest + +from faststream import BaseMiddleware +from faststream.kafka import KafkaBroker, KafkaRouter, TestKafkaBroker +from tests.brokers.base.requests import RequestsTestcase + + +class Mid(BaseMiddleware): + async def on_receive(self) -> None: + self.msg.value = self.msg.value * 2 + + async def consume_scope(self, call_next, msg): + msg._decoded_body = msg._decoded_body * 2 + return await call_next(msg) + + +@pytest.mark.asyncio +class TestRequestTestClient(RequestsTestcase): + def get_middleware(self, **kwargs): + return Mid + + def get_broker(self, **kwargs): + return KafkaBroker(**kwargs) + + def get_router(self, **kwargs): + return KafkaRouter(**kwargs) + + def patch_broker(self, broker, **kwargs): + return TestKafkaBroker(broker, **kwargs) diff --git a/tests/brokers/nats/test_publish.py b/tests/brokers/nats/test_publish.py index 7ecd9e544e..1fb8b799d6 100644 --- a/tests/brokers/nats/test_publish.py +++ b/tests/brokers/nats/test_publish.py @@ -42,7 +42,7 @@ async def handle_next(msg=Context("message")): await asyncio.wait( ( - asyncio.create_task(br.publish("", queue)), + asyncio.create_task(br.publish("", queue, correlation_id="wrong")), asyncio.create_task(event.wait()), ), timeout=3, diff --git a/tests/brokers/nats/test_requests.py b/tests/brokers/nats/test_requests.py new file mode 100644 index 0000000000..19f9c2cb15 --- /dev/null +++ b/tests/brokers/nats/test_requests.py @@ -0,0 +1,85 @@ +import pytest + +from faststream import BaseMiddleware +from faststream.nats import NatsBroker, NatsRouter, TestNatsBroker +from tests.brokers.base.requests import RequestsTestcase + + +class Mid(BaseMiddleware): + async def on_receive(self) -> None: + self.msg.data = self.msg.data * 2 + + async def consume_scope(self, call_next, msg): + msg._decoded_body = msg._decoded_body * 2 + return await call_next(msg) + + +@pytest.mark.asyncio +class NatsRequestsTestcase(RequestsTestcase): + def get_middleware(self, **kwargs): + return Mid + + def get_broker(self, **kwargs): + return NatsBroker(**kwargs) + + def get_router(self, **kwargs): + return NatsRouter(**kwargs) + + async def test_broker_stream_request(self, queue: str): + broker = self.get_broker() + + stream_name = f"{queue}st" + + args, kwargs = self.get_subscriber_params(queue, stream=stream_name) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + return "Response" + + async with self.patch_broker(broker): + await broker.start() + + response = await broker.request( + None, + queue, + correlation_id="1", + stream=stream_name, + timeout=self.timeout, + ) + + assert await response.decode() == "Response" + assert response.correlation_id == "1" + + async def test_publisher_stream_request(self, queue: str): + broker = self.get_broker() + + stream_name = f"{queue}st" + publisher = broker.publisher(queue, stream=stream_name) + + args, kwargs = self.get_subscriber_params(queue, stream=stream_name) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + return "Response" + + async with self.patch_broker(broker): + await broker.start() + + response = await publisher.request( + None, + correlation_id="1", + timeout=self.timeout, + ) + + assert await response.decode() == "Response" + assert response.correlation_id == "1" + + +@pytest.mark.nats +class TestRealRequests(NatsRequestsTestcase): + pass + + +class TestRequestTestClient(NatsRequestsTestcase): + def patch_broker(self, broker, **kwargs): + return TestNatsBroker(broker, **kwargs) diff --git a/tests/brokers/rabbit/test_requests.py b/tests/brokers/rabbit/test_requests.py new file mode 100644 index 0000000000..c0927eabc8 --- /dev/null +++ b/tests/brokers/rabbit/test_requests.py @@ -0,0 +1,38 @@ +import pytest + +from faststream import BaseMiddleware +from faststream.rabbit import RabbitBroker, RabbitRouter, TestRabbitBroker +from tests.brokers.base.requests import RequestsTestcase + + +class Mid(BaseMiddleware): + async def on_receive(self) -> None: + self.msg._Message__lock = False + self.msg.body = self.msg.body * 2 + + async def consume_scope(self, call_next, msg): + msg._decoded_body = msg._decoded_body * 2 + return await call_next(msg) + + +@pytest.mark.asyncio +class RabbitRequestsTestcase(RequestsTestcase): + def get_middleware(self, **kwargs): + return Mid + + def get_broker(self, **kwargs): + return RabbitBroker(**kwargs) + + def get_router(self, **kwargs): + return RabbitRouter(**kwargs) + + +@pytest.mark.rabbit +class TestRealRequests(RabbitRequestsTestcase): + pass + + +@pytest.mark.asyncio +class TestRequestTestClient(RabbitRequestsTestcase): + def patch_broker(self, broker, **kwargs): + return TestRabbitBroker(broker, **kwargs) diff --git a/tests/brokers/redis/test_requests.py b/tests/brokers/redis/test_requests.py new file mode 100644 index 0000000000..e13fe06e92 --- /dev/null +++ b/tests/brokers/redis/test_requests.py @@ -0,0 +1,40 @@ +import json + +import pytest + +from faststream import BaseMiddleware +from faststream.redis import RedisBroker, RedisRouter, TestRedisBroker +from tests.brokers.base.requests import RequestsTestcase + + +class Mid(BaseMiddleware): + async def on_receive(self) -> None: + data = json.loads(self.msg["data"]) + data["data"] *= 2 + self.msg["data"] = json.dumps(data) + + async def consume_scope(self, call_next, msg): + msg._decoded_body = msg._decoded_body * 2 + return await call_next(msg) + + +@pytest.mark.asyncio +class RedisRequestsTestcase(RequestsTestcase): + def get_middleware(self, **kwargs): + return Mid + + def get_broker(self, **kwargs): + return RedisBroker(**kwargs) + + def get_router(self, **kwargs): + return RedisRouter(**kwargs) + + +@pytest.mark.redis +class TestRealRequests(RedisRequestsTestcase): + pass + + +class TestRequestTestClient(RedisRequestsTestcase): + def patch_broker(self, broker, **kwargs): + return TestRedisBroker(broker, **kwargs)