diff --git a/CHANGELOG.md b/CHANGELOG.md index 98aa397b0..4ddbe989d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,15 +8,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `list_files_from_disk` activity to `FileManager` Tool. +- Support for Drivers in `EventListener`. +- `AmazonSqsEventListenerDriver` for sending events to an Amazon SQS queue. +- `AwsIotCoreEventListenerDriver` for sending events to a topic on AWS IoT Core. +- `GriptapeCloudEventListenerDriver` for sending events to Griptape Cloud. +- `WebhookEventListenerDriver` for sending events to a webhook. +- `LocalEventListenerDriver` for sending events to a callback function. + ### Changed -- Improved RAG performance in `VectorQueryEngine`. - **BREAKING**: Secret fields (ex: api_key) removed from serialized Drivers. - **BREAKING**: Remove `FileLoader`. - **BREAKING**: `CsvLoader` no longer accepts `str` file paths as a source. It will now accept the content of the CSV file as a `str` or `bytes` object. - **BREAKING**: `PdfLoader` no longer accepts `str` file content, `Path` file paths or `IO` objects as sources. Instead, it will only accept the content of the PDF file as a `bytes` object. - **BREAKING**: `TextLoader` no longer accepts `Path` file paths as a source. It will now accept the content of the text file as a `str` or `bytes` object. - **BREAKING**: `FileManager.default_loader` is now `None` by default. +- **BREAKING**: Replaced `EventListener.handler` with `EventListener.driver` and `LocalEventListenerDriver`. +- Improved RAG performance in `VectorQueryEngine`. ## [0.24.2] - 2024-04-04 diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index cb453a699..2adc56d7f 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -86,6 +86,13 @@ from .web_scraper.trafilatura_web_scraper_driver import TrafilaturaWebScraperDriver from .web_scraper.markdownify_web_scraper_driver import MarkdownifyWebScraperDriver +from .event_listener.base_event_listener_driver import BaseEventListenerDriver +from .event_listener.amazon_sqs_event_listener_driver import AmazonSqsEventListenerDriver +from .event_listener.webhook_event_listener_driver import WebhookEventListenerDriver +from .event_listener.aws_iot_core_event_listener_driver import AwsIotCoreEventListenerDriver +from .event_listener.griptape_cloud_event_listener_driver import GriptapeCloudEventListenerDriver +from .event_listener.local_event_listener_driver import LocalEventListenerDriver + __all__ = [ "BasePromptDriver", "OpenAiChatPromptDriver", @@ -161,4 +168,10 @@ "BaseWebScraperDriver", "TrafilaturaWebScraperDriver", "MarkdownifyWebScraperDriver", + "BaseEventListenerDriver", + "AmazonSqsEventListenerDriver", + "WebhookEventListenerDriver", + "AwsIotCoreEventListenerDriver", + "GriptapeCloudEventListenerDriver", + "LocalEventListenerDriver", ] diff --git a/griptape/drivers/event_listener/__init__.py b/griptape/drivers/event_listener/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py new file mode 100644 index 000000000..0db63726b --- /dev/null +++ b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +import json + +from attr import Factory, define, field + +from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver +from griptape.events.base_event import BaseEvent +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + import boto3 + + +@define +class AmazonSqsEventListenerDriver(BaseEventListenerDriver): + queue_url: str = field(kw_only=True) + session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) + sqs_client: Any = field(default=Factory(lambda self: self.session.client("sqs"), takes_self=True)) + + def try_publish_event(self, event: BaseEvent) -> None: + self.sqs_client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps({"event": event.to_dict()})) diff --git a/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py new file mode 100644 index 000000000..876b790e8 --- /dev/null +++ b/griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import json +from attr import Factory, define, field + +from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver +from griptape.events.base_event import BaseEvent +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + import boto3 + + +@define +class AwsIotCoreEventListenerDriver(BaseEventListenerDriver): + iot_endpoint: str = field(kw_only=True) + topic: str = field(kw_only=True) + session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) + iotdata_client: Any = field(default=Factory(lambda self: self.session.client("iot-data"), takes_self=True)) + + def try_publish_event(self, event: BaseEvent) -> None: + self.iotdata_client.publish(topic=self.topic, payload=json.dumps({"event": event.to_dict()})) diff --git a/griptape/drivers/event_listener/base_event_listener_driver.py b/griptape/drivers/event_listener/base_event_listener_driver.py new file mode 100644 index 000000000..5bfbe6709 --- /dev/null +++ b/griptape/drivers/event_listener/base_event_listener_driver.py @@ -0,0 +1,20 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from concurrent import futures +from attr import define, field, Factory +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from griptape.events import BaseEvent + + +@define +class BaseEventListenerDriver(ABC): + futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True) + + def publish_event(self, event: BaseEvent) -> None: + self.futures_executor.submit(self.try_publish_event, event) + + @abstractmethod + def try_publish_event(self, event: BaseEvent) -> None: + ... diff --git a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py new file mode 100644 index 000000000..181a246f6 --- /dev/null +++ b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import os +import requests + +from urllib.parse import urljoin +from attr import define, field, Factory + +from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver +from griptape.events.base_event import BaseEvent + + +@define +class GriptapeCloudEventListenerDriver(BaseEventListenerDriver): + base_url: str = field(default="https://cloud.griptape.ai", kw_only=True) + api_key: str = field(kw_only=True) + headers: dict = field( + default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True + ) + run_id: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_RUN_ID")), kw_only=True) + + @run_id.validator # pyright: ignore + def validate_run_id(self, _, run_id: str): + if run_id is None: + raise ValueError( + "run_id must be set either in the constructor or as an environment variable (GT_CLOUD_RUN_ID)." + ) + + def try_publish_event(self, event: BaseEvent) -> None: + url = urljoin(self.base_url.strip("/"), f"/api/runs/{self.run_id}/events/") + + requests.post(url=url, json=event.to_dict(), headers=self.headers) diff --git a/griptape/drivers/event_listener/local_event_listener_driver.py b/griptape/drivers/event_listener/local_event_listener_driver.py new file mode 100644 index 000000000..b276b94ef --- /dev/null +++ b/griptape/drivers/event_listener/local_event_listener_driver.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Callable, Any +from attr import define, field + +from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver +from griptape.events.base_event import BaseEvent + + +@define +class LocalEventListenerDriver(BaseEventListenerDriver): + handler: Callable[[BaseEvent], Any] = field(default=None, kw_only=True) + + def publish_event(self, event: BaseEvent) -> None: + self.try_publish_event(event) + + def try_publish_event(self, event: BaseEvent) -> None: + self.handler(event) diff --git a/griptape/drivers/event_listener/webhook_event_listener_driver.py b/griptape/drivers/event_listener/webhook_event_listener_driver.py new file mode 100644 index 000000000..d2f0046d0 --- /dev/null +++ b/griptape/drivers/event_listener/webhook_event_listener_driver.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import requests + +from attr import define, field + +from griptape.drivers.event_listener.base_event_listener_driver import BaseEventListenerDriver +from griptape.events.base_event import BaseEvent + + +@define +class WebhookEventListenerDriver(BaseEventListenerDriver): + webhook_url: str = field(kw_only=True) + headers: dict = field(default=None, kw_only=True) + + def try_publish_event(self, event: BaseEvent) -> None: + requests.post(url=self.webhook_url, json={"event": event.to_dict()}, headers=self.headers) diff --git a/griptape/events/event_listener.py b/griptape/events/event_listener.py index e757b1b06..aa7da5dcd 100644 --- a/griptape/events/event_listener.py +++ b/griptape/events/event_listener.py @@ -1,9 +1,20 @@ -from typing import Callable, Optional, Type, Any +from __future__ import annotations +from typing import Optional, TYPE_CHECKING from attrs import define, field from .base_event import BaseEvent +if TYPE_CHECKING: + from griptape.drivers import BaseEventListenerDriver + @define class EventListener: - handler: Callable[[BaseEvent], Any] = field() - event_types: Optional[list[Type[BaseEvent]]] = field(default=None, kw_only=True) + event_types: Optional[list[type[BaseEvent]]] = field(default=None, kw_only=True) + driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True) + + def publish_event(self, event: BaseEvent) -> None: + event_types = self.event_types + + if event_types is None or type(event) in event_types: + if self.driver is not None: + self.driver.publish_event(event) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 4e14d087d..e05d25ca5 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -251,11 +251,7 @@ def remove_event_listener(self, event_listener: EventListener) -> None: def publish_event(self, event: BaseEvent) -> None: for event_listener in self.event_listeners: - handler = event_listener.handler - event_types = event_listener.event_types - - if event_types is None or type(event) in event_types: - handler(event) + event_listener.publish_event(event) def context(self, task: BaseTask) -> dict[str, Any]: return {"args": self.execution_args, "structure": self} diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index a6c5c0db4..80d3ea5a1 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -53,11 +53,14 @@ def run(self, *args) -> Iterator[TextArtifact]: t.join() def _run_structure(self, *args): + from griptape.drivers import LocalEventListenerDriver + def event_handler(event: BaseEvent): self._event_queue.put(event) stream_event_listener = EventListener( - event_handler, event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent] + driver=LocalEventListenerDriver(handler=event_handler), + event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) self.structure.add_event_listener(stream_event_listener) diff --git a/poetry.lock b/poetry.lock index 21abd3413..3cc0fbee2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1943,6 +1943,17 @@ files = [ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, ] +[[package]] +name = "jsondiff" +version = "2.0.0" +description = "Diff JSON and JSON-like structures in Python" +optional = false +python-versions = "*" +files = [ + {file = "jsondiff-2.0.0-py3-none-any.whl", hash = "sha256:689841d66273fc88fc79f7d33f4c074774f4f214b6466e3aff0e5adaf889d1e0"}, + {file = "jsondiff-2.0.0.tar.gz", hash = "sha256:2795844ef075ec8a2b8d385c4d59f5ea48b08e7180fce3cb2787be0db00b1fb4"}, +] + [[package]] name = "justext" version = "3.0.0" @@ -2049,7 +2060,6 @@ files = [ {file = "lxml-4.9.4-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:e8f9f93a23634cfafbad6e46ad7d09e0f4a25a2400e4a64b1b7b7c0fbaa06d9d"}, {file = "lxml-4.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3f3f00a9061605725df1816f5713d10cd94636347ed651abdbc75828df302b20"}, {file = "lxml-4.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:953dd5481bd6252bd480d6ec431f61d7d87fdcbbb71b0d2bdcfc6ae00bb6fb10"}, - {file = "lxml-4.9.4-cp312-cp312-win32.whl", hash = "sha256:266f655d1baff9c47b52f529b5f6bec33f66042f65f7c56adde3fcf2ed62ae8b"}, {file = "lxml-4.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:f1faee2a831fe249e1bae9cbc68d3cd8a30f7e37851deee4d7962b17c410dd56"}, {file = "lxml-4.9.4-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:23d891e5bdc12e2e506e7d225d6aa929e0a0368c9916c1fddefab88166e98b20"}, {file = "lxml-4.9.4-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e96a1788f24d03e8d61679f9881a883ecdf9c445a38f9ae3f3f193ab6c591c66"}, @@ -2349,6 +2359,7 @@ botocore = ">=1.12.201" cryptography = ">=3.3.1" docker = {version = ">=3.0.0", optional = true, markers = "extra == \"dynamodb\""} Jinja2 = ">=2.10.1" +jsondiff = {version = ">=1.1.2", optional = true, markers = "extra == \"iotdata\""} py-partiql-parser = {version = "0.5.0", optional = true, markers = "extra == \"dynamodb\""} python-dateutil = ">=2.1,<3.0.0" requests = ">=2.5" @@ -3034,6 +3045,7 @@ files = [ {file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"}, {file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"}, @@ -3042,6 +3054,8 @@ files = [ {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"}, {file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"}, + {file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"}, {file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"}, @@ -5043,6 +5057,8 @@ drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-google = ["google-generativeai"] drivers-embedding-huggingface = ["huggingface-hub", "transformers"] drivers-embedding-voyageai = ["voyageai"] +drivers-event-listener-amazon-iot = ["boto3"] +drivers-event-listener-amazon-sqs = ["boto3"] drivers-memory-conversation-amazon-dynamodb = ["boto3"] drivers-prompt-amazon-bedrock = ["anthropic", "boto3"] drivers-prompt-amazon-sagemaker = ["boto3", "transformers"] @@ -5070,4 +5086,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4eac3f039981703a8c3d3a2d8ead8a9286c6e346e2ba39b0ff77c0e0b0de6aef" +content-hash = "7b1203869463a403ae5ecbf8b3b267a919d6ef1baff302a28084e18c70e52205" diff --git a/pyproject.toml b/pyproject.toml index 6755f02c5..0a52e0a88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,9 @@ drivers-embedding-google = ["google-generativeai"] drivers-web-scraper-trafilatura = ["trafilatura"] drivers-web-scraper-markdownify = ["playwright", "beautifulsoup4", "markdownify"] +drivers-event-listener-amazon-sqs = ["boto3"] +drivers-event-listener-amazon-iot = ["boto3"] + loaders-dataframe = ["pandas"] loaders-pdf = ["pypdf"] loaders-image = ["pillow"] @@ -134,7 +137,7 @@ pytest-mock = "*" mongomock = "*" twine = ">=4" -moto = {extras = ["dynamodb"], version = "^4.1.8"} +moto = {extras = ["dynamodb", "iotdata", "sqs"], version = "^4.2.13"} pytest-xdist = "^3.3.1" pytest-cov = "^4.1.0" pytest-env = "^1.1.1" diff --git a/tests/unit/drivers/event_listener/__init__.py b/tests/unit/drivers/event_listener/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py new file mode 100644 index 000000000..4a19011be --- /dev/null +++ b/tests/unit/drivers/event_listener/test_amazon_sqs_event_listener_driver.py @@ -0,0 +1,31 @@ +from pytest import fixture +from moto import mock_sqs +import boto3 +from tests.mocks.mock_event import MockEvent +from griptape.drivers.event_listener.amazon_sqs_event_listener_driver import AmazonSqsEventListenerDriver +from tests.utils.aws import mock_aws_credentials + + +class TestAmazonSqsEventListenerDriver: + @fixture() + def run_before_and_after_tests(self): + mock_aws_credentials() + + @fixture() + def driver(self): + mock = mock_sqs() + mock.start() + + session = boto3.Session(region_name="us-east-1") + response = session.client("sqs").create_queue(QueueName="foo-bar") + queue_url = response["QueueUrl"] + + yield AmazonSqsEventListenerDriver(queue_url=queue_url, session=session) + + mock.stop() + + def test_init(self, driver): + assert driver + + def test_try_publish_event(self, driver): + driver.try_publish_event(event=MockEvent()) diff --git a/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py new file mode 100644 index 000000000..c56ead9fe --- /dev/null +++ b/tests/unit/drivers/event_listener/test_aws_iot_event_listener_driver.py @@ -0,0 +1,25 @@ +from pytest import fixture +from moto import mock_iotdata +import boto3 +from tests.mocks.mock_event import MockEvent +from griptape.drivers.event_listener.aws_iot_core_event_listener_driver import AwsIotCoreEventListenerDriver +from tests.utils.aws import mock_aws_credentials + + +@mock_iotdata +class TestAwsIotCoreEventListenerDriver: + @fixture() + def run_before_and_after_tests(self): + mock_aws_credentials() + + @fixture() + def driver(self): + return AwsIotCoreEventListenerDriver( + iot_endpoint="foo bar", topic="fizz buzz", session=boto3.Session(region_name="us-east-1") + ) + + def test_init(self, driver): + assert driver + + def test_try_publish_event(self, driver): + driver.try_publish_event(event=MockEvent()) diff --git a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py new file mode 100644 index 000000000..be6340b45 --- /dev/null +++ b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py @@ -0,0 +1,37 @@ +from unittest.mock import Mock +from pytest import fixture +import pytest +from tests.mocks.mock_event import MockEvent +from griptape.drivers.event_listener.griptape_cloud_event_listener_driver import GriptapeCloudEventListenerDriver + + +class TestGriptapeCloudEventListenerDriver: + @fixture(autouse=True) + def mock_post(self, mocker): + data = {"data": {"id": "test"}} + + mock_post = mocker.patch("requests.post") + mock_post.return_value = Mock(status_code=201, json=data) + + return mock_post + + @fixture() + def driver(self): + return GriptapeCloudEventListenerDriver(api_key="foo bar", run_id="baz") + + def test_init(self, driver): + assert driver + + def test_try_publish_event(self, mock_post, driver): + event = MockEvent() + driver.try_publish_event(event=event) + + mock_post.assert_called_once_with( + url=f"https://cloud.griptape.ai/api/runs/{driver.run_id}/events/", + json=event.to_dict(), + headers={"Authorization": "Bearer foo bar"}, + ) + + def test_no_run_id(self): + with pytest.raises(ValueError): + GriptapeCloudEventListenerDriver(api_key="foo bar") diff --git a/tests/unit/drivers/event_listener/test_local_event_listener_driver.py b/tests/unit/drivers/event_listener/test_local_event_listener_driver.py new file mode 100644 index 000000000..a92276fa8 --- /dev/null +++ b/tests/unit/drivers/event_listener/test_local_event_listener_driver.py @@ -0,0 +1,14 @@ +from moto import mock_iotdata +from unittest.mock import Mock +from tests.mocks.mock_event import MockEvent +from griptape.drivers.event_listener.local_event_listener_driver import LocalEventListenerDriver + + +@mock_iotdata +class TestLocalEventListenerDriver: + def test_try_publish_event(self): + mock = Mock() + event = MockEvent() + driver = LocalEventListenerDriver(handler=mock) + driver.try_publish_event(event=event) + mock.assert_called_once_with(event) diff --git a/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py new file mode 100644 index 000000000..207fd48cf --- /dev/null +++ b/tests/unit/drivers/event_listener/test_webhook_event_listener_driver.py @@ -0,0 +1,25 @@ +from unittest.mock import Mock +from pytest import fixture +from tests.mocks.mock_event import MockEvent +from griptape.drivers.event_listener.webhook_event_listener_driver import WebhookEventListenerDriver + + +class TestWebhookEventListenerDriver: + @fixture(autouse=True) + def mock_post(self, mocker): + mock_post = mocker.patch("requests.post") + mock_post.return_value = Mock(status_code=201) + + return mock_post + + def test_init(self): + assert WebhookEventListenerDriver(webhook_url="") + + def test_try_publish_event(self, mock_post): + driver = WebhookEventListenerDriver(webhook_url="foo bar", headers={"Authorization": "Bearer foo bar"}) + event = MockEvent() + driver.try_publish_event(event=event) + + mock_post.assert_called_once_with( + url="foo bar", json={"event": event.to_dict()}, headers={"Authorization": "Bearer foo bar"} + ) diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 17f2dc8c0..39b59ea94 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -1,5 +1,6 @@ from unittest.mock import Mock import pytest +from griptape.drivers.event_listener.local_event_listener_driver import LocalEventListenerDriver from griptape.structures import Pipeline from griptape.tasks import ToolkitTask, ActionsSubtask from griptape.events import ( @@ -33,7 +34,10 @@ def test_untyped_listeners(self, pipeline): event_handler_1 = Mock() event_handler_2 = Mock() - pipeline.event_listeners = [EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)] + pipeline.event_listeners = [ + EventListener(driver=LocalEventListenerDriver(handler=event_handler_1)), + EventListener(driver=LocalEventListenerDriver(handler=event_handler_2)), + ] # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() pipeline.tasks[0].subtasks[0].after_run() @@ -54,15 +58,37 @@ def test_typed_listeners(self, pipeline): completion_chunk_handler = Mock() pipeline.event_listeners = [ - EventListener(start_prompt_event_handler, event_types=[StartPromptEvent]), - EventListener(finish_prompt_event_handler, event_types=[FinishPromptEvent]), - EventListener(start_task_event_handler, event_types=[StartTaskEvent]), - EventListener(finish_task_event_handler, event_types=[FinishTaskEvent]), - EventListener(start_subtask_event_handler, event_types=[StartActionsSubtaskEvent]), - EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), - EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), - EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), - EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), + EventListener( + driver=LocalEventListenerDriver(handler=start_prompt_event_handler), event_types=[StartPromptEvent] + ), + EventListener( + driver=LocalEventListenerDriver(handler=finish_prompt_event_handler), event_types=[FinishPromptEvent] + ), + EventListener( + driver=LocalEventListenerDriver(handler=start_task_event_handler), event_types=[StartTaskEvent] + ), + EventListener( + driver=LocalEventListenerDriver(handler=finish_task_event_handler), event_types=[FinishTaskEvent] + ), + EventListener( + driver=LocalEventListenerDriver(handler=start_subtask_event_handler), + event_types=[StartActionsSubtaskEvent], + ), + EventListener( + driver=LocalEventListenerDriver(handler=finish_subtask_event_handler), + event_types=[FinishActionsSubtaskEvent], + ), + EventListener( + driver=LocalEventListenerDriver(handler=start_structure_run_event_handler), + event_types=[StartStructureRunEvent], + ), + EventListener( + driver=LocalEventListenerDriver(handler=finish_structure_run_event_handler), + event_types=[FinishStructureRunEvent], + ), + EventListener( + driver=LocalEventListenerDriver(handler=completion_chunk_handler), event_types=[CompletionChunkEvent] + ), ] # can't mock subtask events, so must manually call @@ -84,19 +110,23 @@ def test_add_remove_event_listener(self, pipeline): pipeline.event_listeners = [] mock1 = Mock() mock2 = Mock() - # duplicate event listeners will only get added once - event_listener_1 = pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) - pipeline.add_event_listener(EventListener(mock1, event_types=[StartPromptEvent])) + event_listener_1 = pipeline.add_event_listener( + EventListener(driver=LocalEventListenerDriver(handler=mock1), event_types=[StartPromptEvent]) + ) - event_listener_3 = pipeline.add_event_listener(EventListener(mock1, event_types=[FinishPromptEvent])) - event_listener_4 = pipeline.add_event_listener(EventListener(mock2, event_types=[StartPromptEvent])) + event_listener_2 = pipeline.add_event_listener( + EventListener(driver=LocalEventListenerDriver(handler=mock1), event_types=[FinishPromptEvent]) + ) + event_listener_3 = pipeline.add_event_listener( + EventListener(driver=LocalEventListenerDriver(handler=mock2), event_types=[StartPromptEvent]) + ) - event_listener_5 = pipeline.add_event_listener(EventListener(mock2)) + event_listener_4 = pipeline.add_event_listener(EventListener(driver=LocalEventListenerDriver(handler=mock2))) assert len(pipeline.event_listeners) == 4 pipeline.remove_event_listener(event_listener_1) + pipeline.remove_event_listener(event_listener_2) pipeline.remove_event_listener(event_listener_3) pipeline.remove_event_listener(event_listener_4) - pipeline.remove_event_listener(event_listener_5) assert len(pipeline.event_listeners) == 0