From 2bb56208f71be275e518d9cbd40abd59488e0026 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Tue, 11 Jun 2024 22:03:21 +0300 Subject: [PATCH 1/3] feat: support NATS multiple subjects JS subscription --- faststream/nats/broker/registrator.py | 6 +++--- faststream/nats/subscriber/factory.py | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/faststream/nats/broker/registrator.py b/faststream/nats/broker/registrator.py index a77b439b98..bcd0bab0a2 100644 --- a/faststream/nats/broker/registrator.py +++ b/faststream/nats/broker/registrator.py @@ -42,7 +42,7 @@ def subscriber( # type: ignore[override] subject: Annotated[ str, Doc("NATS subject to subscribe."), - ], + ] = "", queue: Annotated[ str, Doc( @@ -209,7 +209,7 @@ def subscriber( # type: ignore[override] You can use it as a handler decorator `@broker.subscriber(...)`. """ - if stream := self._stream_builder.create(stream): + if (stream := self._stream_builder.create(stream)) and subject: stream.add_subject(subject) subscriber = cast( @@ -323,7 +323,7 @@ def publisher( # type: ignore[override] Or you can create a publisher object to call it lately - `broker.publisher(...).publish(...)`. """ - if stream := self._stream_builder.create(stream): + if (stream := self._stream_builder.create(stream)) and subject: stream.add_subject(subject) publisher = cast( diff --git a/faststream/nats/subscriber/factory.py b/faststream/nats/subscriber/factory.py index 2ae7c9b820..57ba827784 100644 --- a/faststream/nats/subscriber/factory.py +++ b/faststream/nats/subscriber/factory.py @@ -80,6 +80,9 @@ def create_subscriber( if pull_sub is not None and stream is None: raise SetupError("Pull subscriber can be used only with a stream") + if not subject and not config: + raise SetupError("You must provide either `subject` or `config` option.") + if stream: # TODO: pull & queue warning # TODO: push & durable warning From 56dffa9b097b6969003db6833c8caf1cd8ad8ddf Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Tue, 11 Jun 2024 22:28:28 +0300 Subject: [PATCH 2/3] feat: NATS test client supports filter subsciption --- faststream/nats/subscriber/factory.py | 13 ++++++++- faststream/nats/subscriber/usecase.py | 40 +++++++++++++++++++++++--- faststream/nats/testing.py | 5 +++- tests/brokers/nats/test_consume.py | 34 ++++++++++++++++++++-- tests/brokers/nats/test_test_client.py | 25 ++++++++++++---- 5 files changed, 104 insertions(+), 13 deletions(-) diff --git a/faststream/nats/subscriber/factory.py b/faststream/nats/subscriber/factory.py index 57ba827784..1161c66550 100644 --- a/faststream/nats/subscriber/factory.py +++ b/faststream/nats/subscriber/factory.py @@ -4,6 +4,7 @@ DEFAULT_SUB_PENDING_BYTES_LIMIT, DEFAULT_SUB_PENDING_MSGS_LIMIT, ) +from nats.js.api import ConsumerConfig from nats.js.client import ( DEFAULT_JS_SUB_PENDING_BYTES_LIMIT, DEFAULT_JS_SUB_PENDING_MSGS_LIMIT, @@ -83,6 +84,8 @@ def create_subscriber( if not subject and not config: raise SetupError("You must provide either `subject` or `config` option.") + config = config or ConsumerConfig(filter_subjects=[]) + if stream: # TODO: pull & queue warning # TODO: push & durable warning @@ -94,7 +97,6 @@ def create_subscriber( or DEFAULT_JS_SUB_PENDING_BYTES_LIMIT, "durable": durable, "stream": stream.name, - "config": config, } if pull_sub is not None: @@ -123,6 +125,7 @@ def create_subscriber( if obj_watch is not None: return AsyncAPIObjStoreWatchSubscriber( subject=subject, + config=config, obj_watch=obj_watch, broker_dependencies=broker_dependencies, broker_middlewares=broker_middlewares, @@ -134,6 +137,7 @@ def create_subscriber( if kv_watch is not None: return AsyncAPIKeyValueWatchSubscriber( subject=subject, + config=config, kv_watch=kv_watch, broker_dependencies=broker_dependencies, broker_middlewares=broker_middlewares, @@ -147,6 +151,7 @@ def create_subscriber( return AsyncAPIConcurrentCoreSubscriber( max_workers=max_workers, subject=subject, + config=config, queue=queue, # basic args extra_options=extra_options, @@ -165,6 +170,7 @@ def create_subscriber( else: return AsyncAPICoreSubscriber( subject=subject, + config=config, queue=queue, # basic args extra_options=extra_options, @@ -188,6 +194,7 @@ def create_subscriber( pull_sub=pull_sub, stream=stream, subject=subject, + config=config, # basic args extra_options=extra_options, # Subscriber args @@ -207,6 +214,7 @@ def create_subscriber( max_workers=max_workers, stream=stream, subject=subject, + config=config, queue=queue, # basic args extra_options=extra_options, @@ -229,6 +237,7 @@ def create_subscriber( pull_sub=pull_sub, stream=stream, subject=subject, + config=config, # basic args extra_options=extra_options, # Subscriber args @@ -248,6 +257,7 @@ def create_subscriber( pull_sub=pull_sub, stream=stream, subject=subject, + config=config, # basic args extra_options=extra_options, # Subscriber args @@ -267,6 +277,7 @@ def create_subscriber( stream=stream, subject=subject, queue=queue, + config=config, # basic args extra_options=extra_options, # Subscriber args diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 322ef41aa3..b21d53acdc 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -20,7 +20,7 @@ import anyio from fast_depends.dependencies import Depends from nats.errors import ConnectionClosedError, TimeoutError -from nats.js.api import ObjectInfo +from nats.js.api import ConsumerConfig, ObjectInfo from nats.js.kv import KeyValue from typing_extensions import Annotated, Doc, override @@ -73,6 +73,7 @@ def __init__( self, *, subject: str, + config: "ConsumerConfig", extra_options: Optional[AnyDict], # Subscriber args default_parser: "AsyncCallable", @@ -88,6 +89,7 @@ def __init__( include_in_schema: bool, ) -> None: self.subject = subject + self.config = config self.extra_options = extra_options or {} @@ -205,10 +207,18 @@ def build_log_context( def add_prefix(self, prefix: str) -> None: """Include Subscriber in router.""" - self.subject = "".join((prefix, self.subject)) + if self.subject: + self.subject = "".join((prefix, self.subject)) + else: + self.config.filter_subjects = [ + "".join((prefix, subject)) + for subject in (self.config.filter_subjects or ()) + ] def __hash__(self) -> int: - return self.get_routing_hash(self.subject) + return self.get_routing_hash( + self.subject or "".join(self.config.filter_subjects or ()) + ) @staticmethod def get_routing_hash( @@ -229,6 +239,7 @@ def __init__( self, *, subject: str, + config: "ConsumerConfig", # default args extra_options: Optional[AnyDict], # Subscriber args @@ -246,6 +257,7 @@ def __init__( ) -> None: super().__init__( subject=subject, + config=config, extra_options=extra_options, # subscriber args default_parser=default_parser, @@ -368,6 +380,7 @@ def __init__( *, # default args subject: str, + config: "ConsumerConfig", queue: str, extra_options: Optional[AnyDict], # Subscriber args @@ -387,6 +400,7 @@ def __init__( super().__init__( subject=subject, + config=config, extra_options=extra_options, # subscriber args default_parser=parser_.parse_message, @@ -439,6 +453,7 @@ def __init__( max_workers: int, # default args subject: str, + config: "ConsumerConfig", queue: str, extra_options: Optional[AnyDict], # Subscriber args @@ -456,6 +471,7 @@ def __init__( max_workers=max_workers, # basic args subject=subject, + config=config, queue=queue, extra_options=extra_options, # Propagated args @@ -494,6 +510,7 @@ def __init__( stream: "JStream", # default args subject: str, + config: "ConsumerConfig", queue: str, extra_options: Optional[AnyDict], # Subscriber args @@ -514,6 +531,7 @@ def __init__( super().__init__( subject=subject, + config=config, extra_options=extra_options, # subscriber args default_parser=parser_.parse_message, @@ -540,7 +558,7 @@ def get_log_context( """Log context factory using in `self.consume` scope.""" return self.build_log_context( message=message, - subject=self.subject, + subject=self.subject or ", ".join(self.config.filter_subjects or ()), queue=self.queue, stream=self.stream.name, ) @@ -560,6 +578,7 @@ async def _create_subscription( # type: ignore[override] subject=self.clear_subject, queue=self.queue, cb=self.consume, + config=self.config, **self.extra_options, ) @@ -574,6 +593,7 @@ def __init__( stream: "JStream", # default args subject: str, + config: "ConsumerConfig", queue: str, extra_options: Optional[AnyDict], # Subscriber args @@ -592,6 +612,7 @@ def __init__( # basic args stream=stream, subject=subject, + config=config, queue=queue, extra_options=extra_options, # Propagated args @@ -619,6 +640,7 @@ async def _create_subscription( # type: ignore[override] subject=self.clear_subject, queue=self.queue, cb=self._put_msg, + config=self.config, **self.extra_options, ) @@ -633,6 +655,7 @@ def __init__( stream: "JStream", # default args subject: str, + config: "ConsumerConfig", extra_options: Optional[AnyDict], # Subscriber args no_ack: bool, @@ -651,6 +674,7 @@ def __init__( # basic args stream=stream, subject=subject, + config=config, extra_options=extra_options, queue="", # Propagated args @@ -708,6 +732,7 @@ def __init__( pull_sub: "PullSub", stream: "JStream", subject: str, + config: "ConsumerConfig", extra_options: Optional[AnyDict], # Subscriber args no_ack: bool, @@ -726,6 +751,7 @@ def __init__( pull_sub=pull_sub, stream=stream, subject=subject, + config=config, extra_options=extra_options, # Propagated args no_ack=no_ack, @@ -765,6 +791,7 @@ def __init__( *, # default args subject: str, + config: "ConsumerConfig", stream: "JStream", pull_sub: "PullSub", extra_options: Optional[AnyDict], @@ -786,6 +813,7 @@ def __init__( super().__init__( subject=subject, + config=config, extra_options=extra_options, # subscriber args default_parser=parser.parse_batch, @@ -837,6 +865,7 @@ def __init__( self, *, subject: str, + config: "ConsumerConfig", kv_watch: "KvWatch", broker_dependencies: Iterable[Depends], broker_middlewares: Iterable["BrokerMiddleware[KeyValue.Entry]"], @@ -850,6 +879,7 @@ def __init__( super().__init__( subject=subject, + config=config, extra_options=None, no_ack=True, no_reply=True, @@ -941,6 +971,7 @@ def __init__( self, *, subject: str, + config: "ConsumerConfig", obj_watch: "ObjWatch", broker_dependencies: Iterable[Depends], broker_middlewares: Iterable["BrokerMiddleware[List[Msg]]"], @@ -955,6 +986,7 @@ def __init__( super().__init__( subject=subject, + config=config, extra_options=None, no_ack=True, no_reply=True, diff --git a/faststream/nats/testing.py b/faststream/nats/testing.py index 34230cb788..4d13333c5f 100644 --- a/faststream/nats/testing.py +++ b/faststream/nats/testing.py @@ -97,7 +97,10 @@ async def publish( # type: ignore[override] ): continue - if is_subject_match_wildcard(subject, handler.clear_subject): + if is_subject_match_wildcard(subject, handler.clear_subject) or any( + is_subject_match_wildcard(subject, filter_subject) + for filter_subject in (handler.config.filter_subjects or ()) + ): msg: Union[List[PatchedMessage], PatchedMessage] if (pull := getattr(handler, "pull_sub", None)) and pull.batch: msg = [incoming] diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index 60ac90a7f3..5b37bcbc7f 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -1,11 +1,11 @@ import asyncio -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from nats.aio.msg import Msg from faststream.exceptions import AckMessage -from faststream.nats import JStream, NatsBroker, PullSub +from faststream.nats import ConsumerConfig, JStream, NatsBroker, PullSub from faststream.nats.annotations import NatsMessage from tests.brokers.base.consume import BrokerRealConsumeTestcase from tests.tools import spy_decorator @@ -40,6 +40,36 @@ def subscriber(m): assert event.is_set() + async def test_consume_with_filter( + self, + queue, + mock: Mock, + event: asyncio.Event, + ): + consume_broker = self.get_broker() + + @consume_broker.subscriber( + config=ConsumerConfig(filter_subjects=[f"{queue}.a"]), + stream=JStream(queue, subjects=[f"{queue}.*"]), + ) + def subscriber(m): + mock(m) + event.set() + + async with self.patch_broker(consume_broker) as br: + await br.start() + await asyncio.wait( + ( + asyncio.create_task(br.publish(1, f"{queue}.b")), + asyncio.create_task(br.publish(2, f"{queue}.a")), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with(2) + async def test_consume_pull( self, queue: str, diff --git a/tests/brokers/nats/test_test_client.py b/tests/brokers/nats/test_test_client.py index ebbd1c7887..9718b558b6 100644 --- a/tests/brokers/nats/test_test_client.py +++ b/tests/brokers/nats/test_test_client.py @@ -4,7 +4,7 @@ from faststream import BaseMiddleware from faststream.exceptions import SetupError -from faststream.nats import JStream, NatsBroker, PullSub, TestNatsBroker +from faststream.nats import ConsumerConfig, JStream, NatsBroker, PullSub, TestNatsBroker from tests.brokers.base.testclient import BrokerTestclientTestcase @@ -208,8 +208,6 @@ async def test_consume_batch( self, queue: str, stream: JStream, - event: asyncio.Event, - mock, ): broker = self.get_broker() @@ -219,9 +217,26 @@ async def test_consume_batch( pull_sub=PullSub(1, batch=True), ) def subscriber(m): - mock(m) - event.set() + pass async with TestNatsBroker(broker) as br: await br.publish("hello", queue) subscriber.mock.assert_called_once_with(["hello"]) + + async def test_consume_with_filter( + self, + queue, + ): + broker = self.get_broker() + + @broker.subscriber( + config=ConsumerConfig(filter_subjects=[f"{queue}.a"]), + stream=JStream(queue, subjects=[f"{queue}.*"]), + ) + def subscriber(m): + pass + + async with TestNatsBroker(broker) as br: + await br.publish(1, f"{queue}.b") + await br.publish(2, f"{queue}.a") + subscriber.mock.assert_called_once_with(2) From 77d55609b7626d569099560f5cf5c9c221a22767 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Tue, 11 Jun 2024 22:58:00 +0300 Subject: [PATCH 3/3] refactor: add cache for NATS log subject calculation --- faststream/nats/subscriber/usecase.py | 10 ++++++---- tests/brokers/nats/test_consume.py | 1 - 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index b21d53acdc..d64cc2cf2d 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -215,10 +215,12 @@ def add_prefix(self, prefix: str) -> None: for subject in (self.config.filter_subjects or ()) ] + @cached_property + def _resolved_subject_string(self) -> str: + return self.subject or ", ".join(self.config.filter_subjects or ()) + def __hash__(self) -> int: - return self.get_routing_hash( - self.subject or "".join(self.config.filter_subjects or ()) - ) + return self.get_routing_hash(self._resolved_subject_string) @staticmethod def get_routing_hash( @@ -558,7 +560,7 @@ def get_log_context( """Log context factory using in `self.consume` scope.""" return self.build_log_context( message=message, - subject=self.subject or ", ".join(self.config.filter_subjects or ()), + subject=self._resolved_subject_string, queue=self.queue, stream=self.stream.name, ) diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index 5b37bcbc7f..96e40f447b 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -60,7 +60,6 @@ def subscriber(m): await br.start() await asyncio.wait( ( - asyncio.create_task(br.publish(1, f"{queue}.b")), asyncio.create_task(br.publish(2, f"{queue}.a")), asyncio.create_task(event.wait()), ),