diff --git a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py index 44e54cfdf8..181a246f6c 100644 --- a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py +++ b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py @@ -1,6 +1,5 @@ from __future__ import annotations -import uuid import os import requests @@ -18,9 +17,16 @@ class GriptapeCloudEventListenerDriver(BaseEventListenerDriver): headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True ) - run_id: str = field(default=os.getenv("GT_CLOUD_RUN_ID", uuid.uuid4().hex), kw_only=True) + run_id: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_RUN_ID")), kw_only=True) + + @run_id.validator # pyright: ignore + def validate_run_id(self, _, run_id: str): + if run_id is None: + raise ValueError( + "run_id must be set either in the constructor or as an environment variable (GT_CLOUD_RUN_ID)." + ) def try_publish_event(self, event: BaseEvent) -> None: - url = urljoin(self.base_url.strip("/"), "/api/events") + url = urljoin(self.base_url.strip("/"), f"/api/runs/{self.run_id}/events/") - requests.post(url=url, json={"run_id": self.run_id, "event": event.to_dict()}, headers=self.headers) + requests.post(url=url, json=event.to_dict(), headers=self.headers) diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 6cd7566dde..80d3ea5a17 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -4,7 +4,6 @@ from threading import Thread from queue import Queue from griptape.artifacts.text_artifact import TextArtifact -from griptape.drivers.event_listener.local_event_listener_driver import LocalEventListenerDriver from griptape.events.completion_chunk_event import CompletionChunkEvent from griptape.events.event_listener import EventListener from griptape.events.base_event import BaseEvent @@ -54,6 +53,8 @@ def run(self, *args) -> Iterator[TextArtifact]: t.join() def _run_structure(self, *args): + from griptape.drivers import LocalEventListenerDriver + def event_handler(event: BaseEvent): self._event_queue.put(event) diff --git a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py index 729c2f7c44..be6340b45b 100644 --- a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py @@ -1,5 +1,6 @@ from unittest.mock import Mock from pytest import fixture +import pytest from tests.mocks.mock_event import MockEvent from griptape.drivers.event_listener.griptape_cloud_event_listener_driver import GriptapeCloudEventListenerDriver @@ -26,7 +27,11 @@ def test_try_publish_event(self, mock_post, driver): driver.try_publish_event(event=event) mock_post.assert_called_once_with( - url="https://cloud.griptape.ai/api/events", - json={"run_id": driver.run_id, "event": event.to_dict()}, + url=f"https://cloud.griptape.ai/api/runs/{driver.run_id}/events/", + json=event.to_dict(), headers={"Authorization": "Bearer foo bar"}, ) + + def test_no_run_id(self): + with pytest.raises(ValueError): + GriptapeCloudEventListenerDriver(api_key="foo bar")