Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Apr 9, 2024
1 parent 3126514 commit d8926c1
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import uuid
import os
import requests

Expand All @@ -18,9 +17,16 @@ class GriptapeCloudEventListenerDriver(BaseEventListenerDriver):
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
)
run_id: str = field(default=os.getenv("GT_CLOUD_RUN_ID", uuid.uuid4().hex), kw_only=True)
run_id: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_RUN_ID")), kw_only=True)

@run_id.validator # pyright: ignore
def validate_run_id(self, _, run_id: str):
if run_id is None:
raise ValueError(
"run_id must be set either in the constructor or as an environment variable (GT_CLOUD_RUN_ID)."
)

def try_publish_event(self, event: BaseEvent) -> None:
url = urljoin(self.base_url.strip("/"), "/api/events")
url = urljoin(self.base_url.strip("/"), f"/api/runs/{self.run_id}/events/")

requests.post(url=url, json={"run_id": self.run_id, "event": event.to_dict()}, headers=self.headers)
requests.post(url=url, json=event.to_dict(), headers=self.headers)
3 changes: 2 additions & 1 deletion griptape/utils/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from threading import Thread
from queue import Queue
from griptape.artifacts.text_artifact import TextArtifact
from griptape.drivers.event_listener.local_event_listener_driver import LocalEventListenerDriver
from griptape.events.completion_chunk_event import CompletionChunkEvent
from griptape.events.event_listener import EventListener
from griptape.events.base_event import BaseEvent
Expand Down Expand Up @@ -54,6 +53,8 @@ def run(self, *args) -> Iterator[TextArtifact]:
t.join()

def _run_structure(self, *args):
from griptape.drivers import LocalEventListenerDriver

def event_handler(event: BaseEvent):
self._event_queue.put(event)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import Mock
from pytest import fixture
import pytest
from tests.mocks.mock_event import MockEvent
from griptape.drivers.event_listener.griptape_cloud_event_listener_driver import GriptapeCloudEventListenerDriver

Expand All @@ -26,7 +27,11 @@ def test_try_publish_event(self, mock_post, driver):
driver.try_publish_event(event=event)

mock_post.assert_called_once_with(
url="https://cloud.griptape.ai/api/events",
json={"run_id": driver.run_id, "event": event.to_dict()},
url=f"https://cloud.griptape.ai/api/runs/{driver.run_id}/events/",
json=event.to_dict(),
headers={"Authorization": "Bearer foo bar"},
)

def test_no_run_id(self):
with pytest.raises(ValueError):
GriptapeCloudEventListenerDriver(api_key="foo bar")

0 comments on commit d8926c1

Please sign in to comment.