Skip to content

Commit

Permalink
Refactor event driver batching, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 10, 2024
1 parent 0a3fc68 commit dd39c2a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 38 deletions.
34 changes: 13 additions & 21 deletions griptape/drivers/event_listener/base_event_listener_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import threading
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

Expand All @@ -19,7 +18,6 @@
class BaseEventListenerDriver(FuturesExecutorMixin, ABC):
batched: bool = field(default=True, kw_only=True)
batch_size: int = field(default=10, kw_only=True)
thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()))

_batch: list[dict] = field(default=Factory(list), kw_only=True)

Expand All @@ -28,34 +26,28 @@ def batch(self) -> list[dict]:
return self._batch

def publish_event(self, event: BaseEvent | dict) -> None:
self.futures_executor.submit(self._safe_try_publish_event, event)
event_payload = event if isinstance(event, dict) else event.to_dict()

try:
if self.batched:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size:
self._flush_events()
else:
self.futures_executor.submit(self.try_publish_event_payload, event_payload)
except Exception as e:
logger.error(e)

def flush_events(self) -> None:
if self.batch:
with self.thread_lock:
self._flush_events()
self._flush_events()

@abstractmethod
def try_publish_event_payload(self, event_payload: dict) -> None: ...

@abstractmethod
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: ...

def _safe_try_publish_event(self, event: BaseEvent | dict) -> None:
try:
event_payload = event if isinstance(event, dict) else event.to_dict()

if self.batched:
with self.thread_lock:
self._batch.append(event_payload)
if len(self.batch) >= self.batch_size:
self._flush_events()
return
else:
self.try_publish_event_payload(event_payload)
except Exception as e:
logger.error(e)

def _flush_events(self) -> None:
self.try_publish_event_payload_batch(self.batch)
self.futures_executor.submit(self.try_publish_event_payload_batch, self.batch)
self._batch = []
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,38 @@


class TestBaseEventListenerDriver:
def test_publish_event(self):
executor = MagicMock()
executor.__enter__.return_value = executor
driver = MockEventListenerDriver(futures_executor_fn=lambda: executor)
def test_publish_event_no_batched(self):
driver = MockEventListenerDriver(batched=False)
driver.try_publish_event_payload = MagicMock(side_effect=driver.try_publish_event_payload)

driver.publish_event(MockEvent().to_dict())

executor.submit.assert_called_once()
driver.try_publish_event_payload.assert_called_once()

def test__safe_try_publish_event(self):
driver = MockEventListenerDriver(batched=False)
def test_publish_event_yes_batched(self):
driver = MockEventListenerDriver(batched=True)
driver.try_publish_event_payload_batch = MagicMock(side_effect=driver.try_publish_event_payload)

for _ in range(4):
driver._safe_try_publish_event(MockEvent().to_dict())
assert len(driver.batch) == 0
for _ in range(0, 9):
driver.publish_event(MockEvent().to_dict())

def test__safe_try_publish_event_batch(self):
driver = MockEventListenerDriver(batched=True)
assert len(driver._batch) == 9
driver.try_publish_event_payload_batch.assert_not_called()

for _ in range(0, 3):
driver._safe_try_publish_event(MockEvent().to_dict())
assert len(driver.batch) == 3
# Publish the 10th event to trigger the batch publish
driver.publish_event(MockEvent().to_dict())

assert len(driver._batch) == 0
driver.try_publish_event_payload_batch.assert_called_once()

def test__safe_try_publish_event_batch_flush(self):
def test_flush_events(self):
driver = MockEventListenerDriver(batched=True)
driver.try_publish_event_payload_batch = MagicMock(side_effect=driver.try_publish_event_payload)

for _ in range(0, 3):
driver._safe_try_publish_event(MockEvent().to_dict())
driver.publish_event(MockEvent().to_dict())
assert len(driver.batch) == 3

driver.flush_events()
driver.try_publish_event_payload_batch.assert_called_once()
assert len(driver.batch) == 0

0 comments on commit dd39c2a

Please sign in to comment.