Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alias to GriptapeCloudConversationMemoryDriver #1237

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BaseRulesetDriver` for loading a `Ruleset` from an external source.
- `LocalRulesetDriver` for loading a `Ruleset` from a local `.json` file.
- `GriptapeCloudRulesetDriver` for loading a `Ruleset` resource from Griptape Cloud.
- Parameter `alias` on `GriptapeCloudConversationMemoryDriver` for fetching a Thread by alias.

### Changed
- **BREAKING**: Renamed parameters on several classes to `client`:
Expand Down
3 changes: 3 additions & 0 deletions docs/griptape-cloud/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,8 @@ Connect to your data with our [Data Sources](data-sources/create-data-source.md)
## Host and Run Your Code
Have Griptape code? Have existing code with another LLM framework? You can host your Python code using [Structures](structures/create-structure.md) whether it uses the Griptape Framework or not.

## Store Configuration for LLM Agents
[Rules and Rulesets](rules/rulesets.md) enable rapid and collabortive iteration for managing LLM behavior. [Threads and Messages](threads/threads.md) allow for persisted and editable conversation memory across any LLM invocation.

## APIs
All of our features can be called via API with a [Griptape Cloud API Key](https://cloud.griptape.ai/configuration/api-keys). See the [API Reference](api/api-reference.md) for detailed information.
2 changes: 1 addition & 1 deletion docs/griptape-cloud/rules/rulesets.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ A [Ruleset can be created](https://cloud.griptape.ai/rulesets/create) to store s
```bash
export GT_CLOUD_API_KEY=<your API key here>
export ALIAS=<your ruleset alias>
curl -H "Authorization: Bearer ${GT_CLOUD_API_KEY}" https://cloud.griptape.ai/rulesets?alias=${ALIAS}
curl -H "Authorization: Bearer ${GT_CLOUD_API_KEY}" https://cloud.griptape.ai/api/rulesets?alias=${ALIAS}
```
11 changes: 11 additions & 0 deletions docs/griptape-cloud/threads/threads.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Threads

A [Thread can be created](https://cloud.griptape.ai/threads/create) to store conversation history across any LLM invocation. A Thread contains a list of [Messages](https://cloud.griptape.ai/messages/create). Messages can be updated and deleted, in order to control how the LLM recalls past conversations.

A Thread can be given an `alias` so it can be referenced by a user-provided unique identifier:

```bash
export GT_CLOUD_API_KEY=<your API key here>
export ALIAS=<your thread alias>
curl -H "Authorization: Bearer ${GT_CLOUD_API_KEY}" https://cloud.griptape.ai/api/threads?alias=${ALIAS}
```
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import os
import uuid

from griptape.drivers import GriptapeCloudConversationMemoryDriver
from griptape.memory.structure import ConversationMemory
from griptape.structures import Agent

conversation_id = uuid.uuid4().hex
cloud_conversation_driver = GriptapeCloudConversationMemoryDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
alias="my_thread_alias",
)
agent = Agent(conversation_memory=ConversationMemory(conversation_memory_driver=cloud_conversation_driver))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):

Attributes:
thread_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to
retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`. If that is not set, a new Thread will be
created.
retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`.
alias: The alias of the Thread to store the conversation memory in.
base_url: The base URL of the Griptape Cloud API. Defaults to the value of the environment variable
`GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`.
api_key: The API key to use for authenticating with the Griptape Cloud API. If not provided, the driver will
Expand All @@ -33,7 +33,11 @@ class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):
ValueError: If `api_key` is not provided.
"""

thread_id: str = field(
thread_id: Optional[str] = field(
default=None,
metadata={"serializable": True},
)
alias: Optional[str] = field(
default=None,
metadata={"serializable": True},
)
Expand All @@ -45,17 +49,45 @@ class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
init=False,
)

def __attrs_post_init__(self) -> None:
if self.thread_id is None:
self.thread_id = os.getenv("GT_CLOUD_THREAD_ID", self._get_thread_id())
_thread: Optional[dict] = field(default=None, init=False)

@api_key.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
if value is None:
raise ValueError(f"{self.__class__.__name__} requires an API key")
return value

@property
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this a @cached_property since it's so expensive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my concern was that the store method would invalidate the cached value, but i guess i can update the method to invalidate self._thread

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah just call del self.thread to invalidate.

def thread(self) -> dict:
"""Try to get the Thread by ID, alias, or create a new one."""
if self._thread is None:
thread = None
if self.thread_id is None:
self.thread_id = os.getenv("GT_CLOUD_THREAD_ID")

if self.thread_id is not None:
res = self._call_api("get", f"/threads/{self.thread_id}", raise_for_status=False)
if res.status_code == 200:
thread = res.json()

# use name as 'alias' to get thread
if thread is None and self.alias is not None:
res = self._call_api("get", f"/threads?alias={self.alias}").json()
if res.get("threads"):
thread = res["threads"][0]
self.thread_id = thread.get("thread_id")

# no thread by name or thread_id
if thread is None:
data = {"name": uuid.uuid4().hex} if self.alias is None else {"name": self.alias, "alias": self.alias}
thread = self._call_api("post", "/threads", data).json()
self.thread_id = thread["thread_id"]
self.alias = thread.get("alias")

self._thread = thread

return self._thread # pyright: ignore[reportReturnType]

def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
# serialize the run artifacts to json strings
messages = [
Expand All @@ -79,25 +111,20 @@ def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:

# patch the Thread with the new messages and metadata
# all old Messages are replaced with the new ones
response = requests.patch(
self._get_url(f"/threads/{self.thread_id}"),
json=body,
headers=self.headers,
)
response.raise_for_status()
thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id
self._call_api("patch", f"/threads/{thread_id}", body)
self._thread = None

def load(self) -> tuple[list[Run], dict[str, Any]]:
from griptape.memory.structure import Run

thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id

# get the Messages from the Thread
messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers)
messages_response.raise_for_status()
messages_response = messages_response.json()
messages_response = self._call_api("get", f"/threads/{thread_id}/messages").json()

# retrieve the Thread to get the metadata
thread_response = requests.get(self._get_url(f"/threads/{self.thread_id}"), headers=self.headers)
thread_response.raise_for_status()
thread_response = thread_response.json()
thread_response = self._call_api("get", f"/threads/{thread_id}").json()

runs = [
Run(
Expand All @@ -110,11 +137,14 @@ def load(self) -> tuple[list[Run], dict[str, Any]]:
]
return runs, thread_response.get("metadata", {})

def _get_thread_id(self) -> str:
res = requests.post(self._get_url("/threads"), json={"name": uuid.uuid4().hex}, headers=self.headers)
res.raise_for_status()
return res.json().get("thread_id")

def _get_url(self, path: str) -> str:
path = path.lstrip("/")
return urljoin(self.base_url, f"/api/{path}")

def _call_api(
self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True
) -> requests.Response:
res = requests.request(method, self._get_url(path), json=json, headers=self.headers)
if raise_for_status:
res.raise_for_status()
return res
Original file line number Diff line number Diff line change
Expand Up @@ -13,61 +13,70 @@
class TestGriptapeCloudConversationMemoryDriver:
@pytest.fixture(autouse=True)
def _mock_requests(self, mocker):
def get(*args, **kwargs):
if str(args[0]).endswith("/messages"):
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"messages": [
{
"message_id": "123",
"input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}',
"output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}',
"index": 0,
"metadata": {"run_id": "1234"},
}
]
},
)
else:
thread_id = args[0].split("/")[-1]
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"metadata": {"foo": "bar"},
"name": "test",
"thread_id": "test_metadata",
}
if thread_id == "test_metadata"
else {"name": "test", "thread_id": "test"},
)

mocker.patch(
"requests.get",
side_effect=get,
)

def post(*args, **kwargs):
if str(args[0]).endswith("/threads"):
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {"thread_id": "test", "name": "test"},
)
def request(*args, **kwargs):
if args[0] == "get":
if "/messages" in str(args[1]):
thread_id = args[1].split("/")[-2]
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"messages": [
{
"message_id": f"{thread_id}_message",
"input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}',
"output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}',
"metadata": {"run_id": "1234"},
}
]
}
if thread_id != "no_messages"
else {"messages": []},
status_code=200,
)
elif "/threads/" in str(args[1]):
thread_id = args[1].split("/")[-1]
return mocker.Mock(
# raise for status if thread_id is == not_found
raise_for_status=lambda: None if thread_id != "not_found" else ValueError,
json=lambda: {
"metadata": {"foo": "bar"},
"name": "test",
"thread_id": thread_id,
},
status_code=200 if thread_id != "not_found" else 404,
)
elif "/threads?alias=" in str(args[1]):
alias = args[1].split("=")[-1]
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {"threads": [{"thread_id": alias, "alias": alias, "metadata": {"foo": "bar"}}]}
if alias != "not_found"
else {"threads": []},
status_code=200,
)
else:
return mocker.Mock()
elif args[0] == "post":
if str(args[1]).endswith("/threads"):
body = kwargs["json"]
body["thread_id"] = "test"
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: body,
)
else:
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {"message_id": "test"},
)
else:
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {"message_id": "test"},
)

mocker.patch(
"requests.post",
side_effect=post,
)
mocker.patch(
"requests.patch",
return_value=mocker.Mock(
raise_for_status=lambda: None,
),
"requests.request",
side_effect=request,
)

@pytest.fixture()
Expand All @@ -80,12 +89,22 @@ def test_no_api_key(self):

def test_thread_id(self):
driver = GriptapeCloudConversationMemoryDriver(api_key="test")
assert driver.thread_id is None
assert driver.thread.get("thread_id") == "test"
assert driver.thread_id == "test"
os.environ["GT_CLOUD_THREAD_ID"] = "test_env"
driver = GriptapeCloudConversationMemoryDriver(api_key="test")
assert driver.thread_id is None
assert driver.thread.get("thread_id") == "test_env"
assert driver.thread_id == "test_env"
driver = GriptapeCloudConversationMemoryDriver(api_key="test", thread_id="test_init")
assert driver.thread_id == "test_init"
os.environ.pop("GT_CLOUD_THREAD_ID")

def test_thread_alias(self):
driver = GriptapeCloudConversationMemoryDriver(api_key="test", alias="test")
assert driver.thread_id is None
assert driver.thread.get("thread_id") == "test"
assert driver.thread_id == "test"
assert driver.alias == "test"

def test_store(self, driver: GriptapeCloudConversationMemoryDriver):
runs = [
Expand All @@ -98,8 +117,4 @@ def test_load(self, driver):
runs, metadata = driver.load()
assert len(runs) == 1
assert runs[0].id == "1234"
assert metadata == {}
driver.thread_id = "test_metadata"
runs, metadata = driver.load()
assert len(runs) == 1
assert metadata == {"foo": "bar"}
Loading