From 35d6885eafae1e51df961f33740e60b925df12fb Mon Sep 17 00:00:00 2001 From: ekerstens <49325583+ekerstens@users.noreply.github.com> Date: Wed, 19 Jan 2022 12:53:43 -0800 Subject: [PATCH] Support waiting for fetcher to finish (#263) Co-authored-by: Eric Kerstens --- faust/transport/consumer.py | 65 ++++++++++++++++++--------- tests/unit/transport/test_consumer.py | 1 - 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/faust/transport/consumer.py b/faust/transport/consumer.py index c95e79e6c..f48c355a2 100644 --- a/faust/transport/consumer.py +++ b/faust/transport/consumer.py @@ -436,6 +436,7 @@ class Consumer(Service, ConsumerT): flow_active: bool = True can_resume_flow: Event suspend_flow: Event + not_waiting_next_records: Event def __init__( self, @@ -477,6 +478,8 @@ def __init__( self.randomly_assigned_topics = set() self.can_resume_flow = Event() self.suspend_flow = Event() + self.not_waiting_next_records = Event() + self.not_waiting_next_records.set() self._reset_state() super().__init__(loop=loop or self.transport.loop, **kwargs) self.transactions = self.transport.create_transaction_manager( @@ -500,6 +503,7 @@ def _reset_state(self) -> None: self._buffered_partitions = set() self.can_resume_flow.clear() self.suspend_flow.clear() + self.not_waiting_next_records.set() self.flow_active = True self._time_start = monotonic() @@ -587,6 +591,18 @@ def resume_flow(self) -> None: self.can_resume_flow.set() self.suspend_flow.clear() + async def wait_for_stopped_flow(self) -> None: + """Wait until the consumer is not waiting on any newly fetched records. + + Useful for scenarios where the consumer needs to be stopped to change the + position of the fetcher to something other than the committed offset. There is a + chance that getmany forces a seek to the committed offsets if the fetcher + returns while the consumer is stopped. This can be prevented by waiting for the + fetcher to finish (by default every second). + """ + if not self.not_waiting_next_records.is_set(): + await self.not_waiting_next_records.wait() + def pause_partitions(self, tps: Iterable[TP]) -> None: """Pause fetching from partitions.""" tpset = ensure_TPset(tps) @@ -745,28 +761,33 @@ async def _wait_next_records( if not self.flow_active: await self.wait(self.can_resume_flow) - # Implementation for the Fetcher service. - is_client_only = self.app.client_only - - active_partitions: Optional[Set[TP]] - if is_client_only: - active_partitions = None - else: - active_partitions = self._get_active_partitions() - - records: RecordMap = {} - if is_client_only or active_partitions: - # Fetch records only if active partitions to avoid the risk of - # fetching all partitions in the beginning when none of the - # partitions is paused/resumed. - records = await self._getmany( - active_partitions=active_partitions, - timeout=timeout, - ) - else: - # We should still release to the event loop - await self.sleep(1) - return records, active_partitions + try: + # Set signal that _wait_next_records is waiting on the fetcher service. + self.not_waiting_next_records.set() + # Implementation for the Fetcher service. + is_client_only = self.app.client_only + + active_partitions: Optional[Set[TP]] + if is_client_only: + active_partitions = None + else: + active_partitions = self._get_active_partitions() + + records: RecordMap = {} + if is_client_only or active_partitions: + # Fetch records only if active partitions to avoid the risk of + # fetching all partitions in the beginning when none of the + # partitions is paused/resumed. + records = await self._getmany( + active_partitions=active_partitions, + timeout=timeout, + ) + else: + # We should still release to the event loop + await self.sleep(1) + return records, active_partitions + finally: + self.not_waiting_next_records.set() @abc.abstractmethod def _to_message(self, tp: TP, record: Any) -> ConsumerMessage: diff --git a/tests/unit/transport/test_consumer.py b/tests/unit/transport/test_consumer.py index e686f95c7..485e47610 100644 --- a/tests/unit/transport/test_consumer.py +++ b/tests/unit/transport/test_consumer.py @@ -544,7 +544,6 @@ def to_message(tp, record): assert not consumer.should_stop consumer.flow_active = False consumer.can_resume_flow.set() - # Test is hanging here assert [a async for a in consumer.getmany(1.0)] == [] assert not consumer.should_stop consumer.flow_active = True