From aa00bddb59ef0647e462dbbe66b49bda91a94611 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 12 Jun 2024 17:26:32 -0700 Subject: [PATCH 01/34] Refactor prompt stack --- .../drivers/prompt-drivers.md | 2 +- griptape/common/__init__.py | 23 ++++ griptape/common/prompt_stack/__init__.py | 0 .../common/prompt_stack/contents/__init__.py | 0 .../base_delta_prompt_stack_content.py | 14 ++ .../contents/base_prompt_stack_content.py | 31 +++++ .../delta_text_prompt_stack_content.py | 9 ++ .../contents/image_prompt_stack_content.py | 17 +++ .../contents/text_prompt_stack_content.py | 20 +++ .../elements/base_prompt_stack_element.py | 25 ++++ .../elements/delta_prompt_stack_element.py | 33 +++++ .../elements/prompt_stack_element.py | 58 ++++++++ griptape/common/prompt_stack/prompt_stack.py | 40 ++++++ griptape/config/google_structure_config.py | 4 +- .../prompt/amazon_bedrock_prompt_driver.py | 83 ++++++++--- ...mazon_sagemaker_jumpstart_prompt_driver.py | 67 ++++++--- .../drivers/prompt/anthropic_prompt_driver.py | 129 +++++++++++++----- .../prompt/azure_openai_chat_prompt_driver.py | 2 +- griptape/drivers/prompt/base_prompt_driver.py | 103 +++++++++----- .../drivers/prompt/cohere_prompt_driver.py | 85 ++++++++++-- .../drivers/prompt/dummy_prompt_driver.py | 29 +++- .../drivers/prompt/google_prompt_driver.py | 92 +++++++++---- .../prompt/huggingface_hub_prompt_driver.py | 53 +++++-- .../huggingface_pipeline_prompt_driver.py | 52 ++++--- .../prompt/openai_chat_prompt_driver.py | 126 ++++++++++++----- .../extraction/csv_extraction_engine.py | 7 +- .../extraction/json_extraction_engine.py | 7 +- griptape/engines/query/vector_query_engine.py | 18 ++- .../engines/summary/prompt_summary_engine.py | 3 +- griptape/events/base_prompt_event.py | 1 - griptape/events/finish_prompt_event.py | 3 + griptape/events/start_prompt_event.py | 3 +- .../structure/base_conversation_memory.py | 2 +- .../memory/structure/conversation_memory.py | 2 +- .../structure/summary_conversation_memory.py | 6 +- griptape/schemas/base_schema.py | 6 +- griptape/structures/agent.py | 15 +- griptape/tasks/actions_subtask.py | 21 ++- griptape/tasks/base_text_input_task.py | 49 +++++-- griptape/tasks/prompt_task.py | 15 +- griptape/tasks/toolkit_task.py | 2 +- .../tokenizers/amazon_bedrock_tokenizer.py | 1 + griptape/tokenizers/huggingface_tokenizer.py | 1 + griptape/tokenizers/openai_tokenizer.py | 7 +- griptape/tokenizers/voyageai_tokenizer.py | 7 +- griptape/utils/__init__.py | 2 - griptape/utils/conversation.py | 2 +- griptape/utils/prompt_stack.py | 48 ------- poetry.lock | 88 +++++++++++- pyproject.toml | 2 +- tests/mocks/mock_failing_prompt_driver.py | 33 +++-- tests/mocks/mock_prompt_driver.py | 34 +++-- tests/mocks/mock_value_prompt_driver.py | 21 --- .../config/test_google_structure_config.py | 2 +- .../test_amazon_bedrock_prompt_driver.py | 20 +-- ...mazon_sagemaker_jumpstart_prompt_driver.py | 5 +- .../prompt/test_anthropic_prompt_driver.py | 12 +- .../test_azure_openai_chat_prompt_driver.py | 13 +- .../drivers/prompt/test_base_prompt_driver.py | 2 +- .../prompt/test_cohere_prompt_driver.py | 10 +- .../prompt/test_google_prompt_driver.py | 33 +---- .../test_hugging_face_hub_prompt_driver.py | 7 +- ...est_hugging_face_pipeline_prompt_driver.py | 3 +- .../prompt/test_openai_chat_prompt_driver.py | 28 ++-- tests/unit/events/test_base_event.py | 39 ++++-- tests/unit/events/test_finish_prompt_event.py | 5 +- tests/unit/events/test_start_prompt_event.py | 10 +- .../structure/test_conversation_memory.py | 16 +-- .../test_summary_conversation_memory.py | 7 +- tests/unit/tasks/test_toolkit_task.py | 7 +- .../unit/tokenizers/test_google_tokenizer.py | 5 +- tests/unit/utils/test_base_tokenizer.py | 13 ++ tests/unit/utils/test_prompt_stack.py | 17 +-- 73 files changed, 1247 insertions(+), 510 deletions(-) create mode 100644 griptape/common/__init__.py create mode 100644 griptape/common/prompt_stack/__init__.py create mode 100644 griptape/common/prompt_stack/contents/__init__.py create mode 100644 griptape/common/prompt_stack/contents/base_delta_prompt_stack_content.py create mode 100644 griptape/common/prompt_stack/contents/base_prompt_stack_content.py create mode 100644 griptape/common/prompt_stack/contents/delta_text_prompt_stack_content.py create mode 100644 griptape/common/prompt_stack/contents/image_prompt_stack_content.py create mode 100644 griptape/common/prompt_stack/contents/text_prompt_stack_content.py create mode 100644 griptape/common/prompt_stack/elements/base_prompt_stack_element.py create mode 100644 griptape/common/prompt_stack/elements/delta_prompt_stack_element.py create mode 100644 griptape/common/prompt_stack/elements/prompt_stack_element.py create mode 100644 griptape/common/prompt_stack/prompt_stack.py delete mode 100644 griptape/utils/prompt_stack.py delete mode 100644 tests/mocks/mock_value_prompt_driver.py create mode 100644 tests/unit/utils/test_base_tokenizer.py diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 0100ccbac..883a4f4b4 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -28,7 +28,7 @@ agent.run("I loved the new Batman movie!") Or use them independently: ```python -from griptape.utils import PromptStack +from griptape.common import PromptStack from griptape.drivers import OpenAiChatPromptDriver stack = PromptStack() diff --git a/griptape/common/__init__.py b/griptape/common/__init__.py new file mode 100644 index 000000000..303c52db6 --- /dev/null +++ b/griptape/common/__init__.py @@ -0,0 +1,23 @@ +from .prompt_stack.contents.base_prompt_stack_content import BasePromptStackContent +from .prompt_stack.contents.base_delta_prompt_stack_content import BaseDeltaPromptStackContent +from .prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent +from .prompt_stack.contents.text_prompt_stack_content import TextPromptStackContent +from .prompt_stack.contents.image_prompt_stack_content import ImagePromptStackContent + +from .prompt_stack.elements.base_prompt_stack_element import BasePromptStackElement +from .prompt_stack.elements.delta_prompt_stack_element import DeltaPromptStackElement +from .prompt_stack.elements.prompt_stack_element import PromptStackElement + +from .prompt_stack.prompt_stack import PromptStack + +__all__ = [ + "BasePromptStackElement", + "BaseDeltaPromptStackContent", + "BasePromptStackContent", + "DeltaPromptStackElement", + "PromptStackElement", + "DeltaTextPromptStackContent", + "TextPromptStackContent", + "ImagePromptStackContent", + "PromptStack", +] diff --git a/griptape/common/prompt_stack/__init__.py b/griptape/common/prompt_stack/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/common/prompt_stack/contents/__init__.py b/griptape/common/prompt_stack/contents/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/common/prompt_stack/contents/base_delta_prompt_stack_content.py b/griptape/common/prompt_stack/contents/base_delta_prompt_stack_content.py new file mode 100644 index 000000000..5e06f4ee9 --- /dev/null +++ b/griptape/common/prompt_stack/contents/base_delta_prompt_stack_content.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from abc import ABC +from typing import Optional + +from attrs import define, field + +from griptape.mixins.serializable_mixin import SerializableMixin + + +@define +class BaseDeltaPromptStackContent(ABC, SerializableMixin): + index: int = field(kw_only=True, default=0, metadata={"serializable": True}) + role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) diff --git a/griptape/common/prompt_stack/contents/base_prompt_stack_content.py b/griptape/common/prompt_stack/contents/base_prompt_stack_content.py new file mode 100644 index 000000000..74d94bc6d --- /dev/null +++ b/griptape/common/prompt_stack/contents/base_prompt_stack_content.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from abc import ABC +from collections.abc import Sequence + +from attrs import define, field + +from griptape.artifacts.base_artifact import BaseArtifact +from griptape.mixins import SerializableMixin + +from .base_delta_prompt_stack_content import BaseDeltaPromptStackContent + + +@define +class BasePromptStackContent(ABC, SerializableMixin): + artifact: BaseArtifact = field(metadata={"serializable": True}) + + def to_text(self) -> str: + return str(self.artifact) + + def __str__(self) -> str: + return self.artifact.to_text() + + def __bool__(self) -> bool: + return bool(self.artifact) + + def __len__(self) -> int: + return len(self.artifact) + + @classmethod + def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> BasePromptStackContent: ... diff --git a/griptape/common/prompt_stack/contents/delta_text_prompt_stack_content.py b/griptape/common/prompt_stack/contents/delta_text_prompt_stack_content.py new file mode 100644 index 000000000..25b6c25b3 --- /dev/null +++ b/griptape/common/prompt_stack/contents/delta_text_prompt_stack_content.py @@ -0,0 +1,9 @@ +from __future__ import annotations +from attrs import define, field + +from griptape.common import BaseDeltaPromptStackContent + + +@define +class DeltaTextPromptStackContent(BaseDeltaPromptStackContent): + text: str = field(metadata={"serializable": True}) diff --git a/griptape/common/prompt_stack/contents/image_prompt_stack_content.py b/griptape/common/prompt_stack/contents/image_prompt_stack_content.py new file mode 100644 index 000000000..ab61b3c19 --- /dev/null +++ b/griptape/common/prompt_stack/contents/image_prompt_stack_content.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from collections.abc import Sequence + +from attrs import define, field + +from griptape.artifacts import ImageArtifact +from griptape.common import BaseDeltaPromptStackContent, BasePromptStackContent + + +@define +class ImagePromptStackContent(BasePromptStackContent): + artifact: ImageArtifact = field(metadata={"serializable": True}) + + @classmethod + def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> ImagePromptStackContent: + raise NotImplementedError() diff --git a/griptape/common/prompt_stack/contents/text_prompt_stack_content.py b/griptape/common/prompt_stack/contents/text_prompt_stack_content.py new file mode 100644 index 000000000..b82f2fb1f --- /dev/null +++ b/griptape/common/prompt_stack/contents/text_prompt_stack_content.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from attrs import define, field +from collections.abc import Sequence + +from griptape.artifacts import TextArtifact +from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent, DeltaTextPromptStackContent + + +@define +class TextPromptStackContent(BasePromptStackContent): + artifact: TextArtifact = field(metadata={"serializable": True}) + + @classmethod + def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> TextPromptStackContent: + text_deltas = [delta for delta in deltas if isinstance(delta, DeltaTextPromptStackContent)] + + artifact = TextArtifact(value="".join(delta.text for delta in text_deltas)) + + return cls(artifact=artifact) diff --git a/griptape/common/prompt_stack/elements/base_prompt_stack_element.py b/griptape/common/prompt_stack/elements/base_prompt_stack_element.py new file mode 100644 index 000000000..83d9daaac --- /dev/null +++ b/griptape/common/prompt_stack/elements/base_prompt_stack_element.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from abc import ABC + +from attrs import define, field + +from griptape.mixins import SerializableMixin + + +@define +class BasePromptStackElement(ABC, SerializableMixin): + USER_ROLE = "user" + ASSISTANT_ROLE = "assistant" + SYSTEM_ROLE = "system" + + role: str = field(kw_only=True, metadata={"serializable": True}) + + def is_system(self) -> bool: + return self.role == self.SYSTEM_ROLE + + def is_user(self) -> bool: + return self.role == self.USER_ROLE + + def is_assistant(self) -> bool: + return self.role == self.ASSISTANT_ROLE diff --git a/griptape/common/prompt_stack/elements/delta_prompt_stack_element.py b/griptape/common/prompt_stack/elements/delta_prompt_stack_element.py new file mode 100644 index 000000000..be2dbd500 --- /dev/null +++ b/griptape/common/prompt_stack/elements/delta_prompt_stack_element.py @@ -0,0 +1,33 @@ +from __future__ import annotations +from typing import Optional + +from attrs import define, field + +from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent + + +from .base_prompt_stack_element import BasePromptStackElement + + +@define +class DeltaPromptStackElement(BasePromptStackElement): + @define + class DeltaUsage: + input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) + output_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) + + @property + def total_tokens(self) -> float: + return (self.input_tokens or 0) + (self.output_tokens or 0) + + def __add__(self, other: DeltaPromptStackElement.DeltaUsage) -> DeltaPromptStackElement.DeltaUsage: + return DeltaPromptStackElement.DeltaUsage( + input_tokens=(self.input_tokens or 0) + (other.input_tokens or 0), + output_tokens=(self.output_tokens or 0) + (other.output_tokens or 0), + ) + + role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) + delta_content: Optional[DeltaTextPromptStackContent] = field( + kw_only=True, default=None, metadata={"serializable": True} + ) + delta_usage: DeltaUsage = field(kw_only=True, default=DeltaUsage(), metadata={"serializable": True}) diff --git a/griptape/common/prompt_stack/elements/prompt_stack_element.py b/griptape/common/prompt_stack/elements/prompt_stack_element.py new file mode 100644 index 000000000..0ccd687d3 --- /dev/null +++ b/griptape/common/prompt_stack/elements/prompt_stack_element.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import Any, Optional + +from attrs import Factory, define, field + +from griptape.artifacts import TextArtifact +from griptape.common import BasePromptStackContent, TextPromptStackContent +from griptape.mixins.serializable_mixin import SerializableMixin + +from .base_prompt_stack_element import BasePromptStackElement + + +@define +class PromptStackElement(BasePromptStackElement): + @define + class Usage(SerializableMixin): + input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) + output_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) + + @property + def total_tokens(self) -> float: + return (self.input_tokens or 0) + (self.output_tokens or 0) + + def __init__(self, content: str | list[BasePromptStackContent], **kwargs: Any): + if isinstance(content, str): + content = [TextPromptStackContent(TextArtifact(value=content))] + self.__attrs_init__(content, **kwargs) # pyright: ignore[reportAttributeAccessIssue] + + content: list[BasePromptStackContent] = field(metadata={"serializable": True}) + usage: Usage = field( + kw_only=True, default=Factory(lambda: PromptStackElement.Usage()), metadata={"serializable": True} + ) + + @property + def value(self) -> Any: + if len(self.content) == 1: + return self.content[0].artifact.value + else: + return [content.artifact for content in self.content] + + def __str__(self) -> str: + return self.to_text() + + def to_text(self) -> str: + return self.to_text_artifact().to_text() + + def to_text_artifact(self) -> TextArtifact: + if all(isinstance(content, TextPromptStackContent) for content in self.content): + artifact = TextArtifact(value="") + + for content in self.content: + if isinstance(content, TextPromptStackContent): + artifact += content.artifact + + return artifact + else: + raise ValueError("Cannot convert to TextArtifact") diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py new file mode 100644 index 000000000..f7cc06a5c --- /dev/null +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -0,0 +1,40 @@ +from __future__ import annotations +from attrs import define, field + +from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact, ImageArtifact +from griptape.mixins import SerializableMixin +from griptape.common import PromptStackElement, TextPromptStackContent, ImagePromptStackContent + + +@define +class PromptStack(SerializableMixin): + inputs: list[PromptStackElement] = field(factory=list, kw_only=True, metadata={"serializable": True}) + + def add_input(self, content: str | BaseArtifact, role: str) -> PromptStackElement: + if isinstance(content, str): + self.inputs.append(PromptStackElement(content=[TextPromptStackContent(TextArtifact(content))], role=role)) + elif isinstance(content, TextArtifact): + self.inputs.append(PromptStackElement(content=[TextPromptStackContent(content)], role=role)) + elif isinstance(content, ListArtifact): + contents = [] + for artifact in content.value: + if isinstance(artifact, TextArtifact): + contents.append(TextPromptStackContent(artifact)) + elif isinstance(artifact, ImageArtifact): + contents.append(ImagePromptStackContent(artifact)) + else: + raise ValueError(f"Unsupported artifact type: {type(artifact)}") + self.inputs.append(PromptStackElement(content=contents, role=role)) + else: + raise ValueError(f"Unsupported content type: {type(content)}") + + return self.inputs[-1] + + def add_system_input(self, content: str) -> PromptStackElement: + return self.add_input(content, PromptStackElement.SYSTEM_ROLE) + + def add_user_input(self, content: str | BaseArtifact) -> PromptStackElement: + return self.add_input(content, PromptStackElement.USER_ROLE) + + def add_assistant_input(self, content: str) -> PromptStackElement: + return self.add_input(content, PromptStackElement.ASSISTANT_ROLE) diff --git a/griptape/config/google_structure_config.py b/griptape/config/google_structure_config.py index 744d08782..76f55d3ef 100644 --- a/griptape/config/google_structure_config.py +++ b/griptape/config/google_structure_config.py @@ -14,7 +14,9 @@ @define class GoogleStructureConfig(StructureConfig): prompt_driver: BasePromptDriver = field( - default=Factory(lambda: GooglePromptDriver(model="gemini-pro")), kw_only=True, metadata={"serializable": True} + default=Factory(lambda: GooglePromptDriver(model="gemini-1.5-flash")), + kw_only=True, + metadata={"serializable": True}, ) embedding_driver: BaseEmbeddingDriver = field( default=Factory(lambda: GoogleEmbeddingDriver(model="models/embedding-001")), diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 849ed0901..d3c13a5da 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -6,6 +6,15 @@ from attrs import Factory, define, field from griptape.artifacts import TextArtifact +from griptape.common import ( + BaseDeltaPromptStackContent, + DeltaPromptStackElement, + PromptStackElement, + DeltaTextPromptStackContent, + BasePromptStackContent, + TextPromptStackContent, + ImagePromptStackContent, +) from griptape.drivers import BasePromptDriver from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer from griptape.utils import import_optional_dependency @@ -13,7 +22,7 @@ if TYPE_CHECKING: import boto3 - from griptape.utils import PromptStack + from griptape.common import PromptStack @define @@ -27,45 +36,59 @@ class AmazonBedrockPromptDriver(BasePromptDriver): default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True ) - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: response = self.bedrock_client.converse(**self._base_params(prompt_stack)) + usage = response["usage"] output_message = response["output"]["message"] - output_content = output_message["content"][0]["text"] - return TextArtifact(output_content) + return PromptStackElement( + content=[TextPromptStackContent(TextArtifact(content["text"])) for content in output_message["content"]], + role=PromptStackElement.ASSISTANT_ROLE, + usage=PromptStackElement.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]), + ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack)) stream = response.get("stream") if stream is not None: for event in stream: - if "contentBlockDelta" in event: - yield TextArtifact(event["contentBlockDelta"]["delta"]["text"]) + if "messageStart" in event: + yield DeltaPromptStackElement(role=event["messageStart"]["role"]) + elif "contentBlockDelta" in event: + content_block_delta = event["contentBlockDelta"] + yield DeltaTextPromptStackContent( + content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"] + ) + elif "metadata" in event: + usage = event["metadata"]["usage"] + yield DeltaPromptStackElement( + delta_usage=DeltaPromptStackElement.DeltaUsage( + input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"] + ) + ) else: raise Exception("model response is empty") - def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: - content = [{"text": prompt_input.content}] - - if prompt_input.is_system(): - return {"text": prompt_input.content} - elif prompt_input.is_assistant(): - return {"role": "assistant", "content": content} - else: - return {"role": "user", "content": content} + def _prompt_stack_elements_to_messages(self, elements: list[PromptStackElement]) -> list[dict]: + return [ + { + "role": self.__to_role(input), + "content": [self.__prompt_stack_content_message_content(content) for content in input.content], + } + for input in elements + ] def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = [ - self._prompt_stack_input_to_message(input) - for input in prompt_stack.inputs - if input.is_system() and input.content - ] - messages = [ - self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs if not input.is_system() + {"text": input.to_text_artifact().to_text()} for input in prompt_stack.inputs if input.is_system() ] + messages = self._prompt_stack_elements_to_messages( + [input for input in prompt_stack.inputs if not input.is_system()] + ) + return { "modelId": self.model, "messages": messages, @@ -73,3 +96,19 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "inferenceConfig": {"temperature": self.temperature}, "additionalModelRequestFields": self.additional_model_request_fields, } + + def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: + if isinstance(content, TextPromptStackContent): + return {"text": content.artifact.to_text()} + elif isinstance(content, ImagePromptStackContent): + return {"image": {"format": content.artifact.format, "source": {"bytes": content.artifact.value}}} + else: + raise ValueError(f"Unsupported content type: {type(content)}") + + def __to_role(self, input: PromptStackElement) -> str: + if input.is_system(): + return "system" + elif input.is_assistant(): + return "assistant" + else: + return "user" diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index 18f8e4b77..596606747 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -7,14 +7,21 @@ from attrs import Factory, define, field from griptape.artifacts import TextArtifact -from griptape.drivers.prompt.base_prompt_driver import BasePromptDriver +from griptape.common import ( + PromptStack, + PromptStackElement, + TextPromptStackContent, + DeltaPromptStackElement, + BaseDeltaPromptStackContent, +) +from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency if TYPE_CHECKING: import boto3 - from griptape.utils import PromptStack + from griptape.common import PromptStack @define @@ -40,8 +47,11 @@ def validate_stream(self, _, stream): if stream: raise ValueError("streaming is not supported") - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - payload = {"inputs": self._to_model_input(prompt_stack), "parameters": self._to_model_params(prompt_stack)} + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + payload = { + "inputs": self.prompt_stack_to_string(prompt_stack), + "parameters": {**self._base_params(prompt_stack)}, + } response = self.sagemaker_client.invoke_endpoint( EndpointName=self.endpoint, @@ -59,31 +69,28 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact: if isinstance(decoded_body, list): if decoded_body: - return TextArtifact(decoded_body[0]["generated_text"]) + generated_text = decoded_body[0]["generated_text"] else: raise ValueError("model response is empty") else: - return TextArtifact(decoded_body["generated_text"]) - - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - raise NotImplementedError("streaming is not supported") + generated_text = decoded_body["generated_text"] - def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: - return {"role": prompt_input.role, "content": prompt_input.content} + input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) + output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) - def _to_model_input(self, prompt_stack: PromptStack) -> str: - prompt = self.tokenizer.tokenizer.apply_chat_template( - [self._prompt_stack_input_to_message(i) for i in prompt_stack.inputs], - tokenize=False, - add_generation_prompt=True, + return PromptStackElement( + content=[TextPromptStackContent(TextArtifact(generated_text))], + role=PromptStackElement.ASSISTANT_ROLE, + usage=PromptStackElement.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) - if isinstance(prompt, str): - return prompt - else: - raise ValueError("Invalid output type.") + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + raise NotImplementedError("streaming is not supported") - def _to_model_params(self, prompt_stack: PromptStack) -> dict: + def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: + return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) + + def _base_params(self, prompt_stack: PromptStack) -> dict: return { "temperature": self.temperature, "max_new_tokens": self.max_tokens, @@ -92,3 +99,21 @@ def _to_model_params(self, prompt_stack: PromptStack) -> dict: "stop_strings": self.tokenizer.stop_sequences, "return_full_text": False, } + + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: + messages = [] + + for input in prompt_stack.inputs: + messages.append({"role": input.role, "content": TextPromptStackContent(input.to_text_artifact())}) + + return messages + + def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: + messages = self._prompt_stack_to_messages(prompt_stack) + + tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) + + if isinstance(tokens, list): + return tokens + else: + raise ValueError("Invalid output type.") diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index b74a9d5f6..5b700c35a 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -1,11 +1,28 @@ from __future__ import annotations -from typing import Optional, Any + from collections.abc import Iterator -from attrs import define, field, Factory +from typing import Any, Optional, TYPE_CHECKING + +from attrs import Factory, define, field + from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack, import_optional_dependency +from griptape.common import ( + BaseDeltaPromptStackContent, + BasePromptStackContent, + DeltaPromptStackElement, + DeltaTextPromptStackContent, + ImagePromptStackContent, + PromptStack, + PromptStackElement, + TextPromptStackContent, +) from griptape.drivers import BasePromptDriver from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + from anthropic.types import ContentBlockDeltaEvent + from anthropic.types import ContentBlock @define @@ -32,42 +49,45 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: response = self.client.messages.create(**self._base_params(prompt_stack)) - return TextArtifact(value=response.content[0].text) + return PromptStackElement( + content=[self.__message_content_to_prompt_stack_content(content) for content in response.content], + role=response.role, + usage=PromptStackElement.Usage( + input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens + ), + ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - response = self.client.messages.create(**self._base_params(prompt_stack), stream=True) + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + events = self.client.messages.create(**self._base_params(prompt_stack), stream=True) - for chunk in response: - if chunk.type == "content_block_delta": - yield TextArtifact(value=chunk.delta.text) + for event in events: + if event.type == "content_block_delta": + yield self.__message_content_delta_to_prompt_stack_content_delta(event) + elif event.type == "message_start": + yield DeltaPromptStackElement( + role=event.message.role, + delta_usage=DeltaPromptStackElement.DeltaUsage(input_tokens=event.message.usage.input_tokens), + ) + elif event.type == "message_delta": + yield DeltaPromptStackElement( + delta_usage=DeltaPromptStackElement.DeltaUsage(output_tokens=event.usage.output_tokens) + ) - def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: - content = prompt_input.content + def _prompt_stack_elements_to_messages(self, elements: list[PromptStackElement]) -> list[dict]: + return [{"role": self.__to_role(input), "content": self.__to_content(input)} for input in elements] - if prompt_input.is_system(): - return {"role": "system", "content": content} - elif prompt_input.is_assistant(): - return {"role": "assistant", "content": content} - else: - return {"role": "user", "content": content} - - def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict: - messages = [ - self._prompt_stack_input_to_message(prompt_input) - for prompt_input in prompt_stack.inputs - if not prompt_input.is_system() - ] - system = next((self._prompt_stack_input_to_message(i) for i in prompt_stack.inputs if i.is_system()), None) - - if system is None: - return {"messages": messages} + def _base_params(self, prompt_stack: PromptStack) -> dict: + messages = self._prompt_stack_elements_to_messages([i for i in prompt_stack.inputs if not i.is_system()]) + + system_element = next((i for i in prompt_stack.inputs if i.is_system()), None) + if system_element: + system_message = system_element.to_text_artifact().to_text() else: - return {"messages": messages, "system": system["content"]} + system_message = None - def _base_params(self, prompt_stack: PromptStack) -> dict: return { "model": self.model, "temperature": self.temperature, @@ -75,5 +95,50 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "top_p": self.top_p, "top_k": self.top_k, "max_tokens": self.max_tokens, - **self._prompt_stack_to_model_input(prompt_stack), + "messages": messages, + **({"system": system_message} if system_message else {}), } + + def __to_role(self, input: PromptStackElement) -> str: + if input.is_system(): + return "system" + elif input.is_assistant(): + return "assistant" + else: + return "user" + + def __to_content(self, input: PromptStackElement) -> str | list[dict]: + if all(isinstance(content, TextPromptStackContent) for content in input.content): + return input.to_text_artifact().to_text() + else: + return [self.__prompt_stack_content_message_content(content) for content in input.content] + + def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: + if isinstance(content, TextPromptStackContent): + return {"type": "text", "text": content.artifact.to_text()} + elif isinstance(content, ImagePromptStackContent): + return { + "type": "image", + "source": {"type": "base64", "media_type": content.artifact.mime_type, "data": content.artifact.base64}, + } + else: + raise ValueError(f"Unsupported prompt content type: {type(content)}") + + def __message_content_to_prompt_stack_content(self, content: ContentBlock) -> BasePromptStackContent: + content_type = content.type + + if content_type == "text": + return TextPromptStackContent(TextArtifact(content.text)) + else: + raise ValueError(f"Unsupported message content type: {content_type}") + + def __message_content_delta_to_prompt_stack_content_delta( + self, content_delta: ContentBlockDeltaEvent + ) -> BaseDeltaPromptStackContent: + index = content_delta.index + delta_type = content_delta.delta.type + + if delta_type == "text_delta": + return DeltaTextPromptStackContent(content_delta.delta.text, index=index) + else: + raise ValueError(f"Unsupported message content delta type : {delta_type}") diff --git a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py index 41c91cb65..50e9effe6 100644 --- a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py @@ -1,6 +1,6 @@ from attrs import define, field, Factory from typing import Callable, Optional -from griptape.utils import PromptStack +from griptape.common import PromptStack from griptape.drivers import OpenAiChatPromptDriver import openai diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 9ef076dbc..89539028e 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -1,14 +1,23 @@ from __future__ import annotations + from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional from collections.abc import Iterator -from attrs import define, field, Factory -from griptape.events import StartPromptEvent, FinishPromptEvent, CompletionChunkEvent -from griptape.mixins.serializable_mixin import SerializableMixin -from griptape.utils import PromptStack -from griptape.mixins import ExponentialBackoffMixin +from typing import TYPE_CHECKING, Optional + +from attrs import Factory, define, field + +from griptape.artifacts.text_artifact import TextArtifact +from griptape.common import ( + BaseDeltaPromptStackContent, + DeltaPromptStackElement, + DeltaTextPromptStackContent, + PromptStack, + PromptStackElement, + TextPromptStackContent, +) +from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent +from griptape.mixins import ExponentialBackoffMixin, SerializableMixin from griptape.tokenizers import BaseTokenizer -from griptape.artifacts import TextArtifact if TYPE_CHECKING: from griptape.structures import Structure @@ -41,20 +50,16 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): def before_run(self, prompt_stack: PromptStack) -> None: if self.structure: - self.structure.publish_event( - StartPromptEvent( - model=self.model, - token_count=self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack)), - prompt_stack=prompt_stack, - prompt=self.prompt_stack_to_string(prompt_stack), - ) - ) + self.structure.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) - def after_run(self, result: TextArtifact) -> None: + def after_run(self, result: PromptStackElement) -> None: if self.structure: self.structure.publish_event( FinishPromptEvent( - model=self.model, result=result.value, token_count=self.tokenizer.count_tokens(result.value) + model=self.model, + result=result.value, + input_token_count=result.usage.input_tokens, + output_token_count=result.usage.output_tokens, ) ) @@ -64,19 +69,13 @@ def run(self, prompt_stack: PromptStack) -> TextArtifact: self.before_run(prompt_stack) if self.stream: - tokens = [] - completion_chunks = self.try_stream(prompt_stack) - for chunk in completion_chunks: - self.structure.publish_event(CompletionChunkEvent(token=chunk.value)) - tokens.append(chunk.value) - result = TextArtifact(value="".join(tokens).strip()) + result = self.__process_stream(prompt_stack) else: - result = self.try_run(prompt_stack) - result.value = result.value.strip() + result = self.__process_run(prompt_stack) self.after_run(result) - return result + return result.to_text_artifact() else: raise Exception("prompt driver failed after all retry attempts") @@ -93,19 +92,61 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: prompt_lines = [] for i in prompt_stack.inputs: + content = i.to_text_artifact().to_text() if i.is_user(): - prompt_lines.append(f"User: {i.content}") + prompt_lines.append(f"User: {content}") elif i.is_assistant(): - prompt_lines.append(f"Assistant: {i.content}") + prompt_lines.append(f"Assistant: {content}") else: - prompt_lines.append(i.content) + prompt_lines.append(content) prompt_lines.append("Assistant:") return "\n\n".join(prompt_lines) @abstractmethod - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: ... + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: ... @abstractmethod - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: ... + def try_stream( + self, prompt_stack: PromptStack + ) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: ... + + def __process_run(self, prompt_stack: PromptStack) -> PromptStackElement: + result = self.try_run(prompt_stack) + + return result + + def __process_stream(self, prompt_stack: PromptStack) -> PromptStackElement: + delta_contents: dict[int, list[BaseDeltaPromptStackContent]] = {} + delta_usage = DeltaPromptStackElement.DeltaUsage() + + deltas = self.try_stream(prompt_stack) + + for delta in deltas: + if isinstance(delta, DeltaPromptStackElement): + delta_usage += delta.delta_usage + elif isinstance(delta, BaseDeltaPromptStackContent): + if delta.index in delta_contents: + delta_contents[delta.index].append(delta) + else: + delta_contents[delta.index] = [delta] + + if isinstance(delta, DeltaTextPromptStackContent): + self.structure.publish_event(CompletionChunkEvent(token=delta.text)) + + content = [] + for index, deltas in delta_contents.items(): + text_deltas = [delta for delta in deltas if isinstance(delta, DeltaTextPromptStackContent)] + if text_deltas: + content.append(TextPromptStackContent.from_deltas(text_deltas)) + + result = PromptStackElement( + content=content, + role=PromptStackElement.ASSISTANT_ROLE, + usage=PromptStackElement.Usage( + input_tokens=delta_usage.input_tokens or 0, output_tokens=delta_usage.output_tokens or 0 + ), + ) + + return result diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 3ff2c9e89..047ea5a35 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -4,8 +4,18 @@ from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver -from griptape.utils import PromptStack, import_optional_dependency -from griptape.tokenizers import BaseTokenizer, CohereTokenizer +from griptape.tokenizers import CohereTokenizer +from griptape.common import ( + PromptStack, + PromptStackElement, + DeltaPromptStackElement, + BaseDeltaPromptStackContent, + TextPromptStackContent, + BasePromptStackContent, + DeltaTextPromptStackContent, +) +from griptape.utils import import_optional_dependency +from griptape.tokenizers import BaseTokenizer if TYPE_CHECKING: from cohere import Client @@ -31,30 +41,60 @@ class CoherePromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: result = self.client.chat(**self._base_params(prompt_stack)) + usage = result.meta.tokens - return TextArtifact(value=result.text) + return PromptStackElement( + content=[TextPromptStackContent(TextArtifact(result.text))], + role=PromptStackElement.ASSISTANT_ROLE, + usage=PromptStackElement.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), + ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: result = self.client.chat_stream(**self._base_params(prompt_stack)) for event in result: if event.event_type == "text-generation": - yield TextArtifact(value=event.text) + yield DeltaTextPromptStackContent(event.text, index=0) + if event.event_type == "stream-end": + usage = event.response.meta.tokens - def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: - if prompt_input.is_system(): - return {"role": "SYSTEM", "text": prompt_input.content} - elif prompt_input.is_user(): - return {"role": "USER", "text": prompt_input.content} - else: - return {"role": "ASSISTANT", "text": prompt_input.content} + yield DeltaPromptStackElement( + role=PromptStackElement.ASSISTANT_ROLE, + delta_usage=DeltaPromptStackElement.DeltaUsage( + input_tokens=usage.input_tokens, output_tokens=usage.output_tokens + ), + ) + + def _prompt_stack_elements_to_messages(self, elements: list[PromptStackElement]) -> list[dict]: + return [ + { + "role": self.__to_role(input), + "content": [self.__prompt_stack_content_message_content(content) for content in input.content], + } + for input in elements + ] def _base_params(self, prompt_stack: PromptStack) -> dict: - user_message = prompt_stack.inputs[-1].content + last_input = prompt_stack.inputs[-1] + if last_input is not None and len(last_input.content) == 1: + user_message = last_input.content[0].artifact.to_text() + else: + raise ValueError("User element must have exactly one content.") - history_messages = [self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs[:-1]] + history_messages = self._prompt_stack_elements_to_messages( + [input for input in prompt_stack.inputs[:-1] if not input.is_system()] + ) + + system_element = next((input for input in prompt_stack.inputs if input.is_system()), None) + if system_element is not None: + if len(system_element.content) == 1: + preamble = system_element.content[0].artifact.to_text() + else: + raise ValueError("System element must have exactly one content.") + else: + preamble = None return { "message": user_message, @@ -62,4 +102,19 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "max_tokens": self.max_tokens, + **({"preamble": preamble} if preamble else {}), } + + def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: + if isinstance(content, TextPromptStackContent): + return {"text": content.artifact.to_text()} + else: + raise ValueError(f"Unsupported content type: {type(content)}") + + def __to_role(self, input: PromptStackElement) -> str: + if input.is_system(): + return "SYSTEM" + elif input.is_user(): + return "USER" + else: + return "CHATBOT" diff --git a/griptape/drivers/prompt/dummy_prompt_driver.py b/griptape/drivers/prompt/dummy_prompt_driver.py index a55ecd4fe..2c9794fbb 100644 --- a/griptape/drivers/prompt/dummy_prompt_driver.py +++ b/griptape/drivers/prompt/dummy_prompt_driver.py @@ -1,10 +1,19 @@ +from __future__ import annotations from collections.abc import Iterator -from attrs import field, Factory, define -from griptape.tokenizers import DummyTokenizer +from typing import Any + +from attrs import Factory, define, field + +from griptape.common import ( + BasePromptStackContent, + PromptStack, + PromptStackElement, + DeltaPromptStackElement, + BaseDeltaPromptStackContent, +) from griptape.drivers import BasePromptDriver -from griptape.artifacts import TextArtifact from griptape.exceptions import DummyException -from griptape.utils.prompt_stack import PromptStack +from griptape.tokenizers import DummyTokenizer @define @@ -12,11 +21,17 @@ class DummyPromptDriver(BasePromptDriver): model: None = field(init=False, default=None, kw_only=True) tokenizer: DummyTokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True) - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: raise DummyException(__class__.__name__, "try_run") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: raise DummyException(__class__.__name__, "try_stream") - def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + def _prompt_stack_input_to_message(self, prompt_input: PromptStackElement) -> dict: raise DummyException(__class__.__name__, "_prompt_stack_input_to_message") + + def _prompt_stack_content_to_message_content(self, content: BasePromptStackContent) -> Any: + raise DummyException(__class__.__name__, "_prompt_stack_content_to_message_content") + + def _message_content_to_prompt_stack_content(self, message_content: Any) -> BasePromptStackContent: + raise DummyException(__class__.__name__, "_message_content_to_prompt_stack_content") diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 67bc19e24..664fd1bec 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -1,15 +1,28 @@ from __future__ import annotations + from collections.abc import Iterator -from typing import TYPE_CHECKING, Optional, Any -from attrs import define, field, Factory -from griptape.utils import PromptStack, import_optional_dependency +from typing import TYPE_CHECKING, Any, Optional + +from attrs import Factory, define, field + from griptape.artifacts import TextArtifact +from griptape.common import ( + BaseDeltaPromptStackContent, + BasePromptStackContent, + DeltaPromptStackElement, + DeltaTextPromptStackContent, + ImagePromptStackContent, + PromptStack, + PromptStackElement, + TextPromptStackContent, +) from griptape.drivers import BasePromptDriver -from griptape.tokenizers import GoogleTokenizer, BaseTokenizer +from griptape.tokenizers import BaseTokenizer, GoogleTokenizer +from griptape.utils import import_optional_dependency if TYPE_CHECKING: from google.generativeai import GenerativeModel - from google.generativeai.types import ContentDict + from google.generativeai.types import ContentDict, GenerateContentResponse @define @@ -33,12 +46,12 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig - inputs = self._prompt_stack_to_model_input(prompt_stack) - response = self.model_client.generate_content( - inputs, + messages = self._prompt_stack_to_messages(prompt_stack) + response: GenerateContentResponse = self.model_client.generate_content( + messages, generation_config=GenerationConfig( stop_sequences=self.tokenizer.stop_sequences, max_output_tokens=self.max_tokens, @@ -48,14 +61,22 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact: ), ) - return TextArtifact(value=response.text) + usage_metadata = response.usage_metadata + + return PromptStackElement( + content=[TextPromptStackContent(TextArtifact(response.text))], + role=PromptStackElement.ASSISTANT_ROLE, + usage=PromptStackElement.Usage( + input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count + ), + ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig - inputs = self._prompt_stack_to_model_input(prompt_stack) - response = self.model_client.generate_content( - inputs, + messages = self._prompt_stack_to_messages(prompt_stack) + response: Iterator[GenerateContentResponse] = self.model_client.generate_content( + messages, stream=True, generation_config=GenerationConfig( stop_sequences=self.tokenizer.stop_sequences, @@ -67,15 +88,17 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: ) for chunk in response: - yield TextArtifact(value=chunk.text) + usage_metadata = chunk.usage_metadata - def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: - parts = [prompt_input.content] + yield DeltaTextPromptStackContent(chunk.text) - if prompt_input.is_assistant(): - return {"role": "model", "parts": parts} - else: - return {"role": "user", "parts": parts} + # TODO: Only yield the first one + yield DeltaPromptStackElement( + role=PromptStackElement.ASSISTANT_ROLE, + delta_usage=DeltaPromptStackElement.DeltaUsage( + input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count + ), + ) def _default_model_client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") @@ -83,20 +106,35 @@ def _default_model_client(self) -> GenerativeModel: return genai.GenerativeModel(self.model) - def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> list[ContentDict]: + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: inputs = [ - self.__to_content_dict(prompt_input) for prompt_input in prompt_stack.inputs if not prompt_input.is_system() + {"role": self.__to_role(input), "parts": self.__to_content(input)} + for input in prompt_stack.inputs + if not input.is_system() ] # Gemini does not have the notion of a system message, so we insert it as part of the first message in the history. system = next((i for i in prompt_stack.inputs if i.is_system()), None) if system is not None: - inputs[0]["parts"].insert(0, system.content) + inputs[0]["parts"].insert(0, "\n".join(content.to_text() for content in system.content)) return inputs - def __to_content_dict(self, prompt_input: PromptStack.Input) -> ContentDict: + def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> ContentDict | str: ContentDict = import_optional_dependency("google.generativeai.types").ContentDict - message = self._prompt_stack_input_to_message(prompt_input) - return ContentDict(message) + if isinstance(content, TextPromptStackContent): + return content.artifact.to_text() + elif isinstance(content, ImagePromptStackContent): + return ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value) + else: + raise ValueError(f"Unsupported content type: {type(content)}") + + def __to_role(self, input: PromptStackElement) -> str: + if input.is_assistant(): + return "model" + else: + return "user" + + def __to_content(self, input: PromptStackElement) -> list[ContentDict | str]: + return [self.__prompt_stack_content_message_content(content) for content in input.content] diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 3edd252cb..e77f1504c 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -5,10 +5,17 @@ from attrs import Factory, define, field -from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer -from griptape.utils import PromptStack, import_optional_dependency +from griptape.common import ( + PromptStack, + PromptStackElement, + DeltaPromptStackElement, + BaseDeltaPromptStackContent, + TextPromptStackContent, + DeltaTextPromptStackContent, +) +from griptape.utils import import_optional_dependency if TYPE_CHECKING: from huggingface_hub import InferenceClient @@ -47,37 +54,57 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: prompt = self.prompt_stack_to_string(prompt_stack) response = self.client.text_generation( prompt, return_full_text=False, max_new_tokens=self.max_tokens, **self.params ) + input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) + output_tokens = len(self.tokenizer.tokenizer.encode(response)) - return TextArtifact(value=response) + return PromptStackElement( + content=response, + role=PromptStackElement.ASSISTANT_ROLE, + usage=PromptStackElement.Usage(input_tokens=input_tokens, output_tokens=output_tokens), + ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: prompt = self.prompt_stack_to_string(prompt_stack) response = self.client.text_generation( prompt, return_full_text=False, max_new_tokens=self.max_tokens, stream=True, **self.params ) + input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) + + full_text = "" for token in response: - yield TextArtifact(value=token) + full_text += token + yield DeltaTextPromptStackContent(token, index=0) - def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: - return {"role": prompt_input.role, "content": prompt_input.content} + output_tokens = len(self.tokenizer.tokenizer.encode(full_text)) + yield DeltaPromptStackElement( + role=PromptStackElement.ASSISTANT_ROLE, + delta_usage=DeltaPromptStackElement.DeltaUsage(input_tokens=input_tokens, output_tokens=output_tokens), + ) def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: + messages = [] + for i in prompt_stack.inputs: + if len(i.content) == 1: + messages.append({"role": i.role, "content": TextPromptStackContent(i.to_text_artifact())}) + else: + raise ValueError("Invalid input content length.") + + return messages + def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: - tokens = self.tokenizer.tokenizer.apply_chat_template( - [self._prompt_stack_input_to_message(i) for i in prompt_stack.inputs], - add_generation_prompt=True, - tokenize=True, - ) + messages = self._prompt_stack_to_messages(prompt_stack) + tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): return tokens diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 4fa291877..180ee9b45 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -1,13 +1,21 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Iterator from typing import TYPE_CHECKING + from attrs import Factory, define, field from griptape.artifacts import TextArtifact +from griptape.common import ( + BaseDeltaPromptStackContent, + DeltaPromptStackElement, + PromptStack, + PromptStackElement, + TextPromptStackContent, +) from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer -from griptape.utils import PromptStack, import_optional_dependency +from griptape.utils import import_optional_dependency if TYPE_CHECKING: from transformers import TextGenerationPipeline @@ -40,43 +48,49 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): ) ) - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - messages = [self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs] + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + messages = self._prompt_stack_to_messages(prompt_stack) result = self.pipe( - messages, - max_new_tokens=self.max_tokens, - tokenizer=self.tokenizer.tokenizer, - stop_strings=self.tokenizer.stop_sequences, - temperature=self.temperature, - do_sample=True, + messages, max_new_tokens=self.max_tokens, temperature=self.temperature, do_sample=True, **self.params ) if isinstance(result, list): if len(result) == 1: generated_text = result[0]["generated_text"][-1]["content"] - return TextArtifact(value=generated_text) + input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) + output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) + + return PromptStackElement( + content=[TextPromptStackContent(TextArtifact(generated_text))], + role=PromptStackElement.ASSISTANT_ROLE, + usage=PromptStackElement.Usage(input_tokens=input_tokens, output_tokens=output_tokens), + ) else: raise Exception("completion with more than one choice is not supported yet") else: raise Exception("invalid output format") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: raise NotImplementedError("streaming is not supported") def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) - def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: - return {"role": prompt_input.role, "content": prompt_input.content} + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: + messages = [] + for i in prompt_stack.inputs: + if len(i.content) == 1: + messages.append({"role": i.role, "content": TextPromptStackContent(i.to_text_artifact())}) + else: + raise ValueError("Invalid input content length.") + + return messages def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: - tokens = self.tokenizer.tokenizer.apply_chat_template( - [self._prompt_stack_input_to_message(i) for i in prompt_stack.inputs], - add_generation_prompt=True, - tokenize=True, - ) + messages = self._prompt_stack_to_messages(prompt_stack) + tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): return tokens diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 9545bd45a..c35186976 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -1,12 +1,29 @@ from __future__ import annotations -from typing import Optional, Literal + from collections.abc import Iterator +from typing import Literal, Optional, TYPE_CHECKING + import openai -from attrs import define, field, Factory +from attrs import Factory, define, field + from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack +from griptape.common import ( + BaseDeltaPromptStackContent, + BasePromptStackContent, + DeltaPromptStackElement, + DeltaTextPromptStackContent, + ImagePromptStackContent, + PromptStack, + PromptStackElement, + TextPromptStackContent, +) from griptape.drivers import BasePromptDriver -from griptape.tokenizers import OpenAiTokenizer, BaseTokenizer +from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer + + +if TYPE_CHECKING: + from openai.types.chat.chat_completion_message import ChatCompletionMessage + from openai.types.chat.chat_completion_chunk import ChoiceDelta @define @@ -57,45 +74,54 @@ class OpenAiChatPromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: result = self.client.chat.completions.create(**self._base_params(prompt_stack)) if len(result.choices) == 1: - return TextArtifact(value=result.choices[0].message.content.strip()) + message = result.choices[0].message + + return PromptStackElement( + content=[self.__message_to_prompt_stack_content(message)], + role=message.role, + usage=PromptStackElement.Usage( + input_tokens=result.usage.prompt_tokens, output_tokens=result.usage.completion_tokens + ), + ) else: raise Exception("Completion with more than one choice is not supported yet.") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - result = self.client.chat.completions.create(**self._base_params(prompt_stack), stream=True) + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + result = self.client.chat.completions.create( + **self._base_params(prompt_stack), stream=True, stream_options={"include_usage": True} + ) for chunk in result: - if len(chunk.choices) == 1: - delta = chunk.choices[0].delta - else: - raise Exception("Completion with more than one choice is not supported yet.") - - if delta.content is not None: - delta_content = delta.content - - yield TextArtifact(value=delta_content) - - def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: - content = prompt_input.content - - if prompt_input.is_system(): - return {"role": "system", "content": content} - elif prompt_input.is_assistant(): - return {"role": "assistant", "content": content} - else: - return {"role": "user", "content": content} + if chunk.usage is not None: + yield DeltaPromptStackElement( + delta_usage=DeltaPromptStackElement.DeltaUsage( + input_tokens=chunk.usage.prompt_tokens, output_tokens=chunk.usage.completion_tokens + ) + ) + elif chunk.choices is not None: + if len(chunk.choices) == 1: + choice = chunk.choices[0] + delta = choice.delta + + yield self.__message_delta_to_prompt_stack_content_delta(delta) + else: + raise Exception("Completion with more than one choice is not supported yet.") + + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: + return [{"role": self.__to_role(input), "content": self.__to_content(input)} for input in prompt_stack.inputs] def _base_params(self, prompt_stack: PromptStack) -> dict: params = { "model": self.model, "temperature": self.temperature, - "stop": self.tokenizer.stop_sequences, "user": self.user, "seed": self.seed, + **({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}), + **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), } if self.response_format == "json_object": @@ -103,11 +129,47 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: # JSON mode still requires a system input instructing the LLM to output JSON. prompt_stack.add_system_input("Provide your response as a valid JSON object.") - messages = [self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs] - - if self.max_tokens is not None: - params["max_tokens"] = self.max_tokens + messages = self._prompt_stack_to_messages(prompt_stack) params["messages"] = messages return params + + def __to_role(self, input: PromptStackElement) -> str: + if input.is_system(): + return "system" + elif input.is_assistant(): + return "assistant" + else: + return "user" + + def __to_content(self, input: PromptStackElement) -> str | list[dict]: + if all(isinstance(content, TextPromptStackContent) for content in input.content): + return input.to_text_artifact().to_text() + else: + return [self.__prompt_stack_content_message_content(content) for content in input.content] + + def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: + if isinstance(content, TextPromptStackContent): + return {"type": "text", "text": content.artifact.to_text()} + elif isinstance(content, ImagePromptStackContent): + return { + "type": "image_url", + "image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"}, + } + else: + raise ValueError(f"Unsupported content type: {type(content)}") + + def __message_to_prompt_stack_content(self, message: ChatCompletionMessage) -> BasePromptStackContent: + if message.content is not None: + return TextPromptStackContent(TextArtifact(message.content)) + else: + raise ValueError(f"Unsupported message type: {message}") + + def __message_delta_to_prompt_stack_content_delta(self, content_delta: ChoiceDelta) -> BaseDeltaPromptStackContent: + if content_delta.content is not None: + delta_content = content_delta.content + + return DeltaTextPromptStackContent(delta_content, role=content_delta.role) + else: + return DeltaTextPromptStackContent("", role=content_delta.role) diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index fe4d0e6c7..481a2da4d 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -4,7 +4,8 @@ import io from attrs import field, Factory, define from griptape.artifacts import TextArtifact, CsvRowArtifact, ListArtifact, ErrorArtifact -from griptape.utils import PromptStack +from griptape.common import PromptStack +from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement from griptape.engines import BaseExtractionEngine from griptape.utils import J2 from griptape.rules import Ruleset @@ -64,7 +65,7 @@ def _extract_rec( rows.extend( self.text_to_csv_rows( self.prompt_driver.run( - PromptStack(inputs=[PromptStack.Input(full_text, role=PromptStack.USER_ROLE)]) + PromptStack(inputs=[PromptStackElement(full_text, role=PromptStackElement.USER_ROLE)]) ).value, column_names, ) @@ -82,7 +83,7 @@ def _extract_rec( rows.extend( self.text_to_csv_rows( self.prompt_driver.run( - PromptStack(inputs=[PromptStack.Input(partial_text, role=PromptStack.USER_ROLE)]) + PromptStack(inputs=[PromptStackElement(partial_text, role=PromptStackElement.USER_ROLE)]) ).value, column_names, ) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index 05db19d40..cfa76f8af 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -3,9 +3,10 @@ import json from attrs import field, Factory, define from griptape.artifacts import TextArtifact, ListArtifact, ErrorArtifact +from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement from griptape.engines import BaseExtractionEngine from griptape.utils import J2 -from griptape.utils import PromptStack +from griptape.common import PromptStack from griptape.rules import Ruleset @@ -59,7 +60,7 @@ def _extract_rec( extractions.extend( self.json_to_text_artifacts( self.prompt_driver.run( - PromptStack(inputs=[PromptStack.Input(full_text, role=PromptStack.USER_ROLE)]) + PromptStack(inputs=[PromptStackElement(full_text, role=PromptStackElement.USER_ROLE)]) ).value ) ) @@ -76,7 +77,7 @@ def _extract_rec( extractions.extend( self.json_to_text_artifacts( self.prompt_driver.run( - PromptStack(inputs=[PromptStack.Input(partial_text, role=PromptStack.USER_ROLE)]) + PromptStack(inputs=[PromptStackElement(partial_text, role=PromptStackElement.USER_ROLE)]) ).value ) ) diff --git a/griptape/engines/query/vector_query_engine.py b/griptape/engines/query/vector_query_engine.py index 24338b348..d4db926b0 100644 --- a/griptape/engines/query/vector_query_engine.py +++ b/griptape/engines/query/vector_query_engine.py @@ -2,7 +2,8 @@ from typing import TYPE_CHECKING, Optional from attrs import define, field, Factory from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact -from griptape.utils import PromptStack +from griptape.common import PromptStack +from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement from griptape.engines import BaseQueryEngine from griptape.utils.j2 import J2 from griptape.rules import Ruleset @@ -53,8 +54,8 @@ def query( self.prompt_driver.prompt_stack_to_string( PromptStack( inputs=[ - PromptStack.Input(system_message, role=PromptStack.SYSTEM_ROLE), - PromptStack.Input(user_message, role=PromptStack.USER_ROLE), + PromptStackElement(system_message, role=PromptStackElement.SYSTEM_ROLE), + PromptStackElement(user_message, role=PromptStackElement.USER_ROLE), ] ) ) @@ -71,15 +72,20 @@ def query( break - return self.prompt_driver.run( + result = self.prompt_driver.run( PromptStack( inputs=[ - PromptStack.Input(system_message, role=PromptStack.SYSTEM_ROLE), - PromptStack.Input(user_message, role=PromptStack.USER_ROLE), + PromptStackElement(system_message, role=PromptStackElement.SYSTEM_ROLE), + PromptStackElement(user_message, role=PromptStackElement.USER_ROLE), ] ) ) + if isinstance(result, TextArtifact): + return result + else: + raise ValueError("Prompt Driver did not return a TextArtifact.") + def upsert_text_artifact(self, artifact: TextArtifact, namespace: Optional[str] = None) -> str: result = self.vector_store_driver.upsert_text_artifact(artifact, namespace=namespace) diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 9d3e8db78..e5968c4df 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -2,7 +2,8 @@ from attrs import define, Factory, field from griptape.artifacts import TextArtifact, ListArtifact from griptape.chunkers import BaseChunker, TextChunker -from griptape.utils import PromptStack +from griptape.common import PromptStack +from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement from griptape.drivers import BasePromptDriver from griptape.engines import BaseSummaryEngine from griptape.utils import J2 diff --git a/griptape/events/base_prompt_event.py b/griptape/events/base_prompt_event.py index b9dcd0c57..4a44599cc 100644 --- a/griptape/events/base_prompt_event.py +++ b/griptape/events/base_prompt_event.py @@ -7,4 +7,3 @@ @define class BasePromptEvent(BaseEvent, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) - token_count: int = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/finish_prompt_event.py b/griptape/events/finish_prompt_event.py index 83bc1b9ef..79e338871 100644 --- a/griptape/events/finish_prompt_event.py +++ b/griptape/events/finish_prompt_event.py @@ -1,7 +1,10 @@ from attrs import define, field +from typing import Optional from griptape.events.base_prompt_event import BasePromptEvent @define class FinishPromptEvent(BasePromptEvent): result: str = field(kw_only=True, metadata={"serializable": True}) + input_token_count: Optional[float] = field(kw_only=True, metadata={"serializable": True}) + output_token_count: Optional[float] = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/start_prompt_event.py b/griptape/events/start_prompt_event.py index 7ab418adb..35dae95d6 100644 --- a/griptape/events/start_prompt_event.py +++ b/griptape/events/start_prompt_event.py @@ -5,10 +5,9 @@ from griptape.events.base_prompt_event import BasePromptEvent if TYPE_CHECKING: - from griptape.utils import PromptStack + from griptape.common import PromptStack @define class StartPromptEvent(BasePromptEvent): prompt_stack: PromptStack = field(kw_only=True, metadata={"serializable": True}) - prompt: str = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index f8cc51743..503b35fa1 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Optional from attrs import define, field from griptape.memory.structure import Run -from griptape.utils import PromptStack +from griptape.common import PromptStack from griptape.mixins import SerializableMixin from abc import ABC, abstractmethod diff --git a/griptape/memory/structure/conversation_memory.py b/griptape/memory/structure/conversation_memory.py index 94e73d80c..b4043c33c 100644 --- a/griptape/memory/structure/conversation_memory.py +++ b/griptape/memory/structure/conversation_memory.py @@ -2,7 +2,7 @@ from attrs import define from typing import Optional from griptape.memory.structure import Run, BaseConversationMemory -from griptape.utils import PromptStack +from griptape.common import PromptStack @define diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index e4d5597d5..4b5db6def 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -2,7 +2,9 @@ import logging from typing import TYPE_CHECKING, Optional from attrs import define, field, Factory -from griptape.utils import J2, PromptStack +from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement +from griptape.utils import J2 +from griptape.common import PromptStack from griptape.memory.structure import ConversationMemory if TYPE_CHECKING: @@ -73,7 +75,7 @@ def summarize_runs(self, previous_summary: str | None, runs: list[Run]) -> str | if len(runs) > 0: summary = self.summarize_conversation_template_generator.render(summary=previous_summary, runs=runs) return self.prompt_driver.run( - prompt_stack=PromptStack(inputs=[PromptStack.Input(summary, role=PromptStack.USER_ROLE)]) + prompt_stack=PromptStack(inputs=[PromptStackElement(summary, role=PromptStackElement.USER_ROLE)]) ).to_text() else: return previous_summary diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 6ba23d6fe..e61307c64 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -35,7 +35,7 @@ def make_obj(self, data, **kwargs): cls._resolve_types(attrs_cls) return SubSchema.from_dict( { - a.name: cls._get_field_for_type(a.type) + a.alias or a.name: cls._get_field_for_type(a.type) for a in attrs.fields(attrs_cls) if a.metadata.get("serializable") }, @@ -105,7 +105,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: # These modules are required to avoid `NameError`s when resolving types. from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver from griptape.structures import Structure - from griptape.utils import PromptStack + from griptape.common import PromptStack, PromptStackElement from griptape.tokenizers.base_tokenizer import BaseTokenizer from typing import Any @@ -116,7 +116,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: attrs_cls, localns={ "PromptStack": PromptStack, - "Input": PromptStack.Input, + "Usage": PromptStackElement.Usage, "Structure": Structure, "BaseConversationMemoryDriver": BaseConversationMemoryDriver, "BasePromptDriver": BasePromptDriver, diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index d0446aff0..12512341e 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Callable from attrs import define, field from griptape.tools import BaseTool from griptape.memory.structure import Run from griptape.structures import Structure -from griptape.tasks import PromptTask, ToolkitTask +from griptape.tasks import PromptTask, ToolkitTask, BaseTextInputTask +from griptape.artifacts import BaseArtifact if TYPE_CHECKING: from griptape.tasks import BaseTask @@ -12,7 +13,9 @@ @define class Agent(Structure): - input_template: str = field(default=PromptTask.DEFAULT_INPUT_TEMPLATE) + input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( + default=BaseTextInputTask.DEFAULT_INPUT_TEMPLATE + ) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) @@ -20,11 +23,9 @@ def __attrs_post_init__(self) -> None: super().__attrs_post_init__() if len(self.tasks) == 0: if self.tools: - task = ToolkitTask( - self.input_template, tools=self.tools, max_meta_memory_entries=self.max_meta_memory_entries - ) + task = ToolkitTask(self.input, tools=self.tools, max_meta_memory_entries=self.max_meta_memory_entries) else: - task = PromptTask(self.input_template, max_meta_memory_entries=self.max_meta_memory_entries) + task = PromptTask(self.input, max_meta_memory_entries=self.max_meta_memory_entries) self.add_task(task) diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 1546a825d..4aa2b2783 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -38,19 +38,6 @@ class Action: _input: Optional[str | TextArtifact | Callable[[BaseTask], TextArtifact]] = field(default=None) _memory: Optional[TaskMemory] = None - @property - def input(self) -> TextArtifact: - if isinstance(self._input, TextArtifact): - return self._input - elif isinstance(self._input, Callable): - return self._input(self) - else: - return TextArtifact(self._input) - - @input.setter - def input(self, value: str | TextArtifact | Callable[[BaseTask], TextArtifact]) -> None: - self._input = value - @property def origin_task(self) -> BaseTask: if self.parent_task_id: @@ -178,6 +165,14 @@ def actions_to_dicts(self) -> list[dict]: def actions_to_json(self) -> str: return json.dumps(self.actions_to_dicts()) + def _process_task_input( + self, task_input: str | list | BaseArtifact | Callable[[BaseTask], BaseArtifact] + ) -> BaseArtifact: + if isinstance(task_input, TextArtifact): + return task_input + else: + return super()._process_task_input(task_input) + def __init_from_prompt(self, value: str) -> None: thought_matches = re.findall(self.THOUGHT_PATTERN, value, re.MULTILINE) actions_matches = re.findall(self.ACTIONS_PATTERN, value, re.DOTALL) diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index c5641bb14..9281f5a7e 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -2,10 +2,11 @@ from abc import ABC from typing import Callable +from collections.abc import Sequence from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 @@ -15,21 +16,22 @@ class BaseTextInputTask(RuleMixin, BaseTask, ABC): DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" - _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field( + _input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( default=DEFAULT_INPUT_TEMPLATE, alias="input" ) @property - def input(self) -> TextArtifact: - if isinstance(self._input, TextArtifact): - return self._input - elif isinstance(self._input, Callable): - return self._input(self) + def input(self) -> BaseArtifact: + if isinstance(self._input, list) or isinstance(self._input, tuple): + artifacts = [self._process_task_input(input) for input in self._input] + flattened_artifacts = self.__flatten_artifacts(artifacts) + + return ListArtifact(flattened_artifacts) else: - return TextArtifact(J2().render_from_string(self._input, **self.full_context)) + return self._process_task_input(self._input) @input.setter - def input(self, value: str | TextArtifact | Callable[[BaseTask], TextArtifact]) -> None: + def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact]) -> None: self._input = value def before_run(self) -> None: @@ -41,3 +43,32 @@ def after_run(self) -> None: super().after_run() self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") + + def _process_task_input( + self, task_input: str | list | BaseArtifact | Callable[[BaseTask], BaseArtifact] + ) -> BaseArtifact: + if isinstance(task_input, TextArtifact): + task_input.value = J2().render_from_string(task_input.value, **self.full_context) + + return task_input + elif isinstance(task_input, Callable): + return task_input(self) + elif isinstance(task_input, str): + return self._process_task_input(TextArtifact(task_input)) + elif isinstance(task_input, BaseArtifact): + return task_input + elif isinstance(task_input, list): + return ListArtifact([self._process_task_input(elem) for elem in task_input]) + else: + raise ValueError(f"Invalid input type: {type(task_input)} ") + + def __flatten_artifacts(self, artifacts: Sequence[BaseArtifact]) -> Sequence[BaseArtifact]: + result = [] + + for elem in artifacts: + if isinstance(elem, ListArtifact): + result.extend(self.__flatten_artifacts(elem.value)) + else: + result.append(elem) + + return result diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 694a5050d..5727a499b 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Callable -from attrs import define, field, Factory -from griptape.utils import PromptStack -from griptape.utils import J2 -from griptape.tasks import BaseTextInputTask + +from typing import TYPE_CHECKING, Callable, Optional + +from attrs import Factory, define, field + from griptape.artifacts import BaseArtifact +from griptape.common import PromptStack +from griptape.tasks import BaseTextInputTask +from griptape.utils import J2 if TYPE_CHECKING: from griptape.drivers import BasePromptDriver @@ -27,7 +30,7 @@ def prompt_stack(self) -> PromptStack: stack.add_system_input(self.generate_system_template(self)) - stack.add_user_input(self.input.to_text()) + stack.add_user_input(self.input) if self.output: stack.add_assistant_input(self.output.to_text()) diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index c99f9e23f..de291d1a9 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -10,7 +10,7 @@ from griptape.tasks import ActionsSubtask from griptape.tasks import PromptTask from griptape.utils import J2 -from griptape.utils import PromptStack +from griptape.common import PromptStack if TYPE_CHECKING: from griptape.tools import BaseTool diff --git a/griptape/tokenizers/amazon_bedrock_tokenizer.py b/griptape/tokenizers/amazon_bedrock_tokenizer.py index 670b5739a..951802d59 100644 --- a/griptape/tokenizers/amazon_bedrock_tokenizer.py +++ b/griptape/tokenizers/amazon_bedrock_tokenizer.py @@ -1,4 +1,5 @@ from __future__ import annotations + from attrs import define, field from griptape.tokenizers.base_tokenizer import BaseTokenizer diff --git a/griptape/tokenizers/huggingface_tokenizer.py b/griptape/tokenizers/huggingface_tokenizer.py index a8312567d..fdebd23da 100644 --- a/griptape/tokenizers/huggingface_tokenizer.py +++ b/griptape/tokenizers/huggingface_tokenizer.py @@ -1,4 +1,5 @@ from __future__ import annotations + from typing import TYPE_CHECKING from attrs import define, field, Factory from griptape.utils import import_optional_dependency diff --git a/griptape/tokenizers/openai_tokenizer.py b/griptape/tokenizers/openai_tokenizer.py index 39a2a033e..a58ec5ce5 100644 --- a/griptape/tokenizers/openai_tokenizer.py +++ b/griptape/tokenizers/openai_tokenizer.py @@ -1,8 +1,11 @@ from __future__ import annotations + import logging -from attrs import define, field, Factory -import tiktoken from typing import Optional + +import tiktoken +from attrs import Factory, define, field + from griptape.tokenizers import BaseTokenizer diff --git a/griptape/tokenizers/voyageai_tokenizer.py b/griptape/tokenizers/voyageai_tokenizer.py index d8fb5adf1..649f6e0cc 100644 --- a/griptape/tokenizers/voyageai_tokenizer.py +++ b/griptape/tokenizers/voyageai_tokenizer.py @@ -1,8 +1,11 @@ from __future__ import annotations -from attrs import define, field, Factory + from typing import TYPE_CHECKING, Optional -from griptape.utils import import_optional_dependency + +from attrs import Factory, define, field + from griptape.tokenizers import BaseTokenizer +from griptape.utils import import_optional_dependency if TYPE_CHECKING: from voyageai import Client diff --git a/griptape/utils/__init__.py b/griptape/utils/__init__.py index daac63f4e..43c275723 100644 --- a/griptape/utils/__init__.py +++ b/griptape/utils/__init__.py @@ -7,7 +7,6 @@ from .chat import Chat from .futures import execute_futures_dict from .token_counter import TokenCounter -from .prompt_stack import PromptStack from .dict_utils import remove_null_values_in_dict_recursively, dict_merge from .file_utils import load_file, load_files from .hash import str_to_hash @@ -36,7 +35,6 @@ def minify_json(value: str) -> str: "is_dependency_installed", "execute_futures_dict", "TokenCounter", - "PromptStack", "remove_null_values_in_dict_recursively", "dict_merge", "Stream", diff --git a/griptape/utils/conversation.py b/griptape/utils/conversation.py index 2d87563ae..ef076b168 100644 --- a/griptape/utils/conversation.py +++ b/griptape/utils/conversation.py @@ -23,7 +23,7 @@ def prompt_stack(self) -> list[str]: lines = [] for stack in self.memory.to_prompt_stack().inputs: - lines.append(f"{stack.role}: {stack.content}") + lines.append(f"{stack.role}: {stack.to_text_artifact().to_text()}") return lines diff --git a/griptape/utils/prompt_stack.py b/griptape/utils/prompt_stack.py deleted file mode 100644 index 378f9dd1e..000000000 --- a/griptape/utils/prompt_stack.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations -from attrs import define, field - -from griptape.mixins import SerializableMixin - - -@define -class PromptStack(SerializableMixin): - GENERIC_ROLE = "generic" - USER_ROLE = "user" - ASSISTANT_ROLE = "assistant" - SYSTEM_ROLE = "system" - - @define - class Input(SerializableMixin): - content: str = field(metadata={"serializable": True}) - role: str = field(metadata={"serializable": True}) - - def is_generic(self) -> bool: - return self.role == PromptStack.GENERIC_ROLE - - def is_system(self) -> bool: - return self.role == PromptStack.SYSTEM_ROLE - - def is_user(self) -> bool: - return self.role == PromptStack.USER_ROLE - - def is_assistant(self) -> bool: - return self.role == PromptStack.ASSISTANT_ROLE - - inputs: list[Input] = field(factory=list, kw_only=True, metadata={"serializable": True}) - - def add_input(self, content: str, role: str) -> Input: - self.inputs.append(self.Input(content=content, role=role)) - - return self.inputs[-1] - - def add_generic_input(self, content: str) -> Input: - return self.add_input(content, self.GENERIC_ROLE) - - def add_system_input(self, content: str) -> Input: - return self.add_input(content, self.SYSTEM_ROLE) - - def add_user_input(self, content: str) -> Input: - return self.add_input(content, self.USER_ROLE) - - def add_assistant_input(self, content: str) -> Input: - return self.add_input(content, self.ASSISTANT_ROLE) diff --git a/poetry.lock b/poetry.lock index b71ad227b..477d4ad08 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1543,17 +1543,18 @@ dev = ["flake8", "markdown", "twine", "wheel"] [[package]] name = "google-ai-generativelanguage" -version = "0.4.0" +version = "0.6.4" description = "Google Ai Generativelanguage API client library" optional = true python-versions = ">=3.7" files = [ - {file = "google-ai-generativelanguage-0.4.0.tar.gz", hash = "sha256:c8199066c08f74c4e91290778329bb9f357ba1ea5d6f82de2bc0d10552bf4f8c"}, - {file = "google_ai_generativelanguage-0.4.0-py3-none-any.whl", hash = "sha256:e4c425376c1ee26c78acbc49a24f735f90ebfa81bf1a06495fae509a2433232c"}, + {file = "google-ai-generativelanguage-0.6.4.tar.gz", hash = "sha256:1750848c12af96cb24ae1c3dd05e4bfe24867dc4577009ed03e1042d8421e874"}, + {file = "google_ai_generativelanguage-0.6.4-py3-none-any.whl", hash = "sha256:730e471aa549797118fb1c88421ba1957741433ada575cf5dd08d3aebf903ab1"}, ] [package.dependencies] -google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" @@ -1588,6 +1589,24 @@ grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +[[package]] +name = "google-api-python-client" +version = "2.132.0" +description = "Google API Client Library for Python" +optional = true +python-versions = ">=3.7" +files = [ + {file = "google-api-python-client-2.132.0.tar.gz", hash = "sha256:d6340dc83b72d72333cee5d50f7dcfecbff66a8783164090e945f985ec4c374d"}, + {file = "google_api_python_client-2.132.0-py2.py3-none-any.whl", hash = "sha256:cde87700bd4d37f39f5e940292c1c6cd0910990b5b01f50b1332a8cea38e8595"}, +] + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0" +google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0" +google-auth-httplib2 = ">=0.2.0,<1.0.0" +httplib2 = ">=0.19.0,<1.dev0" +uritemplate = ">=3.0.1,<5" + [[package]] name = "google-auth" version = "2.29.0" @@ -1611,19 +1630,35 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0.dev0)"] +[[package]] +name = "google-auth-httplib2" +version = "0.2.0" +description = "Google Authentication Library: httplib2 transport" +optional = true +python-versions = "*" +files = [ + {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"}, + {file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"}, +] + +[package.dependencies] +google-auth = "*" +httplib2 = ">=0.19.0" + [[package]] name = "google-generativeai" -version = "0.4.1" +version = "0.6.0" description = "Google Generative AI High level API client library and tools." optional = true python-versions = ">=3.9" files = [ - {file = "google_generativeai-0.4.1-py3-none-any.whl", hash = "sha256:89be3c00c2e688108fccefc50f47f45fc9d37ecd53c1ade9d86b5d982919c24a"}, + {file = "google_generativeai-0.6.0-py3-none-any.whl", hash = "sha256:ba1d3b826b872bffe330aaac0dc6de2f0e4610df861c8ce7ec6433771611b676"}, ] [package.dependencies] -google-ai-generativelanguage = "0.4.0" +google-ai-generativelanguage = "0.6.4" google-api-core = "*" +google-api-python-client = "*" google-auth = ">=2.15.0" protobuf = "*" pydantic = "*" @@ -1863,6 +1898,20 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<0.26.0)"] +[[package]] +name = "httplib2" +version = "0.22.0" +description = "A comprehensive HTTP client library." +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"}, + {file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"}, +] + +[package.dependencies] +pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} + [[package]] name = "httpx" version = "0.27.0" @@ -4241,6 +4290,20 @@ cryptography = ">=41.0.5,<43" docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"] test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"] +[[package]] +name = "pyparsing" +version = "3.1.2" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = true +python-versions = ">=3.6.8" +files = [ + {file = "pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742"}, + {file = "pyparsing-3.1.2.tar.gz", hash = "sha256:a1bac0ce561155ecc3ed78ca94d3c9378656ad4c94c1270de543f621420f94ad"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pypdf" version = "3.17.4" @@ -5942,6 +6005,17 @@ tzdata = {version = "*", markers = "platform_system == \"Windows\""} [package.extras] devenv = ["check-manifest", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"] +[[package]] +name = "uritemplate" +version = "4.1.1" +description = "Implementation of RFC 6570 URI Templates" +optional = true +python-versions = ">=3.6" +files = [ + {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"}, + {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"}, +] + [[package]] name = "urllib3" version = "1.26.18" diff --git a/pyproject.toml b/pyproject.toml index 377ba46f9..7ecad7fe5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ redis = { version = "^4.6.0", optional = true } opensearch-py = { version = "^2.3.1", optional = true } pgvector = { version = "^0.2.3", optional = true } psycopg2-binary = { version = "^2.9.9", optional = true } -google-generativeai = { version = "^0.4.1", optional = true } +google-generativeai = { version = "^0.6.0", optional = true } trafilatura = {version = "^1.6", optional = true} playwright = {version = "^1.42", optional = true} beautifulsoup4 = {version = "^4.12.3", optional = true} diff --git a/tests/mocks/mock_failing_prompt_driver.py b/tests/mocks/mock_failing_prompt_driver.py index c97b25d86..33127bf4a 100644 --- a/tests/mocks/mock_failing_prompt_driver.py +++ b/tests/mocks/mock_failing_prompt_driver.py @@ -1,10 +1,18 @@ +from __future__ import annotations from collections.abc import Iterator from attrs import define -from griptape.utils import PromptStack -from griptape.drivers import BasePromptDriver -from griptape.tokenizers import OpenAiTokenizer, BaseTokenizer from griptape.artifacts import TextArtifact +from griptape.common import ( + PromptStack, + PromptStackElement, + TextPromptStackContent, + DeltaPromptStackElement, + DeltaTextPromptStackContent, + BaseDeltaPromptStackContent, +) +from griptape.drivers import BasePromptDriver +from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer @define @@ -14,18 +22,25 @@ class MockFailingPromptDriver(BasePromptDriver): model: str = "test-model" tokenizer: BaseTokenizer = OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: if self.current_attempt < self.max_failures: self.current_attempt += 1 - raise Exception(f"failed attempt") + raise Exception("failed attempt") else: - return TextArtifact("success") + return PromptStackElement( + content=[TextPromptStackContent(TextArtifact("success"))], + role=PromptStackElement.ASSISTANT_ROLE, + usage=PromptStackElement.Usage(input_tokens=100, output_tokens=100), + ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: if self.current_attempt < self.max_failures: self.current_attempt += 1 - raise Exception(f"failed attempt") + raise Exception("failed attempt") else: - yield TextArtifact("success") + yield DeltaPromptStackElement( + delta_content=DeltaTextPromptStackContent("success"), + delta_usage=DeltaPromptStackElement.DeltaUsage(input_tokens=100, output_tokens=100), + ) diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 3235f7cd5..7f59e6bd8 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -1,11 +1,22 @@ from __future__ import annotations + from collections.abc import Iterator from typing import Callable + from attrs import define, field -from griptape.utils import PromptStack + +from griptape.artifacts import TextArtifact +from griptape.common import ( + PromptStack, + PromptStackElement, + DeltaPromptStackElement, + BaseDeltaPromptStackContent, + TextPromptStackContent, + DeltaTextPromptStackContent, +) from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer -from griptape.artifacts import TextArtifact + from tests.mocks.mock_tokenizer import MockTokenizer @@ -16,12 +27,19 @@ class MockPromptDriver(BasePromptDriver): mock_input: str | Callable[[], str] = field(default="mock input", kw_only=True) mock_output: str | Callable[[PromptStack], str] = field(default="mock output", kw_only=True) - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - return TextArtifact( - value=self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output + def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + output = self.mock_output() if isinstance(self.mock_output, Callable) else self.mock_output + + return PromptStackElement( + content=[TextPromptStackContent(TextArtifact(output))], + role=PromptStackElement.ASSISTANT_ROLE, + usage=PromptStackElement.Usage(input_tokens=100, output_tokens=100), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - yield TextArtifact( - value=self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + output = self.mock_output() if isinstance(self.mock_output, Callable) else self.mock_output + + yield DeltaTextPromptStackContent(output) + yield DeltaPromptStackElement( + delta_usage=DeltaPromptStackElement.DeltaUsage(input_tokens=100, output_tokens=100) ) diff --git a/tests/mocks/mock_value_prompt_driver.py b/tests/mocks/mock_value_prompt_driver.py deleted file mode 100644 index 12ddeec9f..000000000 --- a/tests/mocks/mock_value_prompt_driver.py +++ /dev/null @@ -1,21 +0,0 @@ -from collections.abc import Iterator -from attrs import define, field, Factory -from griptape.drivers import BasePromptDriver -from griptape.tokenizers import OpenAiTokenizer, BaseTokenizer -from griptape.artifacts import TextArtifact -from griptape.utils.prompt_stack import PromptStack - - -@define -class MockValuePromptDriver(BasePromptDriver): - value: str = field(kw_only=True) - model: str = field(default="test-model") - tokenizer: BaseTokenizer = field( - default=Factory(lambda: OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL)) - ) - - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - return TextArtifact(value=self.value) - - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - yield TextArtifact(value=self.value) diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_structure_config.py index 72e623ff0..47c46f181 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_structure_config.py @@ -19,7 +19,7 @@ def test_to_dict(self, config): "temperature": 0.1, "max_tokens": None, "stream": False, - "model": "gemini-pro", + "model": "gemini-1.5-flash", "top_p": None, "top_k": None, }, diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index 8aa345595..7c7112ce0 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,6 +1,7 @@ import pytest -from griptape.utils import PromptStack +from griptape.common import PromptStack +from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent from griptape.drivers import AmazonBedrockPromptDriver @@ -9,7 +10,10 @@ class TestAmazonBedrockPromptDriver: def mock_converse(self, mocker): mock_converse = mocker.patch("boto3.Session").return_value.client.return_value.converse - mock_converse.return_value = {"output": {"message": {"content": [{"text": "model-output"}]}}} + mock_converse.return_value = { + "output": {"message": {"content": [{"text": "model-output"}]}}, + "usage": {"inputTokens": 100, "outputTokens": 100}, + } return mock_converse @@ -17,14 +21,15 @@ def mock_converse(self, mocker): def mock_converse_stream(self, mocker): mock_converse_stream = mocker.patch("boto3.Session").return_value.client.return_value.converse_stream - mock_converse_stream.return_value = {"stream": [{"contentBlockDelta": {"delta": {"text": "model-output"}}}]} + mock_converse_stream.return_value = { + "stream": [{"contentBlockDelta": {"contentBlockIndex": 0, "delta": {"text": "model-output"}}}] + } return mock_converse_stream @pytest.fixture def prompt_stack(self): prompt_stack = PromptStack() - prompt_stack.add_generic_input("generic-input") prompt_stack.add_system_input("system-input") prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") @@ -34,7 +39,6 @@ def prompt_stack(self): @pytest.fixture def messages(self): return [ - {"role": "user", "content": [{"text": "generic-input"}]}, {"role": "system", "content": [{"text": "system-input"}]}, {"role": "user", "content": [{"text": "user-input"}]}, {"role": "assistant", "content": [{"text": "assistant-input"}]}, @@ -51,7 +55,6 @@ def test_try_run(self, mock_converse, prompt_stack, messages): mock_converse.assert_called_once_with( modelId=driver.model, messages=[ - {"role": "user", "content": [{"text": "generic-input"}]}, {"role": "user", "content": [{"text": "user-input"}]}, {"role": "assistant", "content": [{"text": "assistant-input"}]}, ], @@ -72,7 +75,6 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages): mock_converse_stream.assert_called_once_with( modelId=driver.model, messages=[ - {"role": "user", "content": [{"text": "generic-input"}]}, {"role": "user", "content": [{"text": "user-input"}]}, {"role": "assistant", "content": [{"text": "assistant-input"}]}, ], @@ -80,4 +82,6 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages): inferenceConfig={"temperature": driver.temperature}, additionalModelRequestFields={}, ) - assert text_artifact.value == "model-output" + + if isinstance(text_artifact, DeltaTextPromptStackContent): + assert text_artifact.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py index 4ae8fe944..1f3a3963b 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -2,7 +2,7 @@ from botocore.response import StreamingBody from griptape.tokenizers import HuggingFaceTokenizer from griptape.drivers.prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver -from griptape.utils import PromptStack +from griptape.common import PromptStack from io import BytesIO import json import pytest @@ -18,7 +18,8 @@ class TestAmazonSageMakerJumpstartPromptDriver: @pytest.fixture(autouse=True) def tokenizer(self, mocker): from_pretrained = mocker.patch("transformers.AutoTokenizer").from_pretrained - from_pretrained.return_value.apply_chat_template.return_value = "foo\n\nUser: bar" + from_pretrained.return_value.decode.return_value = "foo\n\nUser: bar" + from_pretrained.return_value.apply_chat_template.return_value = ["foo", "\nbar"] from_pretrained.return_value.model_max_length = 8000 from_pretrained.return_value.eos_token_id = 1 diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 22178bbf3..9e2e82534 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,5 +1,6 @@ +from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent from griptape.drivers import AnthropicPromptDriver -from griptape.utils import PromptStack +from griptape.common import PromptStack from unittest.mock import Mock import pytest @@ -9,6 +10,7 @@ class TestAnthropicPromptDriver: def mock_client(self, mocker): mock_client = mocker.patch("anthropic.Anthropic") mock_content = Mock() + mock_content.type = "text" mock_content.text = "model-output" mock_client.return_value.messages.create.return_value.content = [mock_content] mock_client.return_value.count_tokens.return_value = 5 @@ -20,6 +22,7 @@ def mock_stream_client(self, mocker): mock_stream_client = mocker.patch("anthropic.Anthropic") mock_chunk = Mock() mock_chunk.type = "content_block_delta" + mock_chunk.delta.type = "text_delta" mock_chunk.delta.text = "model-output" mock_stream_client.return_value.messages.create.return_value = iter([mock_chunk]) mock_stream_client.return_value.count_tokens.return_value = 5 @@ -45,14 +48,12 @@ def test_init(self, model): def test_try_run(self, mock_client, model, system_enabled): # Given prompt_stack = PromptStack() - prompt_stack.add_generic_input("generic-input") if system_enabled: prompt_stack.add_system_input("system-input") prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") driver = AnthropicPromptDriver(model=model, api_key="api-key") expected_messages = [ - {"role": "user", "content": "generic-input"}, {"role": "user", "content": "user-input"}, {"role": "assistant", "content": "assistant-input"}, ] @@ -88,13 +89,11 @@ def test_try_run(self, mock_client, model, system_enabled): def test_try_stream_run(self, mock_stream_client, model, system_enabled): # Given prompt_stack = PromptStack() - prompt_stack.add_generic_input("generic-input") if system_enabled: prompt_stack.add_system_input("system-input") prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") expected_messages = [ - {"role": "user", "content": "generic-input"}, {"role": "user", "content": "user-input"}, {"role": "assistant", "content": "assistant-input"}, ] @@ -115,7 +114,8 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): top_k=250, **{"system": "system-input"} if system_enabled else {}, ) - assert text_artifact.value == "model-output" + if isinstance(text_artifact, DeltaTextPromptStackContent): + assert text_artifact.text == "model-output" def test_try_run_throws_when_prompt_stack_is_string(self): # Given diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index f6bd12d80..b2b3be062 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -1,5 +1,6 @@ import pytest from unittest.mock import Mock +from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent from griptape.drivers import AzureOpenAiChatPromptDriver from tests.unit.drivers.prompt.test_openai_chat_prompt_driver import TestOpenAiChatPromptDriverFixtureMixin @@ -37,11 +38,7 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): # Then mock_chat_completion_create.assert_called_once_with( - model=driver.model, - temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, - user=driver.user, - messages=messages, + model=driver.model, temperature=driver.temperature, user=driver.user, messages=messages ) assert text_artifact.value == "model-output" @@ -58,9 +55,11 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, mock_chat_completion_stream_create.assert_called_once_with( model=driver.model, temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, user=driver.user, stream=True, messages=messages, + stream_options={"include_usage": True}, ) - assert text_artifact.value == "model-output" + + if isinstance(text_artifact, DeltaTextPromptStackContent): + assert text_artifact == "model-output" diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 0743402aa..794df1186 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,5 +1,5 @@ from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.utils import PromptStack +from griptape.common import PromptStack from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver from griptape.artifacts import ErrorArtifact, TextArtifact diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index b3ceb11a4..ac434996c 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -1,8 +1,10 @@ -from griptape.drivers import CoherePromptDriver -from griptape.utils import PromptStack from unittest.mock import Mock + import pytest +from griptape.common import DeltaTextPromptStackContent, PromptStack +from griptape.drivers import CoherePromptDriver + class TestCoherePromptDriver: @pytest.fixture @@ -27,7 +29,6 @@ def mock_tokenizer(self, mocker): @pytest.fixture def prompt_stack(self): prompt_stack = PromptStack() - prompt_stack.add_generic_input("generic-input") prompt_stack.add_system_input("system-input") prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") @@ -54,4 +55,5 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack): # pyright: ign text_artifact = next(driver.try_stream(prompt_stack)) # Then - assert text_artifact.value == "model-output" + if isinstance(text_artifact, DeltaTextPromptStackContent): + assert text_artifact.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index f655d3e51..d9eb9313b 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -1,6 +1,7 @@ from google.generativeai.types import GenerationConfig +from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent from griptape.drivers import GooglePromptDriver -from griptape.utils import PromptStack +from griptape.common import PromptStack from unittest.mock import Mock import pytest @@ -30,7 +31,6 @@ def test_try_run(self, mock_generative_model): prompt_stack.add_system_input("system-input") prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") - prompt_stack.add_generic_input("generic-input") driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", top_p=0.5, top_k=50) # When @@ -41,7 +41,6 @@ def test_try_run(self, mock_generative_model): [ {"parts": ["system-input", "user-input"], "role": "user"}, {"parts": ["assistant-input"], "role": "model"}, - {"parts": ["generic-input"], "role": "user"}, ], generation_config=GenerationConfig( max_output_tokens=None, temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[] @@ -55,7 +54,6 @@ def test_try_stream(self, mock_stream_generative_model): prompt_stack.add_system_input("system-input") prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") - prompt_stack.add_generic_input("generic-input") driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", stream=True, top_p=0.5, top_k=50) # When @@ -67,32 +65,9 @@ def test_try_stream(self, mock_stream_generative_model): [ {"parts": ["system-input", "user-input"], "role": "user"}, {"parts": ["assistant-input"], "role": "model"}, - {"parts": ["generic-input"], "role": "user"}, ], stream=True, generation_config=GenerationConfig(temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[]), ) - assert text_artifact.value == "model-output" - - def test_prompt_stack_to_model_input(self): - # Given - driver = GooglePromptDriver(model="gemini-pro", api_key="1234") - prompt_stack = PromptStack() - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") - prompt_stack.add_generic_input("generic-input") - prompt_stack.add_assistant_input("assistant-input") - prompt_stack.add_user_input("user-input") - - # When - model_input = driver._prompt_stack_to_model_input(prompt_stack) - - # Then - assert model_input == [ - {"role": "user", "parts": ["system-input", "user-input"]}, - {"role": "model", "parts": ["assistant-input"]}, - {"role": "user", "parts": ["generic-input"]}, - {"role": "model", "parts": ["assistant-input"]}, - {"role": "user", "parts": ["user-input"]}, - ] + if isinstance(text_artifact, DeltaTextPromptStackContent): + assert text_artifact.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 15bbb4ead..eed4f0922 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -1,5 +1,6 @@ +from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent from griptape.drivers import HuggingFaceHubPromptDriver -from griptape.utils import PromptStack +from griptape.common import PromptStack import pytest @@ -27,7 +28,6 @@ def mock_client_stream(self, mocker): @pytest.fixture def prompt_stack(self): prompt_stack = PromptStack() - prompt_stack.add_generic_input("generic-input") prompt_stack.add_system_input("system-input") prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") @@ -60,4 +60,5 @@ def test_try_stream(self, prompt_stack, mock_client_stream): text_artifact = next(driver.try_stream(prompt_stack)) # Then - assert text_artifact.value == "model-output" + if isinstance(text_artifact, DeltaTextPromptStackContent): + assert text_artifact.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index b2746ca58..16691b474 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -1,5 +1,5 @@ from griptape.drivers import HuggingFacePipelinePromptDriver -from griptape.utils import PromptStack +from griptape.common import PromptStack import pytest @@ -27,7 +27,6 @@ def mock_autotokenizer(self, mocker): @pytest.fixture def prompt_stack(self): prompt_stack = PromptStack() - prompt_stack.add_generic_input("generic-input") prompt_stack.add_system_input("system-input") prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index a2900d4d3..80825efc6 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,8 +1,8 @@ from griptape.drivers import OpenAiChatPromptDriver -from griptape.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer -from griptape.utils import PromptStack +from griptape.common import PromptStack, DeltaTextPromptStackContent from griptape.tokenizers import OpenAiTokenizer from unittest.mock import Mock +from tests.mocks.mock_tokenizer import MockTokenizer import pytest @@ -29,7 +29,6 @@ def mock_chat_completion_stream_create(self, mocker): @pytest.fixture def prompt_stack(self): prompt_stack = PromptStack() - prompt_stack.add_generic_input("generic-input") prompt_stack.add_system_input("system-input") prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") @@ -38,7 +37,6 @@ def prompt_stack(self): @pytest.fixture def messages(self): return [ - {"role": "user", "content": "generic-input"}, {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, {"role": "assistant", "content": "assistant-input"}, @@ -91,12 +89,7 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): # Then mock_chat_completion_create.assert_called_once_with( - model=driver.model, - temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, - user=driver.user, - messages=messages, - seed=driver.seed, + model=driver.model, temperature=driver.temperature, user=driver.user, messages=messages, seed=driver.seed ) assert text_artifact.value == "model-output" @@ -107,19 +100,18 @@ def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack ) # When - text_artifact = driver.try_run(prompt_stack) + element = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( model=driver.model, temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, user=driver.user, messages=[*messages, {"role": "system", "content": "Provide your response as a valid JSON object."}], seed=driver.seed, response_format={"type": "json_object"}, ) - assert text_artifact.value == "model-output" + assert element.value == "model-output" def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages): # Given @@ -132,13 +124,15 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, mock_chat_completion_stream_create.assert_called_once_with( model=driver.model, temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, user=driver.user, stream=True, messages=messages, seed=driver.seed, + stream_options={"include_usage": True}, ) - assert text_artifact.value == "model-output" + + if isinstance(text_artifact, DeltaTextPromptStackContent): + assert text_artifact.text == "model-output" def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given @@ -151,7 +145,6 @@ def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack mock_chat_completion_create.assert_called_once_with( model=driver.model, temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, user=driver.user, messages=messages, max_tokens=1, @@ -186,7 +179,7 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_chat_ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, - tokenizer=HuggingFaceTokenizer(model="gpt2", max_output_tokens=1000), + tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), max_tokens=1, ) @@ -200,7 +193,6 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa stop=driver.tokenizer.stop_sequences, user=driver.user, messages=[ - {"role": "user", "content": "generic-input"}, {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, {"role": "assistant", "content": "assistant-input"}, diff --git a/tests/unit/events/test_base_event.py b/tests/unit/events/test_base_event.py index 7656b6b0d..6e2166f6a 100644 --- a/tests/unit/events/test_base_event.py +++ b/tests/unit/events/test_base_event.py @@ -29,30 +29,48 @@ def test_to_dict(self): def test_start_prompt_event_from_dict(self): dict_value = { "type": "StartPromptEvent", - "timestamp": 123.0, - "token_count": 10, - "prompt_stack": {"inputs": [{"content": "foo", "role": "user"}, {"content": "bar", "role": "system"}]}, - "prompt": "foo bar", + "id": "917298d4bf894b0a824a8fdb26717a0c", + "timestamp": 123, "model": "foo bar", + "prompt_stack": { + "type": "PromptStack", + "inputs": [ + { + "type": "PromptStackElement", + "role": "user", + "content": [ + {"type": "TextPromptStackContent", "artifact": {"type": "TextArtifact", "value": "foo"}} + ], + "usage": {"type": "Usage", "input_tokens": None, "output_tokens": None}, + }, + { + "type": "PromptStackElement", + "role": "system", + "content": [ + {"type": "TextPromptStackContent", "artifact": {"type": "TextArtifact", "value": "bar"}} + ], + "usage": {"type": "Usage", "input_tokens": None, "output_tokens": None}, + }, + ], + }, } event = BaseEvent.from_dict(dict_value) assert isinstance(event, StartPromptEvent) assert event.timestamp == 123 - assert event.token_count == 10 - assert event.prompt_stack.inputs[0].content == "foo" + assert event.prompt_stack.inputs[0].content[0].artifact.value == "foo" assert event.prompt_stack.inputs[0].role == "user" - assert event.prompt_stack.inputs[1].content == "bar" + assert event.prompt_stack.inputs[1].content[0].artifact.value == "bar" assert event.prompt_stack.inputs[1].role == "system" - assert event.prompt == "foo bar" assert event.model == "foo bar" def test_finish_prompt_event_from_dict(self): dict_value = { "type": "FinishPromptEvent", "timestamp": 123.0, - "token_count": 10, + "input_token_count": 10, + "output_token_count": 12, "result": "foo bar", "model": "foo bar", } @@ -61,7 +79,8 @@ def test_finish_prompt_event_from_dict(self): assert isinstance(event, FinishPromptEvent) assert event.timestamp == 123 - assert event.token_count == 10 + assert event.input_token_count == 10 + assert event.output_token_count == 12 assert event.result == "foo bar" assert event.model == "foo bar" diff --git a/tests/unit/events/test_finish_prompt_event.py b/tests/unit/events/test_finish_prompt_event.py index b788c67f9..7443fce0c 100644 --- a/tests/unit/events/test_finish_prompt_event.py +++ b/tests/unit/events/test_finish_prompt_event.py @@ -5,12 +5,13 @@ class TestFinishPromptEvent: @pytest.fixture def finish_prompt_event(self): - return FinishPromptEvent(token_count=123, result="foo bar", model="foo bar") + return FinishPromptEvent(input_token_count=321, output_token_count=123, result="foo bar", model="foo bar") def test_to_dict(self, finish_prompt_event): assert "timestamp" in finish_prompt_event.to_dict() - assert finish_prompt_event.to_dict()["token_count"] == 123 + assert finish_prompt_event.to_dict()["input_token_count"] == 321 + assert finish_prompt_event.to_dict()["output_token_count"] == 123 assert finish_prompt_event.to_dict()["result"] == "foo bar" assert finish_prompt_event.to_dict()["model"] == "foo bar" diff --git a/tests/unit/events/test_start_prompt_event.py b/tests/unit/events/test_start_prompt_event.py index a80f8cdfc..6f9268a63 100644 --- a/tests/unit/events/test_start_prompt_event.py +++ b/tests/unit/events/test_start_prompt_event.py @@ -1,6 +1,6 @@ import pytest from griptape.events import StartPromptEvent -from griptape.utils import PromptStack +from griptape.common import PromptStack class TestStartPromptEvent: @@ -9,16 +9,14 @@ def start_prompt_event(self): prompt_stack = PromptStack() prompt_stack.add_user_input("foo") prompt_stack.add_system_input("bar") - return StartPromptEvent(token_count=123, prompt_stack=prompt_stack, prompt="foo bar", model="foo bar") + return StartPromptEvent(prompt_stack=prompt_stack, model="foo bar") def test_to_dict(self, start_prompt_event): assert "timestamp" in start_prompt_event.to_dict() - assert start_prompt_event.to_dict()["token_count"] == 123 - assert start_prompt_event.to_dict()["prompt_stack"]["inputs"][0]["content"] == "foo" + assert start_prompt_event.to_dict()["prompt_stack"]["inputs"][0]["content"][0]["artifact"]["value"] == "foo" assert start_prompt_event.to_dict()["prompt_stack"]["inputs"][0]["role"] == "user" - assert start_prompt_event.to_dict()["prompt_stack"]["inputs"][1]["content"] == "bar" + assert start_prompt_event.to_dict()["prompt_stack"]["inputs"][1]["content"][0]["artifact"]["value"] == "bar" assert start_prompt_event.to_dict()["prompt_stack"]["inputs"][1]["role"] == "system" - assert start_prompt_event.to_dict()["prompt"] == "foo bar" assert start_prompt_event.to_dict()["model"] == "foo bar" diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 298e5ac3f..8909956fe 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -1,6 +1,6 @@ import json from griptape.structures import Agent -from griptape.utils import PromptStack +from griptape.common import PromptStack from griptape.memory.structure import ConversationMemory, Run, BaseConversationMemory from griptape.structures import Pipeline from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -37,8 +37,8 @@ def test_to_prompt_stack(self): prompt_stack = memory.to_prompt_stack() - assert prompt_stack.inputs[0].content == "foo" - assert prompt_stack.inputs[1].content == "bar" + assert prompt_stack.inputs[0].content[0].artifact.value == "foo" + assert prompt_stack.inputs[1].content[0].artifact.value == "bar" def test_from_dict(self): memory = ConversationMemory() @@ -161,8 +161,8 @@ def test_add_to_prompt_stack_autopruning_enabled(self): # We expect one run (2 prompt stack inputs) to be pruned. assert len(prompt_stack.inputs) == 11 - assert prompt_stack.inputs[0].content == "fizz" - assert prompt_stack.inputs[1].content == "foo2" - assert prompt_stack.inputs[2].content == "bar2" - assert prompt_stack.inputs[-2].content == "foo" - assert prompt_stack.inputs[-1].content == "bar" + assert prompt_stack.inputs[0].content[0].artifact.value == "fizz" + assert prompt_stack.inputs[1].content[0].artifact.value == "foo2" + assert prompt_stack.inputs[2].content[0].artifact.value == "bar2" + assert prompt_stack.inputs[-2].content[0].artifact.value == "foo" + assert prompt_stack.inputs[-1].content[0].artifact.value == "bar" diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index 09792ff5d..236689284 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -1,6 +1,5 @@ import json -import pytest from griptape.memory.structure import Run, SummaryConversationMemory from griptape.structures import Pipeline @@ -59,9 +58,9 @@ def test_to_prompt_stack(self): prompt_stack = memory.to_prompt_stack() - assert prompt_stack.inputs[0].content == "Summary of the conversation so far: foobar" - assert prompt_stack.inputs[1].content == "foo" - assert prompt_stack.inputs[2].content == "bar" + assert prompt_stack.inputs[0].content[0].artifact.value == "Summary of the conversation so far: foobar" + assert prompt_stack.inputs[1].content[0].artifact.value == "foo" + assert prompt_stack.inputs[2].content[0].artifact.value == "bar" def test_from_dict(self): memory = SummaryConversationMemory() diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index f63b06d5b..14bd5ef06 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -7,7 +7,6 @@ from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tool.tool import MockTool -from tests.mocks.mock_value_prompt_driver import MockValuePromptDriver from tests.utils import defaults @@ -161,7 +160,7 @@ def test_run(self): output = """Answer: done""" task = ToolkitTask("test", tools=[MockTool(name="Tool1"), MockTool(name="Tool2")]) - agent = Agent(prompt_driver=MockValuePromptDriver(value=output)) + agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) agent.add_task(task) @@ -175,7 +174,7 @@ def test_run_max_subtasks(self): output = """Actions: [{"name": "blah"}]""" task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) - agent = Agent(prompt_driver=MockValuePromptDriver(value=output)) + agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) agent.add_task(task) @@ -188,7 +187,7 @@ def test_run_invalid_react_prompt(self): output = """foo bar""" task = ToolkitTask("test", tools=[MockTool(name="Tool1")], max_subtasks=3) - agent = Agent(prompt_driver=MockValuePromptDriver(value=output)) + agent = Agent(prompt_driver=MockPromptDriver(mock_output=output)) agent.add_task(task) diff --git a/tests/unit/tokenizers/test_google_tokenizer.py b/tests/unit/tokenizers/test_google_tokenizer.py index 955a0517f..70b441000 100644 --- a/tests/unit/tokenizers/test_google_tokenizer.py +++ b/tests/unit/tokenizers/test_google_tokenizer.py @@ -1,6 +1,7 @@ import pytest from unittest.mock import Mock -from griptape.utils import PromptStack +from griptape.common import PromptStack +from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement from griptape.tokenizers import GoogleTokenizer @@ -19,7 +20,7 @@ def tokenizer(self, request): @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 5)], indirect=["tokenizer"]) def test_token_count(self, tokenizer, expected): assert tokenizer.count_tokens("foo bar huzzah") == expected - assert tokenizer.count_tokens(PromptStack(inputs=[PromptStack.Input(content="foo", role="user")])) == expected + assert tokenizer.count_tokens(PromptStack(inputs=[PromptStackElement(content="foo", role="user")])) == expected assert tokenizer.count_tokens(["foo", "bar", "huzzah"]) == expected @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 30715)], indirect=["tokenizer"]) diff --git a/tests/unit/utils/test_base_tokenizer.py b/tests/unit/utils/test_base_tokenizer.py new file mode 100644 index 000000000..eed15b9b2 --- /dev/null +++ b/tests/unit/utils/test_base_tokenizer.py @@ -0,0 +1,13 @@ +import logging +from tests.mocks.mock_tokenizer import MockTokenizer + + +class TestBaseTokenizer: + def test_default_tokens(self, caplog): + with caplog.at_level(logging.WARNING): + tokenizer = MockTokenizer(model="gpt2") + + assert tokenizer.max_input_tokens == 4096 + assert tokenizer.max_output_tokens == 1000 + + assert "gpt2 not found" in caplog.text diff --git a/tests/unit/utils/test_prompt_stack.py b/tests/unit/utils/test_prompt_stack.py index 80010abec..0732bc733 100644 --- a/tests/unit/utils/test_prompt_stack.py +++ b/tests/unit/utils/test_prompt_stack.py @@ -1,5 +1,6 @@ import pytest -from griptape.utils import PromptStack + +from griptape.common import PromptStack class TestPromptStack: @@ -14,28 +15,22 @@ def test_add_input(self, prompt_stack): prompt_stack.add_input("foo", "role") assert prompt_stack.inputs[0].role == "role" - assert prompt_stack.inputs[0].content == "foo" - - def test_add_generic_input(self, prompt_stack): - prompt_stack.add_generic_input("foo") - - assert prompt_stack.inputs[0].role == "generic" - assert prompt_stack.inputs[0].content == "foo" + assert prompt_stack.inputs[0].content[0].artifact.value == "foo" def test_add_system_input(self, prompt_stack): prompt_stack.add_system_input("foo") assert prompt_stack.inputs[0].role == "system" - assert prompt_stack.inputs[0].content == "foo" + assert prompt_stack.inputs[0].content[0].artifact.value == "foo" def test_add_user_input(self, prompt_stack): prompt_stack.add_user_input("foo") assert prompt_stack.inputs[0].role == "user" - assert prompt_stack.inputs[0].content == "foo" + assert prompt_stack.inputs[0].content[0].artifact.value == "foo" def test_add_assistant_input(self, prompt_stack): prompt_stack.add_assistant_input("foo") assert prompt_stack.inputs[0].role == "assistant" - assert prompt_stack.inputs[0].content == "foo" + assert prompt_stack.inputs[0].content[0].artifact.value == "foo" From d3995b839292bf251f1d90ef09493d3632ab3bd2 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 14 Jun 2024 14:34:47 -0700 Subject: [PATCH 02/34] Add support for more modalities to conversation memory --- .../elements/prompt_stack_element.py | 12 ++-- griptape/common/prompt_stack/prompt_stack.py | 42 ++++++------ .../engines/summary/prompt_summary_engine.py | 8 +-- griptape/memory/structure/run.py | 5 +- griptape/structures/agent.py | 7 +- griptape/structures/pipeline.py | 7 +- griptape/structures/workflow.py | 7 +- griptape/tasks/base_text_input_task.py | 2 +- griptape/tasks/prompt_task.py | 2 +- griptape/tasks/toolkit_task.py | 2 +- .../prompt_image_generation_client/tool.py | 9 +-- ...est_dynamodb_conversation_memory_driver.py | 4 +- .../test_local_conversation_memory_driver.py | 8 +-- .../test_redis_conversation_memory_driver.py | 2 +- .../structure/test_conversation_memory.py | 67 ++++++++++--------- .../test_summary_conversation_memory.py | 21 +++--- 16 files changed, 94 insertions(+), 111 deletions(-) diff --git a/griptape/common/prompt_stack/elements/prompt_stack_element.py b/griptape/common/prompt_stack/elements/prompt_stack_element.py index 0ccd687d3..b94c8a5f4 100644 --- a/griptape/common/prompt_stack/elements/prompt_stack_element.py +++ b/griptape/common/prompt_stack/elements/prompt_stack_element.py @@ -46,13 +46,9 @@ def to_text(self) -> str: return self.to_text_artifact().to_text() def to_text_artifact(self) -> TextArtifact: - if all(isinstance(content, TextPromptStackContent) for content in self.content): - artifact = TextArtifact(value="") + artifact = TextArtifact(value="") - for content in self.content: - if isinstance(content, TextPromptStackContent): - artifact += content.artifact + for content in self.content: + artifact.value += content.artifact.to_text() - return artifact - else: - raise ValueError("Cannot convert to TextArtifact") + return artifact diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index f7cc06a5c..a82c47216 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -3,7 +3,7 @@ from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact, ImageArtifact from griptape.mixins import SerializableMixin -from griptape.common import PromptStackElement, TextPromptStackContent, ImagePromptStackContent +from griptape.common import PromptStackElement, TextPromptStackContent, BasePromptStackContent, ImagePromptStackContent @define @@ -11,30 +11,34 @@ class PromptStack(SerializableMixin): inputs: list[PromptStackElement] = field(factory=list, kw_only=True, metadata={"serializable": True}) def add_input(self, content: str | BaseArtifact, role: str) -> PromptStackElement: - if isinstance(content, str): - self.inputs.append(PromptStackElement(content=[TextPromptStackContent(TextArtifact(content))], role=role)) - elif isinstance(content, TextArtifact): - self.inputs.append(PromptStackElement(content=[TextPromptStackContent(content)], role=role)) - elif isinstance(content, ListArtifact): - contents = [] - for artifact in content.value: - if isinstance(artifact, TextArtifact): - contents.append(TextPromptStackContent(artifact)) - elif isinstance(artifact, ImageArtifact): - contents.append(ImagePromptStackContent(artifact)) - else: - raise ValueError(f"Unsupported artifact type: {type(artifact)}") - self.inputs.append(PromptStackElement(content=contents, role=role)) - else: - raise ValueError(f"Unsupported content type: {type(content)}") + new_content = self.__process_content(content) + + self.inputs.append(PromptStackElement(content=new_content, role=role)) return self.inputs[-1] - def add_system_input(self, content: str) -> PromptStackElement: + def add_system_input(self, content: str | BaseArtifact) -> PromptStackElement: return self.add_input(content, PromptStackElement.SYSTEM_ROLE) def add_user_input(self, content: str | BaseArtifact) -> PromptStackElement: return self.add_input(content, PromptStackElement.USER_ROLE) - def add_assistant_input(self, content: str) -> PromptStackElement: + def add_assistant_input(self, content: str | BaseArtifact) -> PromptStackElement: return self.add_input(content, PromptStackElement.ASSISTANT_ROLE) + + def __process_content(self, content: str | BaseArtifact) -> list[BasePromptStackContent]: + if isinstance(content, str): + return [TextPromptStackContent(TextArtifact(content))] + elif isinstance(content, TextArtifact): + return [TextPromptStackContent(content)] + elif isinstance(content, ImageArtifact): + return [ImagePromptStackContent(content)] + elif isinstance(content, ListArtifact): + processed_contents = [self.__process_content(artifact) for artifact in content.value] + flattened_content = [ + sub_content for processed_content in processed_contents for sub_content in processed_content + ] + + return flattened_content + else: + raise ValueError(f"Unsupported content type: {type(content)}") diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index e5968c4df..18b5f3a07 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -64,8 +64,8 @@ def summarize_artifacts_rec( return self.prompt_driver.run( PromptStack( inputs=[ - PromptStack.Input(system_prompt, role=PromptStack.SYSTEM_ROLE), - PromptStack.Input(user_prompt, role=PromptStack.USER_ROLE), + PromptStackElement(system_prompt, role=PromptStackElement.SYSTEM_ROLE), + PromptStackElement(user_prompt, role=PromptStackElement.USER_ROLE), ] ) ) @@ -79,8 +79,8 @@ def summarize_artifacts_rec( self.prompt_driver.run( PromptStack( inputs=[ - PromptStack.Input(system_prompt, role=PromptStack.SYSTEM_ROLE), - PromptStack.Input(partial_text, role=PromptStack.USER_ROLE), + PromptStackElement(system_prompt, role=PromptStackElement.SYSTEM_ROLE), + PromptStackElement(partial_text, role=PromptStackElement.USER_ROLE), ] ) ).value, diff --git a/griptape/memory/structure/run.py b/griptape/memory/structure/run.py index c5a2b9b55..b91df2ae9 100644 --- a/griptape/memory/structure/run.py +++ b/griptape/memory/structure/run.py @@ -1,10 +1,11 @@ import uuid from attrs import define, field, Factory +from griptape.artifacts.base_artifact import BaseArtifact from griptape.mixins import SerializableMixin @define class Run(SerializableMixin): id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) - input: str = field(kw_only=True, metadata={"serializable": True}) - output: str = field(kw_only=True, metadata={"serializable": True}) + input: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) + output: BaseArtifact = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 12512341e..f767d74a2 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -51,12 +51,7 @@ def try_run(self, *args) -> Agent: self.task.execute() if self.conversation_memory and self.output is not None: - if isinstance(self.task.input, tuple): - input_text = self.task.input[0].to_text() - else: - input_text = self.task.input.to_text() - - run = Run(input=input_text, output=self.task.output.to_text()) + run = Run(input=self.input_task.input, output=self.output) self.conversation_memory.add_run(run) diff --git a/griptape/structures/pipeline.py b/griptape/structures/pipeline.py index d5724244e..fe8fcbdf1 100644 --- a/griptape/structures/pipeline.py +++ b/griptape/structures/pipeline.py @@ -46,12 +46,7 @@ def try_run(self, *args) -> Pipeline: self.__run_from_task(self.input_task) if self.conversation_memory and self.output is not None: - if isinstance(self.input_task.input, tuple): - input_text = self.input_task.input[0].to_text() - else: - input_text = self.input_task.input.to_text() - - run = Run(input=input_text, output=self.output.to_text()) + run = Run(input=self.input_task.input, output=self.output) self.conversation_memory.add_run(run) diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index 6552fba89..bede3fbf4 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -114,12 +114,7 @@ def try_run(self, *args) -> Workflow: break if self.conversation_memory and self.output is not None: - if isinstance(self.input_task.input, tuple): - input_text = self.input_task.input[0].to_text() - else: - input_text = self.input_task.input.to_text() - - run = Run(input=input_text, output=self.output_task.output.to_text()) + run = Run(input=self.input_task.input, output=self.output) self.conversation_memory.add_run(run) diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index 9281f5a7e..40b6e75c5 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -52,7 +52,7 @@ def _process_task_input( return task_input elif isinstance(task_input, Callable): - return task_input(self) + return self._process_task_input(task_input(self)) elif isinstance(task_input, str): return self._process_task_input(TextArtifact(task_input)) elif isinstance(task_input, BaseArtifact): diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 5727a499b..4fd8a34db 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -33,7 +33,7 @@ def prompt_stack(self) -> PromptStack: stack.add_user_input(self.input) if self.output: - stack.add_assistant_input(self.output.to_text()) + stack.add_assistant_input(self.output) if memory: # inserting at index 1 to place memory right after system prompt diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index de291d1a9..60eebc405 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -67,7 +67,7 @@ def prompt_stack(self) -> PromptStack: stack.add_system_input(self.generate_system_template(self)) - stack.add_user_input(self.input.to_text()) + stack.add_user_input(self.input) if self.output: stack.add_assistant_input(self.output.to_text()) diff --git a/griptape/tools/prompt_image_generation_client/tool.py b/griptape/tools/prompt_image_generation_client/tool.py index 50020a1ea..6f4ce0e1f 100644 --- a/griptape/tools/prompt_image_generation_client/tool.py +++ b/griptape/tools/prompt_image_generation_client/tool.py @@ -30,20 +30,15 @@ class PromptImageGenerationClient(BlobArtifactFileOutputMixin, BaseTool): Literal( "prompts", description="A detailed list of features and descriptions to include in the generated image.", - ): list[str], - Literal( - "negative_prompts", - description="A detailed list of features and descriptions to avoid in the generated image.", - ): list[str], + ): list[str] } ), } ) def generate_image(self, params: dict[str, dict[str, list[str]]]) -> ImageArtifact | ErrorArtifact: prompts = params["values"]["prompts"] - negative_prompts = params["values"]["negative_prompts"] - output_artifact = self.engine.run(prompts=prompts, negative_prompts=negative_prompts) + output_artifact = self.engine.run(prompts=prompts) if self.output_dir or self.output_file: self._write_to_file(output_artifact) diff --git a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py index ef3b0e1df..80d77d24d 100644 --- a/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_dynamodb_conversation_memory_driver.py @@ -81,5 +81,5 @@ def test_load(self): assert new_memory.type == "ConversationMemory" assert len(new_memory.runs) == 2 - assert new_memory.runs[0].input == "test" - assert new_memory.runs[0].output == "mock output" + assert new_memory.runs[0].input.value == "test" + assert new_memory.runs[0].output.value == "mock output" diff --git a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py index d12d5d3d2..c794afd0e 100644 --- a/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_local_conversation_memory_driver.py @@ -52,8 +52,8 @@ def test_load(self): assert new_memory.type == "ConversationMemory" assert len(new_memory.runs) == 2 - assert new_memory.runs[0].input == "test" - assert new_memory.runs[0].output == "mock output" + assert new_memory.runs[0].input.value == "test" + assert new_memory.runs[0].output.value == "mock output" assert new_memory.max_runs == 5 def test_autoload(self): @@ -71,8 +71,8 @@ def test_autoload(self): assert autoloaded_memory.type == "ConversationMemory" assert len(autoloaded_memory.runs) == 2 - assert autoloaded_memory.runs[0].input == "test" - assert autoloaded_memory.runs[0].output == "mock output" + assert autoloaded_memory.runs[0].input.value == "test" + assert autoloaded_memory.runs[0].output.value == "mock output" def __delete_file(self, file_path): try: diff --git a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py index dee840508..7a74b6921 100644 --- a/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_redis_conversation_memory_driver.py @@ -3,7 +3,7 @@ from griptape.memory.structure.base_conversation_memory import BaseConversationMemory from griptape.drivers.memory.conversation.redis_conversation_memory_driver import RedisConversationMemoryDriver -TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": "Hi There, Hello", "output": "Hello! How can I assist you today?"}], "max_runs": 2}' +TEST_CONVERSATION = '{"type": "ConversationMemory", "runs": [{"type": "Run", "id": "729ca6be5d79433d9762eb06dfd677e2", "input": {"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}, "output": {"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}}], "max_runs": 2}' CONVERSATION_ID = "117151897f344ff684b553d0655d8f39" INDEX = "griptape_converstaion" HOST = "127.0.0.1" diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 8909956fe..642ca4354 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -6,12 +6,13 @@ from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_tokenizer import MockTokenizer from griptape.tasks import PromptTask +from griptape.artifacts import TextArtifact class TestConversationMemory: def test_add_run(self): memory = ConversationMemory() - run = Run(input="test", output="test") + run = Run(input=TextArtifact("foo"), output=TextArtifact("bar")) memory.add_run(run) @@ -19,21 +20,21 @@ def test_add_run(self): def test_to_json(self): memory = ConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) assert json.loads(memory.to_json())["type"] == "ConversationMemory" - assert json.loads(memory.to_json())["runs"][0]["input"] == "foo" + assert json.loads(memory.to_json())["runs"][0]["input"]["value"] == "foo" def test_to_dict(self): memory = ConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) assert memory.to_dict()["type"] == "ConversationMemory" - assert memory.to_dict()["runs"][0]["input"] == "foo" + assert memory.to_dict()["runs"][0]["input"]["value"] == "foo" def test_to_prompt_stack(self): memory = ConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) prompt_stack = memory.to_prompt_stack() @@ -42,19 +43,19 @@ def test_to_prompt_stack(self): def test_from_dict(self): memory = ConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) memory_dict = memory.to_dict() assert isinstance(BaseConversationMemory.from_dict(memory_dict), ConversationMemory) - assert BaseConversationMemory.from_dict(memory_dict).runs[0].input == "foo" + assert BaseConversationMemory.from_dict(memory_dict).runs[0].input.value == "foo" def test_from_json(self): memory = ConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) memory_dict = memory.to_dict() assert isinstance(memory.from_dict(memory_dict), ConversationMemory) - assert memory.from_dict(memory_dict).runs[0].input == "foo" + assert memory.from_dict(memory_dict).runs[0].input.value == "foo" def test_buffering(self): memory = ConversationMemory(max_runs=2) @@ -70,24 +71,24 @@ def test_buffering(self): pipeline.run("run5") assert len(pipeline.conversation_memory.runs) == 2 - assert pipeline.conversation_memory.runs[0].input == "run4" - assert pipeline.conversation_memory.runs[1].input == "run5" + assert pipeline.conversation_memory.runs[0].input.value == "run4" + assert pipeline.conversation_memory.runs[1].input.value == "run5" def test_add_to_prompt_stack_autopruing_disabled(self): agent = Agent(prompt_driver=MockPromptDriver()) memory = ConversationMemory( autoprune=False, runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), + Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")), + Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")), + Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")), + Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")), + Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")), ], ) memory.structure = agent prompt_stack = PromptStack() - prompt_stack.add_user_input("foo") + prompt_stack.add_user_input(TextArtifact("foo")) prompt_stack.add_assistant_input("bar") memory.add_to_prompt_stack(prompt_stack) @@ -99,11 +100,11 @@ def test_add_to_prompt_stack_autopruning_enabled(self): memory = ConversationMemory( autoprune=True, runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), + Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")), + Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")), + Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")), + Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")), + Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")), ], ) memory.structure = agent @@ -120,11 +121,11 @@ def test_add_to_prompt_stack_autopruning_enabled(self): memory = ConversationMemory( autoprune=True, runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), + Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")), + Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")), + Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")), + Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")), + Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")), ], ) memory.structure = agent @@ -144,11 +145,11 @@ def test_add_to_prompt_stack_autopruning_enabled(self): autoprune=True, runs=[ # All of these sum to 155 tokens with the MockTokenizer. - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), + Run(input=TextArtifact("foo1"), output=TextArtifact("bar1")), + Run(input=TextArtifact("foo2"), output=TextArtifact("bar2")), + Run(input=TextArtifact("foo3"), output=TextArtifact("bar3")), + Run(input=TextArtifact("foo4"), output=TextArtifact("bar4")), + Run(input=TextArtifact("foo5"), output=TextArtifact("bar5")), ], ) memory.structure = agent diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index 236689284..5ca99f07a 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -3,6 +3,7 @@ from griptape.memory.structure import Run, SummaryConversationMemory from griptape.structures import Pipeline +from griptape.artifacts import TextArtifact from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_structure_config import MockStructureConfig @@ -40,21 +41,21 @@ def test_after_run(self): def test_to_json(self): memory = SummaryConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) assert json.loads(memory.to_json())["type"] == "SummaryConversationMemory" - assert json.loads(memory.to_json())["runs"][0]["input"] == "foo" + assert json.loads(memory.to_json())["runs"][0]["input"]["value"] == "foo" def test_to_dict(self): memory = SummaryConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) assert memory.to_dict()["type"] == "SummaryConversationMemory" - assert memory.to_dict()["runs"][0]["input"] == "foo" + assert memory.to_dict()["runs"][0]["input"]["value"] == "foo" def test_to_prompt_stack(self): memory = SummaryConversationMemory(summary="foobar") - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) prompt_stack = memory.to_prompt_stack() @@ -64,12 +65,12 @@ def test_to_prompt_stack(self): def test_from_dict(self): memory = SummaryConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) memory_dict = memory.to_dict() assert isinstance(memory.from_dict(memory_dict), SummaryConversationMemory) - assert memory.from_dict(memory_dict).runs[0].input == "foo" - assert memory.from_dict(memory_dict).runs[0].output == "bar" + assert memory.from_dict(memory_dict).runs[0].input.value == "foo" + assert memory.from_dict(memory_dict).runs[0].output.value == "bar" assert memory.from_dict(memory_dict).offset == memory.offset assert memory.from_dict(memory_dict).summary == memory.summary assert memory.from_dict(memory_dict).summary_index == memory.summary_index @@ -77,11 +78,11 @@ def test_from_dict(self): def test_from_json(self): memory = SummaryConversationMemory() - memory.add_run(Run(input="foo", output="bar")) + memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) memory_dict = memory.to_dict() assert isinstance(memory.from_dict(memory_dict), SummaryConversationMemory) - assert memory.from_dict(memory_dict).runs[0].input == "foo" + assert memory.from_dict(memory_dict).runs[0].input.value == "foo" def test_config_prompt_driver(self): memory = SummaryConversationMemory() From 524a30d4fd58c6908f6078ecc82ebd801735ef13 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 14 Jun 2024 16:47:58 -0700 Subject: [PATCH 03/34] Update default artifact --- griptape/structures/agent.py | 5 +++-- griptape/tasks/base_text_input_task.py | 5 ++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index f767d74a2..22937338c 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -1,10 +1,11 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional, Callable from attrs import define, field +from griptape.artifacts.text_artifact import TextArtifact from griptape.tools import BaseTool from griptape.memory.structure import Run from griptape.structures import Structure -from griptape.tasks import PromptTask, ToolkitTask, BaseTextInputTask +from griptape.tasks import PromptTask, ToolkitTask from griptape.artifacts import BaseArtifact if TYPE_CHECKING: @@ -14,7 +15,7 @@ @define class Agent(Structure): input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( - default=BaseTextInputTask.DEFAULT_INPUT_TEMPLATE + default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value="") ) tools: list[BaseTool] = field(factory=list, kw_only=True) max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index 40b6e75c5..f0b5bd9e0 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -14,10 +14,9 @@ @define class BaseTextInputTask(RuleMixin, BaseTask, ABC): - DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" - _input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( - default=DEFAULT_INPUT_TEMPLATE, alias="input" + default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), + alias="input", ) @property From 7f382889d99af11bf92e3171e6bdd2b02077e700 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 18 Jun 2024 10:27:28 -0700 Subject: [PATCH 04/34] Fix bad merge --- griptape/drivers/prompt/anthropic_prompt_driver.py | 11 ++++------- tests/mocks/mock_prompt_driver.py | 4 ++-- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 5b700c35a..bc2e0fdd4 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -125,20 +125,17 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent raise ValueError(f"Unsupported prompt content type: {type(content)}") def __message_content_to_prompt_stack_content(self, content: ContentBlock) -> BasePromptStackContent: - content_type = content.type - - if content_type == "text": + if content.type == "text": return TextPromptStackContent(TextArtifact(content.text)) else: - raise ValueError(f"Unsupported message content type: {content_type}") + raise ValueError(f"Unsupported message content type: {content.type}") def __message_content_delta_to_prompt_stack_content_delta( self, content_delta: ContentBlockDeltaEvent ) -> BaseDeltaPromptStackContent: index = content_delta.index - delta_type = content_delta.delta.type - if delta_type == "text_delta": + if content_delta.delta.type == "text_delta": return DeltaTextPromptStackContent(content_delta.delta.text, index=index) else: - raise ValueError(f"Unsupported message content delta type : {delta_type}") + raise ValueError(f"Unsupported message content delta type : {content_delta.delta.type}") diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 7f59e6bd8..0c1f9e00c 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -28,7 +28,7 @@ class MockPromptDriver(BasePromptDriver): mock_output: str | Callable[[PromptStack], str] = field(default="mock output", kw_only=True) def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: - output = self.mock_output() if isinstance(self.mock_output, Callable) else self.mock_output + output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output return PromptStackElement( content=[TextPromptStackContent(TextArtifact(output))], @@ -37,7 +37,7 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: ) def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: - output = self.mock_output() if isinstance(self.mock_output, Callable) else self.mock_output + output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output yield DeltaTextPromptStackContent(output) yield DeltaPromptStackElement( From 2380b37604e68b277f725d67a2a12acd50f48308 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 18 Jun 2024 13:55:40 -0700 Subject: [PATCH 05/34] Rename Prompt Stack Element to Prompt Stack Message --- docs/griptape-framework/misc/events.md | 2 +- griptape/common/__init__.py | 12 ++--- .../base_prompt_stack_message.py} | 2 +- .../delta_prompt_stack_message.py} | 8 ++-- .../prompt_stack_message.py} | 6 +-- griptape/common/prompt_stack/prompt_stack.py | 44 +++++++++---------- .../prompt/amazon_bedrock_prompt_driver.py | 30 ++++++------- ...mazon_sagemaker_jumpstart_prompt_driver.py | 16 +++---- .../drivers/prompt/anthropic_prompt_driver.py | 30 ++++++------- griptape/drivers/prompt/base_prompt_driver.py | 26 +++++------ .../drivers/prompt/cohere_prompt_driver.py | 32 +++++++------- .../drivers/prompt/dummy_prompt_driver.py | 10 ++--- .../drivers/prompt/google_prompt_driver.py | 28 ++++++------ .../prompt/huggingface_hub_prompt_driver.py | 22 +++++----- .../huggingface_pipeline_prompt_driver.py | 16 +++---- .../prompt/openai_chat_prompt_driver.py | 24 +++++----- .../extraction/csv_extraction_engine.py | 6 +-- .../extraction/json_extraction_engine.py | 6 +-- griptape/engines/query/vector_query_engine.py | 14 +++--- .../engines/summary/prompt_summary_engine.py | 14 +++--- .../structure/base_conversation_memory.py | 14 +++--- .../memory/structure/conversation_memory.py | 4 +- .../structure/summary_conversation_memory.py | 10 ++--- griptape/schemas/base_schema.py | 4 +- griptape/tasks/prompt_task.py | 6 +-- griptape/tasks/toolkit_task.py | 10 ++--- griptape/utils/conversation.py | 2 +- tests/mocks/mock_failing_prompt_driver.py | 18 ++++---- tests/mocks/mock_prompt_driver.py | 18 ++++---- .../test_amazon_bedrock_prompt_driver.py | 6 +-- ...mazon_sagemaker_jumpstart_prompt_driver.py | 6 +-- .../prompt/test_anthropic_prompt_driver.py | 14 +++--- .../drivers/prompt/test_base_prompt_driver.py | 2 +- .../prompt/test_cohere_prompt_driver.py | 6 +-- .../prompt/test_google_prompt_driver.py | 12 ++--- .../test_hugging_face_hub_prompt_driver.py | 6 +-- ...est_hugging_face_pipeline_prompt_driver.py | 6 +-- .../prompt/test_openai_chat_prompt_driver.py | 8 ++-- .../summary/test_prompt_summary_engine.py | 4 +- tests/unit/events/test_base_event.py | 14 +++--- tests/unit/events/test_start_prompt_event.py | 12 ++--- .../structure/test_conversation_memory.py | 44 +++++++++---------- .../test_summary_conversation_memory.py | 6 +-- tests/unit/structures/test_agent.py | 12 ++--- tests/unit/structures/test_pipeline.py | 24 +++++----- .../unit/tokenizers/test_google_tokenizer.py | 6 ++- tests/unit/utils/test_prompt_stack.py | 32 +++++++------- 47 files changed, 328 insertions(+), 326 deletions(-) rename griptape/common/prompt_stack/{elements/base_prompt_stack_element.py => messages/base_prompt_stack_message.py} (90%) rename griptape/common/prompt_stack/{elements/delta_prompt_stack_element.py => messages/delta_prompt_stack_message.py} (80%) rename griptape/common/prompt_stack/{elements/prompt_stack_element.py => messages/prompt_stack_message.py} (90%) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 0921676fb..226e96741 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -245,7 +245,7 @@ from griptape.events import BaseEvent, StartPromptEvent, EventListener def handler(event: BaseEvent): if isinstance(event, StartPromptEvent): print("Prompt Stack Inputs:") - for input in event.prompt_stack.inputs: + for input in event.prompt_stack.messages: print(f"{input.role}: {input.content}") print("Final Prompt String:") print(event.prompt) diff --git a/griptape/common/__init__.py b/griptape/common/__init__.py index 303c52db6..db5e011e1 100644 --- a/griptape/common/__init__.py +++ b/griptape/common/__init__.py @@ -4,18 +4,18 @@ from .prompt_stack.contents.text_prompt_stack_content import TextPromptStackContent from .prompt_stack.contents.image_prompt_stack_content import ImagePromptStackContent -from .prompt_stack.elements.base_prompt_stack_element import BasePromptStackElement -from .prompt_stack.elements.delta_prompt_stack_element import DeltaPromptStackElement -from .prompt_stack.elements.prompt_stack_element import PromptStackElement +from .prompt_stack.messages.base_prompt_stack_message import BasePromptStackMessage +from .prompt_stack.messages.delta_prompt_stack_message import DeltaPromptStackMessage +from .prompt_stack.messages.prompt_stack_message import PromptStackMessage from .prompt_stack.prompt_stack import PromptStack __all__ = [ - "BasePromptStackElement", + "BasePromptStackMessage", "BaseDeltaPromptStackContent", "BasePromptStackContent", - "DeltaPromptStackElement", - "PromptStackElement", + "DeltaPromptStackMessage", + "PromptStackMessage", "DeltaTextPromptStackContent", "TextPromptStackContent", "ImagePromptStackContent", diff --git a/griptape/common/prompt_stack/elements/base_prompt_stack_element.py b/griptape/common/prompt_stack/messages/base_prompt_stack_message.py similarity index 90% rename from griptape/common/prompt_stack/elements/base_prompt_stack_element.py rename to griptape/common/prompt_stack/messages/base_prompt_stack_message.py index 83d9daaac..ac6d1365b 100644 --- a/griptape/common/prompt_stack/elements/base_prompt_stack_element.py +++ b/griptape/common/prompt_stack/messages/base_prompt_stack_message.py @@ -8,7 +8,7 @@ @define -class BasePromptStackElement(ABC, SerializableMixin): +class BasePromptStackMessage(ABC, SerializableMixin): USER_ROLE = "user" ASSISTANT_ROLE = "assistant" SYSTEM_ROLE = "system" diff --git a/griptape/common/prompt_stack/elements/delta_prompt_stack_element.py b/griptape/common/prompt_stack/messages/delta_prompt_stack_message.py similarity index 80% rename from griptape/common/prompt_stack/elements/delta_prompt_stack_element.py rename to griptape/common/prompt_stack/messages/delta_prompt_stack_message.py index be2dbd500..bbc18a024 100644 --- a/griptape/common/prompt_stack/elements/delta_prompt_stack_element.py +++ b/griptape/common/prompt_stack/messages/delta_prompt_stack_message.py @@ -6,11 +6,11 @@ from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent -from .base_prompt_stack_element import BasePromptStackElement +from .base_prompt_stack_message import BasePromptStackMessage @define -class DeltaPromptStackElement(BasePromptStackElement): +class DeltaPromptStackMessage(BasePromptStackMessage): @define class DeltaUsage: input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) @@ -20,8 +20,8 @@ class DeltaUsage: def total_tokens(self) -> float: return (self.input_tokens or 0) + (self.output_tokens or 0) - def __add__(self, other: DeltaPromptStackElement.DeltaUsage) -> DeltaPromptStackElement.DeltaUsage: - return DeltaPromptStackElement.DeltaUsage( + def __add__(self, other: DeltaPromptStackMessage.DeltaUsage) -> DeltaPromptStackMessage.DeltaUsage: + return DeltaPromptStackMessage.DeltaUsage( input_tokens=(self.input_tokens or 0) + (other.input_tokens or 0), output_tokens=(self.output_tokens or 0) + (other.output_tokens or 0), ) diff --git a/griptape/common/prompt_stack/elements/prompt_stack_element.py b/griptape/common/prompt_stack/messages/prompt_stack_message.py similarity index 90% rename from griptape/common/prompt_stack/elements/prompt_stack_element.py rename to griptape/common/prompt_stack/messages/prompt_stack_message.py index b94c8a5f4..bffde91de 100644 --- a/griptape/common/prompt_stack/elements/prompt_stack_element.py +++ b/griptape/common/prompt_stack/messages/prompt_stack_message.py @@ -8,11 +8,11 @@ from griptape.common import BasePromptStackContent, TextPromptStackContent from griptape.mixins.serializable_mixin import SerializableMixin -from .base_prompt_stack_element import BasePromptStackElement +from .base_prompt_stack_message import BasePromptStackMessage @define -class PromptStackElement(BasePromptStackElement): +class PromptStackMessage(BasePromptStackMessage): @define class Usage(SerializableMixin): input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) @@ -29,7 +29,7 @@ def __init__(self, content: str | list[BasePromptStackContent], **kwargs: Any): content: list[BasePromptStackContent] = field(metadata={"serializable": True}) usage: Usage = field( - kw_only=True, default=Factory(lambda: PromptStackElement.Usage()), metadata={"serializable": True} + kw_only=True, default=Factory(lambda: PromptStackMessage.Usage()), metadata={"serializable": True} ) @property diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/prompt_stack/prompt_stack.py index a82c47216..9091ca708 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -3,42 +3,42 @@ from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact, ImageArtifact from griptape.mixins import SerializableMixin -from griptape.common import PromptStackElement, TextPromptStackContent, BasePromptStackContent, ImagePromptStackContent +from griptape.common import PromptStackMessage, TextPromptStackContent, BasePromptStackContent, ImagePromptStackContent @define class PromptStack(SerializableMixin): - inputs: list[PromptStackElement] = field(factory=list, kw_only=True, metadata={"serializable": True}) + messages: list[PromptStackMessage] = field(factory=list, kw_only=True, metadata={"serializable": True}) - def add_input(self, content: str | BaseArtifact, role: str) -> PromptStackElement: - new_content = self.__process_content(content) + def add_message(self, artifact: str | BaseArtifact, role: str) -> PromptStackMessage: + new_content = self.__process_artifact(artifact) - self.inputs.append(PromptStackElement(content=new_content, role=role)) + self.messages.append(PromptStackMessage(content=new_content, role=role)) - return self.inputs[-1] + return self.messages[-1] - def add_system_input(self, content: str | BaseArtifact) -> PromptStackElement: - return self.add_input(content, PromptStackElement.SYSTEM_ROLE) + def add_system_message(self, artifact: str | BaseArtifact) -> PromptStackMessage: + return self.add_message(artifact, PromptStackMessage.SYSTEM_ROLE) - def add_user_input(self, content: str | BaseArtifact) -> PromptStackElement: - return self.add_input(content, PromptStackElement.USER_ROLE) + def add_user_message(self, artifact: str | BaseArtifact) -> PromptStackMessage: + return self.add_message(artifact, PromptStackMessage.USER_ROLE) - def add_assistant_input(self, content: str | BaseArtifact) -> PromptStackElement: - return self.add_input(content, PromptStackElement.ASSISTANT_ROLE) + def add_assistant_message(self, artifact: str | BaseArtifact) -> PromptStackMessage: + return self.add_message(artifact, PromptStackMessage.ASSISTANT_ROLE) - def __process_content(self, content: str | BaseArtifact) -> list[BasePromptStackContent]: - if isinstance(content, str): - return [TextPromptStackContent(TextArtifact(content))] - elif isinstance(content, TextArtifact): - return [TextPromptStackContent(content)] - elif isinstance(content, ImageArtifact): - return [ImagePromptStackContent(content)] - elif isinstance(content, ListArtifact): - processed_contents = [self.__process_content(artifact) for artifact in content.value] + def __process_artifact(self, artifact: str | BaseArtifact) -> list[BasePromptStackContent]: + if isinstance(artifact, str): + return [TextPromptStackContent(TextArtifact(artifact))] + elif isinstance(artifact, TextArtifact): + return [TextPromptStackContent(artifact)] + elif isinstance(artifact, ImageArtifact): + return [ImagePromptStackContent(artifact)] + elif isinstance(artifact, ListArtifact): + processed_contents = [self.__process_artifact(artifact) for artifact in artifact.value] flattened_content = [ sub_content for processed_content in processed_contents for sub_content in processed_content ] return flattened_content else: - raise ValueError(f"Unsupported content type: {type(content)}") + raise ValueError(f"Unsupported artifact type: {type(artifact)}") diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index d3c13a5da..4446c852f 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -8,8 +8,8 @@ from griptape.artifacts import TextArtifact from griptape.common import ( BaseDeltaPromptStackContent, - DeltaPromptStackElement, - PromptStackElement, + DeltaPromptStackMessage, + PromptStackMessage, DeltaTextPromptStackContent, BasePromptStackContent, TextPromptStackContent, @@ -36,26 +36,26 @@ class AmazonBedrockPromptDriver(BasePromptDriver): default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True ) - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: response = self.bedrock_client.converse(**self._base_params(prompt_stack)) usage = response["usage"] output_message = response["output"]["message"] - return PromptStackElement( + return PromptStackMessage( content=[TextPromptStackContent(TextArtifact(content["text"])) for content in output_message["content"]], - role=PromptStackElement.ASSISTANT_ROLE, - usage=PromptStackElement.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]), + role=PromptStackMessage.ASSISTANT_ROLE, + usage=PromptStackMessage.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack)) stream = response.get("stream") if stream is not None: for event in stream: if "messageStart" in event: - yield DeltaPromptStackElement(role=event["messageStart"]["role"]) + yield DeltaPromptStackMessage(role=event["messageStart"]["role"]) elif "contentBlockDelta" in event: content_block_delta = event["contentBlockDelta"] yield DeltaTextPromptStackContent( @@ -63,15 +63,15 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElem ) elif "metadata" in event: usage = event["metadata"]["usage"] - yield DeltaPromptStackElement( - delta_usage=DeltaPromptStackElement.DeltaUsage( + yield DeltaPromptStackMessage( + delta_usage=DeltaPromptStackMessage.DeltaUsage( input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"] ) ) else: raise Exception("model response is empty") - def _prompt_stack_elements_to_messages(self, elements: list[PromptStackElement]) -> list[dict]: + def _prompt_stack_messages_to_messages(self, elements: list[PromptStackMessage]) -> list[dict]: return [ { "role": self.__to_role(input), @@ -82,11 +82,11 @@ def _prompt_stack_elements_to_messages(self, elements: list[PromptStackElement]) def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = [ - {"text": input.to_text_artifact().to_text()} for input in prompt_stack.inputs if input.is_system() + {"text": input.to_text_artifact().to_text()} for input in prompt_stack.messages if input.is_system() ] - messages = self._prompt_stack_elements_to_messages( - [input for input in prompt_stack.inputs if not input.is_system()] + messages = self._prompt_stack_messages_to_messages( + [input for input in prompt_stack.messages if not input.is_system()] ) return { @@ -105,7 +105,7 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent else: raise ValueError(f"Unsupported content type: {type(content)}") - def __to_role(self, input: PromptStackElement) -> str: + def __to_role(self, input: PromptStackMessage) -> str: if input.is_system(): return "system" elif input.is_assistant(): diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index 596606747..f1858bc46 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -9,9 +9,9 @@ from griptape.artifacts import TextArtifact from griptape.common import ( PromptStack, - PromptStackElement, + PromptStackMessage, TextPromptStackContent, - DeltaPromptStackElement, + DeltaPromptStackMessage, BaseDeltaPromptStackContent, ) from griptape.drivers import BasePromptDriver @@ -47,7 +47,7 @@ def validate_stream(self, _, stream): if stream: raise ValueError("streaming is not supported") - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: payload = { "inputs": self.prompt_stack_to_string(prompt_stack), "parameters": {**self._base_params(prompt_stack)}, @@ -78,13 +78,13 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) - return PromptStackElement( + return PromptStackMessage( content=[TextPromptStackContent(TextArtifact(generated_text))], - role=PromptStackElement.ASSISTANT_ROLE, - usage=PromptStackElement.Usage(input_tokens=input_tokens, output_tokens=output_tokens), + role=PromptStackMessage.ASSISTANT_ROLE, + usage=PromptStackMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: raise NotImplementedError("streaming is not supported") def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: @@ -103,7 +103,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] - for input in prompt_stack.inputs: + for input in prompt_stack.messages: messages.append({"role": input.role, "content": TextPromptStackContent(input.to_text_artifact())}) return messages diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index bc2e0fdd4..5b8b7be74 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -9,11 +9,11 @@ from griptape.common import ( BaseDeltaPromptStackContent, BasePromptStackContent, - DeltaPromptStackElement, + DeltaPromptStackMessage, DeltaTextPromptStackContent, ImagePromptStackContent, PromptStack, - PromptStackElement, + PromptStackMessage, TextPromptStackContent, ) from griptape.drivers import BasePromptDriver @@ -49,40 +49,40 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: response = self.client.messages.create(**self._base_params(prompt_stack)) - return PromptStackElement( + return PromptStackMessage( content=[self.__message_content_to_prompt_stack_content(content) for content in response.content], role=response.role, - usage=PromptStackElement.Usage( + usage=PromptStackMessage.Usage( input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens ), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: events = self.client.messages.create(**self._base_params(prompt_stack), stream=True) for event in events: if event.type == "content_block_delta": yield self.__message_content_delta_to_prompt_stack_content_delta(event) elif event.type == "message_start": - yield DeltaPromptStackElement( + yield DeltaPromptStackMessage( role=event.message.role, - delta_usage=DeltaPromptStackElement.DeltaUsage(input_tokens=event.message.usage.input_tokens), + delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=event.message.usage.input_tokens), ) elif event.type == "message_delta": - yield DeltaPromptStackElement( - delta_usage=DeltaPromptStackElement.DeltaUsage(output_tokens=event.usage.output_tokens) + yield DeltaPromptStackMessage( + delta_usage=DeltaPromptStackMessage.DeltaUsage(output_tokens=event.usage.output_tokens) ) - def _prompt_stack_elements_to_messages(self, elements: list[PromptStackElement]) -> list[dict]: + def _prompt_stack_messages_to_messages(self, elements: list[PromptStackMessage]) -> list[dict]: return [{"role": self.__to_role(input), "content": self.__to_content(input)} for input in elements] def _base_params(self, prompt_stack: PromptStack) -> dict: - messages = self._prompt_stack_elements_to_messages([i for i in prompt_stack.inputs if not i.is_system()]) + messages = self._prompt_stack_messages_to_messages([i for i in prompt_stack.messages if not i.is_system()]) - system_element = next((i for i in prompt_stack.inputs if i.is_system()), None) + system_element = next((i for i in prompt_stack.messages if i.is_system()), None) if system_element: system_message = system_element.to_text_artifact().to_text() else: @@ -99,7 +99,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"system": system_message} if system_message else {}), } - def __to_role(self, input: PromptStackElement) -> str: + def __to_role(self, input: PromptStackMessage) -> str: if input.is_system(): return "system" elif input.is_assistant(): @@ -107,7 +107,7 @@ def __to_role(self, input: PromptStackElement) -> str: else: return "user" - def __to_content(self, input: PromptStackElement) -> str | list[dict]: + def __to_content(self, input: PromptStackMessage) -> str | list[dict]: if all(isinstance(content, TextPromptStackContent) for content in input.content): return input.to_text_artifact().to_text() else: diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 89539028e..2373aeeb8 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -9,10 +9,10 @@ from griptape.artifacts.text_artifact import TextArtifact from griptape.common import ( BaseDeltaPromptStackContent, - DeltaPromptStackElement, + DeltaPromptStackMessage, DeltaTextPromptStackContent, PromptStack, - PromptStackElement, + PromptStackMessage, TextPromptStackContent, ) from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent @@ -52,7 +52,7 @@ def before_run(self, prompt_stack: PromptStack) -> None: if self.structure: self.structure.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) - def after_run(self, result: PromptStackElement) -> None: + def after_run(self, result: PromptStackMessage) -> None: if self.structure: self.structure.publish_event( FinishPromptEvent( @@ -91,7 +91,7 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: """ prompt_lines = [] - for i in prompt_stack.inputs: + for i in prompt_stack.messages: content = i.to_text_artifact().to_text() if i.is_user(): prompt_lines.append(f"User: {content}") @@ -105,26 +105,26 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return "\n\n".join(prompt_lines) @abstractmethod - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: ... + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: ... @abstractmethod def try_stream( self, prompt_stack: PromptStack - ) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: ... + ) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: ... - def __process_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def __process_run(self, prompt_stack: PromptStack) -> PromptStackMessage: result = self.try_run(prompt_stack) return result - def __process_stream(self, prompt_stack: PromptStack) -> PromptStackElement: + def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage: delta_contents: dict[int, list[BaseDeltaPromptStackContent]] = {} - delta_usage = DeltaPromptStackElement.DeltaUsage() + delta_usage = DeltaPromptStackMessage.DeltaUsage() deltas = self.try_stream(prompt_stack) for delta in deltas: - if isinstance(delta, DeltaPromptStackElement): + if isinstance(delta, DeltaPromptStackMessage): delta_usage += delta.delta_usage elif isinstance(delta, BaseDeltaPromptStackContent): if delta.index in delta_contents: @@ -141,10 +141,10 @@ def __process_stream(self, prompt_stack: PromptStack) -> PromptStackElement: if text_deltas: content.append(TextPromptStackContent.from_deltas(text_deltas)) - result = PromptStackElement( + result = PromptStackMessage( content=content, - role=PromptStackElement.ASSISTANT_ROLE, - usage=PromptStackElement.Usage( + role=PromptStackMessage.ASSISTANT_ROLE, + usage=PromptStackMessage.Usage( input_tokens=delta_usage.input_tokens or 0, output_tokens=delta_usage.output_tokens or 0 ), ) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 047ea5a35..8c951587f 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -7,8 +7,8 @@ from griptape.tokenizers import CohereTokenizer from griptape.common import ( PromptStack, - PromptStackElement, - DeltaPromptStackElement, + PromptStackMessage, + DeltaPromptStackMessage, BaseDeltaPromptStackContent, TextPromptStackContent, BasePromptStackContent, @@ -41,17 +41,17 @@ class CoherePromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: result = self.client.chat(**self._base_params(prompt_stack)) usage = result.meta.tokens - return PromptStackElement( + return PromptStackMessage( content=[TextPromptStackContent(TextArtifact(result.text))], - role=PromptStackElement.ASSISTANT_ROLE, - usage=PromptStackElement.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), + role=PromptStackMessage.ASSISTANT_ROLE, + usage=PromptStackMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: result = self.client.chat_stream(**self._base_params(prompt_stack)) for event in result: @@ -60,14 +60,14 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElem if event.event_type == "stream-end": usage = event.response.meta.tokens - yield DeltaPromptStackElement( - role=PromptStackElement.ASSISTANT_ROLE, - delta_usage=DeltaPromptStackElement.DeltaUsage( + yield DeltaPromptStackMessage( + role=PromptStackMessage.ASSISTANT_ROLE, + delta_usage=DeltaPromptStackMessage.DeltaUsage( input_tokens=usage.input_tokens, output_tokens=usage.output_tokens ), ) - def _prompt_stack_elements_to_messages(self, elements: list[PromptStackElement]) -> list[dict]: + def _prompt_stack_messages_to_messages(self, elements: list[PromptStackMessage]) -> list[dict]: return [ { "role": self.__to_role(input), @@ -77,17 +77,17 @@ def _prompt_stack_elements_to_messages(self, elements: list[PromptStackElement]) ] def _base_params(self, prompt_stack: PromptStack) -> dict: - last_input = prompt_stack.inputs[-1] + last_input = prompt_stack.messages[-1] if last_input is not None and len(last_input.content) == 1: user_message = last_input.content[0].artifact.to_text() else: raise ValueError("User element must have exactly one content.") - history_messages = self._prompt_stack_elements_to_messages( - [input for input in prompt_stack.inputs[:-1] if not input.is_system()] + history_messages = self._prompt_stack_messages_to_messages( + [input for input in prompt_stack.messages[:-1] if not input.is_system()] ) - system_element = next((input for input in prompt_stack.inputs if input.is_system()), None) + system_element = next((input for input in prompt_stack.messages if input.is_system()), None) if system_element is not None: if len(system_element.content) == 1: preamble = system_element.content[0].artifact.to_text() @@ -111,7 +111,7 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent else: raise ValueError(f"Unsupported content type: {type(content)}") - def __to_role(self, input: PromptStackElement) -> str: + def __to_role(self, input: PromptStackMessage) -> str: if input.is_system(): return "SYSTEM" elif input.is_user(): diff --git a/griptape/drivers/prompt/dummy_prompt_driver.py b/griptape/drivers/prompt/dummy_prompt_driver.py index 2c9794fbb..48eb95b25 100644 --- a/griptape/drivers/prompt/dummy_prompt_driver.py +++ b/griptape/drivers/prompt/dummy_prompt_driver.py @@ -7,8 +7,8 @@ from griptape.common import ( BasePromptStackContent, PromptStack, - PromptStackElement, - DeltaPromptStackElement, + PromptStackMessage, + DeltaPromptStackMessage, BaseDeltaPromptStackContent, ) from griptape.drivers import BasePromptDriver @@ -21,13 +21,13 @@ class DummyPromptDriver(BasePromptDriver): model: None = field(init=False, default=None, kw_only=True) tokenizer: DummyTokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True) - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: raise DummyException(__class__.__name__, "try_run") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: raise DummyException(__class__.__name__, "try_stream") - def _prompt_stack_input_to_message(self, prompt_input: PromptStackElement) -> dict: + def _prompt_stack_input_to_message(self, prompt_input: PromptStackMessage) -> dict: raise DummyException(__class__.__name__, "_prompt_stack_input_to_message") def _prompt_stack_content_to_message_content(self, content: BasePromptStackContent) -> Any: diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 664fd1bec..6a4a99859 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -9,11 +9,11 @@ from griptape.common import ( BaseDeltaPromptStackContent, BasePromptStackContent, - DeltaPromptStackElement, + DeltaPromptStackMessage, DeltaTextPromptStackContent, ImagePromptStackContent, PromptStack, - PromptStackElement, + PromptStackMessage, TextPromptStackContent, ) from griptape.drivers import BasePromptDriver @@ -46,7 +46,7 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig messages = self._prompt_stack_to_messages(prompt_stack) @@ -63,15 +63,15 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: usage_metadata = response.usage_metadata - return PromptStackElement( + return PromptStackMessage( content=[TextPromptStackContent(TextArtifact(response.text))], - role=PromptStackElement.ASSISTANT_ROLE, - usage=PromptStackElement.Usage( + role=PromptStackMessage.ASSISTANT_ROLE, + usage=PromptStackMessage.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count ), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig messages = self._prompt_stack_to_messages(prompt_stack) @@ -93,9 +93,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElem yield DeltaTextPromptStackContent(chunk.text) # TODO: Only yield the first one - yield DeltaPromptStackElement( - role=PromptStackElement.ASSISTANT_ROLE, - delta_usage=DeltaPromptStackElement.DeltaUsage( + yield DeltaPromptStackMessage( + role=PromptStackMessage.ASSISTANT_ROLE, + delta_usage=DeltaPromptStackMessage.DeltaUsage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count ), ) @@ -109,12 +109,12 @@ def _default_model_client(self) -> GenerativeModel: def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: inputs = [ {"role": self.__to_role(input), "parts": self.__to_content(input)} - for input in prompt_stack.inputs + for input in prompt_stack.messages if not input.is_system() ] # Gemini does not have the notion of a system message, so we insert it as part of the first message in the history. - system = next((i for i in prompt_stack.inputs if i.is_system()), None) + system = next((i for i in prompt_stack.messages if i.is_system()), None) if system is not None: inputs[0]["parts"].insert(0, "\n".join(content.to_text() for content in system.content)) @@ -130,11 +130,11 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent else: raise ValueError(f"Unsupported content type: {type(content)}") - def __to_role(self, input: PromptStackElement) -> str: + def __to_role(self, input: PromptStackMessage) -> str: if input.is_assistant(): return "model" else: return "user" - def __to_content(self, input: PromptStackElement) -> list[ContentDict | str]: + def __to_content(self, input: PromptStackMessage) -> list[ContentDict | str]: return [self.__prompt_stack_content_message_content(content) for content in input.content] diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index e77f1504c..296ed1b0f 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -9,8 +9,8 @@ from griptape.tokenizers import HuggingFaceTokenizer from griptape.common import ( PromptStack, - PromptStackElement, - DeltaPromptStackElement, + PromptStackMessage, + DeltaPromptStackMessage, BaseDeltaPromptStackContent, TextPromptStackContent, DeltaTextPromptStackContent, @@ -54,7 +54,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: prompt = self.prompt_stack_to_string(prompt_stack) response = self.client.text_generation( @@ -63,13 +63,13 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(response)) - return PromptStackElement( + return PromptStackMessage( content=response, - role=PromptStackElement.ASSISTANT_ROLE, - usage=PromptStackElement.Usage(input_tokens=input_tokens, output_tokens=output_tokens), + role=PromptStackMessage.ASSISTANT_ROLE, + usage=PromptStackMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: prompt = self.prompt_stack_to_string(prompt_stack) response = self.client.text_generation( @@ -84,9 +84,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElem yield DeltaTextPromptStackContent(token, index=0) output_tokens = len(self.tokenizer.tokenizer.encode(full_text)) - yield DeltaPromptStackElement( - role=PromptStackElement.ASSISTANT_ROLE, - delta_usage=DeltaPromptStackElement.DeltaUsage(input_tokens=input_tokens, output_tokens=output_tokens), + yield DeltaPromptStackMessage( + role=PromptStackMessage.ASSISTANT_ROLE, + delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=input_tokens, output_tokens=output_tokens), ) def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: @@ -94,7 +94,7 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] - for i in prompt_stack.inputs: + for i in prompt_stack.messages: if len(i.content) == 1: messages.append({"role": i.role, "content": TextPromptStackContent(i.to_text_artifact())}) else: diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 180ee9b45..1c88208fe 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -8,9 +8,9 @@ from griptape.artifacts import TextArtifact from griptape.common import ( BaseDeltaPromptStackContent, - DeltaPromptStackElement, + DeltaPromptStackMessage, PromptStack, - PromptStackElement, + PromptStackMessage, TextPromptStackContent, ) from griptape.drivers import BasePromptDriver @@ -48,7 +48,7 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): ) ) - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: messages = self._prompt_stack_to_messages(prompt_stack) result = self.pipe( @@ -62,17 +62,17 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) - return PromptStackElement( + return PromptStackMessage( content=[TextPromptStackContent(TextArtifact(generated_text))], - role=PromptStackElement.ASSISTANT_ROLE, - usage=PromptStackElement.Usage(input_tokens=input_tokens, output_tokens=output_tokens), + role=PromptStackMessage.ASSISTANT_ROLE, + usage=PromptStackMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) else: raise Exception("completion with more than one choice is not supported yet") else: raise Exception("invalid output format") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: raise NotImplementedError("streaming is not supported") def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: @@ -80,7 +80,7 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] - for i in prompt_stack.inputs: + for i in prompt_stack.messages: if len(i.content) == 1: messages.append({"role": i.role, "content": TextPromptStackContent(i.to_text_artifact())}) else: diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index c35186976..5a36c4c2f 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -10,11 +10,11 @@ from griptape.common import ( BaseDeltaPromptStackContent, BasePromptStackContent, - DeltaPromptStackElement, + DeltaPromptStackMessage, DeltaTextPromptStackContent, ImagePromptStackContent, PromptStack, - PromptStackElement, + PromptStackMessage, TextPromptStackContent, ) from griptape.drivers import BasePromptDriver @@ -74,31 +74,31 @@ class OpenAiChatPromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: result = self.client.chat.completions.create(**self._base_params(prompt_stack)) if len(result.choices) == 1: message = result.choices[0].message - return PromptStackElement( + return PromptStackMessage( content=[self.__message_to_prompt_stack_content(message)], role=message.role, - usage=PromptStackElement.Usage( + usage=PromptStackMessage.Usage( input_tokens=result.usage.prompt_tokens, output_tokens=result.usage.completion_tokens ), ) else: raise Exception("Completion with more than one choice is not supported yet.") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: result = self.client.chat.completions.create( **self._base_params(prompt_stack), stream=True, stream_options={"include_usage": True} ) for chunk in result: if chunk.usage is not None: - yield DeltaPromptStackElement( - delta_usage=DeltaPromptStackElement.DeltaUsage( + yield DeltaPromptStackMessage( + delta_usage=DeltaPromptStackMessage.DeltaUsage( input_tokens=chunk.usage.prompt_tokens, output_tokens=chunk.usage.completion_tokens ) ) @@ -112,7 +112,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElem raise Exception("Completion with more than one choice is not supported yet.") def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: - return [{"role": self.__to_role(input), "content": self.__to_content(input)} for input in prompt_stack.inputs] + return [{"role": self.__to_role(input), "content": self.__to_content(input)} for input in prompt_stack.messages] def _base_params(self, prompt_stack: PromptStack) -> dict: params = { @@ -127,7 +127,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if self.response_format == "json_object": params["response_format"] = {"type": "json_object"} # JSON mode still requires a system input instructing the LLM to output JSON. - prompt_stack.add_system_input("Provide your response as a valid JSON object.") + prompt_stack.add_system_message("Provide your response as a valid JSON object.") messages = self._prompt_stack_to_messages(prompt_stack) @@ -135,7 +135,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: return params - def __to_role(self, input: PromptStackElement) -> str: + def __to_role(self, input: PromptStackMessage) -> str: if input.is_system(): return "system" elif input.is_assistant(): @@ -143,7 +143,7 @@ def __to_role(self, input: PromptStackElement) -> str: else: return "user" - def __to_content(self, input: PromptStackElement) -> str | list[dict]: + def __to_content(self, input: PromptStackMessage) -> str | list[dict]: if all(isinstance(content, TextPromptStackContent) for content in input.content): return input.to_text_artifact().to_text() else: diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 481a2da4d..48eb0d392 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -5,7 +5,7 @@ from attrs import field, Factory, define from griptape.artifacts import TextArtifact, CsvRowArtifact, ListArtifact, ErrorArtifact from griptape.common import PromptStack -from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement +from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage from griptape.engines import BaseExtractionEngine from griptape.utils import J2 from griptape.rules import Ruleset @@ -65,7 +65,7 @@ def _extract_rec( rows.extend( self.text_to_csv_rows( self.prompt_driver.run( - PromptStack(inputs=[PromptStackElement(full_text, role=PromptStackElement.USER_ROLE)]) + PromptStack(messages=[PromptStackMessage(full_text, role=PromptStackMessage.USER_ROLE)]) ).value, column_names, ) @@ -83,7 +83,7 @@ def _extract_rec( rows.extend( self.text_to_csv_rows( self.prompt_driver.run( - PromptStack(inputs=[PromptStackElement(partial_text, role=PromptStackElement.USER_ROLE)]) + PromptStack(messages=[PromptStackMessage(partial_text, role=PromptStackMessage.USER_ROLE)]) ).value, column_names, ) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index cfa76f8af..4b8f45a03 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -3,7 +3,7 @@ import json from attrs import field, Factory, define from griptape.artifacts import TextArtifact, ListArtifact, ErrorArtifact -from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement +from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage from griptape.engines import BaseExtractionEngine from griptape.utils import J2 from griptape.common import PromptStack @@ -60,7 +60,7 @@ def _extract_rec( extractions.extend( self.json_to_text_artifacts( self.prompt_driver.run( - PromptStack(inputs=[PromptStackElement(full_text, role=PromptStackElement.USER_ROLE)]) + PromptStack(messages=[PromptStackMessage(full_text, role=PromptStackMessage.USER_ROLE)]) ).value ) ) @@ -77,7 +77,7 @@ def _extract_rec( extractions.extend( self.json_to_text_artifacts( self.prompt_driver.run( - PromptStack(inputs=[PromptStackElement(partial_text, role=PromptStackElement.USER_ROLE)]) + PromptStack(messages=[PromptStackMessage(partial_text, role=PromptStackMessage.USER_ROLE)]) ).value ) ) diff --git a/griptape/engines/query/vector_query_engine.py b/griptape/engines/query/vector_query_engine.py index d4db926b0..3eb4948e7 100644 --- a/griptape/engines/query/vector_query_engine.py +++ b/griptape/engines/query/vector_query_engine.py @@ -3,7 +3,7 @@ from attrs import define, field, Factory from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact from griptape.common import PromptStack -from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement +from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage from griptape.engines import BaseQueryEngine from griptape.utils.j2 import J2 from griptape.rules import Ruleset @@ -53,9 +53,9 @@ def query( message_token_count = self.prompt_driver.tokenizer.count_input_tokens_left( self.prompt_driver.prompt_stack_to_string( PromptStack( - inputs=[ - PromptStackElement(system_message, role=PromptStackElement.SYSTEM_ROLE), - PromptStackElement(user_message, role=PromptStackElement.USER_ROLE), + messages=[ + PromptStackMessage(system_message, role=PromptStackMessage.SYSTEM_ROLE), + PromptStackMessage(user_message, role=PromptStackMessage.USER_ROLE), ] ) ) @@ -74,9 +74,9 @@ def query( result = self.prompt_driver.run( PromptStack( - inputs=[ - PromptStackElement(system_message, role=PromptStackElement.SYSTEM_ROLE), - PromptStackElement(user_message, role=PromptStackElement.USER_ROLE), + messages=[ + PromptStackMessage(system_message, role=PromptStackMessage.SYSTEM_ROLE), + PromptStackMessage(user_message, role=PromptStackMessage.USER_ROLE), ] ) ) diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 18b5f3a07..a76388025 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -3,7 +3,7 @@ from griptape.artifacts import TextArtifact, ListArtifact from griptape.chunkers import BaseChunker, TextChunker from griptape.common import PromptStack -from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement +from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage from griptape.drivers import BasePromptDriver from griptape.engines import BaseSummaryEngine from griptape.utils import J2 @@ -63,9 +63,9 @@ def summarize_artifacts_rec( ): return self.prompt_driver.run( PromptStack( - inputs=[ - PromptStackElement(system_prompt, role=PromptStackElement.SYSTEM_ROLE), - PromptStackElement(user_prompt, role=PromptStackElement.USER_ROLE), + messages=[ + PromptStackMessage(system_prompt, role=PromptStackMessage.SYSTEM_ROLE), + PromptStackMessage(user_prompt, role=PromptStackMessage.USER_ROLE), ] ) ) @@ -78,9 +78,9 @@ def summarize_artifacts_rec( chunks[1:], self.prompt_driver.run( PromptStack( - inputs=[ - PromptStackElement(system_prompt, role=PromptStackElement.SYSTEM_ROLE), - PromptStackElement(partial_text, role=PromptStackElement.USER_ROLE), + messages=[ + PromptStackMessage(system_prompt, role=PromptStackMessage.SYSTEM_ROLE), + PromptStackMessage(partial_text, role=PromptStackMessage.USER_ROLE), ] ) ).value, diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 503b35fa1..53a72d227 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -47,7 +47,7 @@ def try_add_run(self, run: Run) -> None: ... def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: ... def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = None) -> PromptStack: - """Add the Conversation Memory runs to the Prompt Stack by modifying the inputs in place. + """Add the Conversation Memory runs to the Prompt Stack by modifying the messages in place. If autoprune is enabled, this will fit as many Conversation Memory runs into the Prompt Stack as possible without exceeding the token limit. @@ -67,13 +67,13 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = # Try to determine how many Conversation Memory runs we can # fit into the Prompt Stack without exceeding the token limit. while should_prune and num_runs_to_fit_in_prompt > 0: - temp_stack.inputs = prompt_stack.inputs.copy() + temp_stack.messages = prompt_stack.messages.copy() # Add n runs from Conversation Memory. # Where we insert into the Prompt Stack doesn't matter here # since we only care about the total token count. - memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).inputs - temp_stack.inputs.extend(memory_inputs) + memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).messages + temp_stack.messages.extend(memory_inputs) # Convert the prompt stack into tokens left. tokens_left = prompt_driver.tokenizer.count_input_tokens_left( @@ -87,10 +87,10 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = num_runs_to_fit_in_prompt -= 1 if num_runs_to_fit_in_prompt: - memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).inputs + memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).messages if index: - prompt_stack.inputs[index:index] = memory_inputs + prompt_stack.messages[index:index] = memory_inputs else: - prompt_stack.inputs.extend(memory_inputs) + prompt_stack.messages.extend(memory_inputs) return prompt_stack diff --git a/griptape/memory/structure/conversation_memory.py b/griptape/memory/structure/conversation_memory.py index b4043c33c..42d160abd 100644 --- a/griptape/memory/structure/conversation_memory.py +++ b/griptape/memory/structure/conversation_memory.py @@ -18,6 +18,6 @@ def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: prompt_stack = PromptStack() runs = self.runs[-last_n:] if last_n else self.runs for run in runs: - prompt_stack.add_user_input(run.input) - prompt_stack.add_assistant_input(run.output) + prompt_stack.add_user_message(run.input) + prompt_stack.add_assistant_message(run.output) return prompt_stack diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 4b5db6def..5a4ff084d 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -2,7 +2,7 @@ import logging from typing import TYPE_CHECKING, Optional from attrs import define, field, Factory -from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement +from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage from griptape.utils import J2 from griptape.common import PromptStack from griptape.memory.structure import ConversationMemory @@ -39,11 +39,11 @@ def prompt_driver(self, value: BasePromptDriver) -> None: def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: stack = PromptStack() if self.summary: - stack.add_user_input(self.summary_template_generator.render(summary=self.summary)) + stack.add_user_message(self.summary_template_generator.render(summary=self.summary)) for r in self.unsummarized_runs(last_n): - stack.add_user_input(r.input) - stack.add_assistant_input(r.output) + stack.add_user_message(r.input) + stack.add_assistant_message(r.output) return stack @@ -75,7 +75,7 @@ def summarize_runs(self, previous_summary: str | None, runs: list[Run]) -> str | if len(runs) > 0: summary = self.summarize_conversation_template_generator.render(summary=previous_summary, runs=runs) return self.prompt_driver.run( - prompt_stack=PromptStack(inputs=[PromptStackElement(summary, role=PromptStackElement.USER_ROLE)]) + prompt_stack=PromptStack(messages=[PromptStackMessage(summary, role=PromptStackMessage.USER_ROLE)]) ).to_text() else: return previous_summary diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index e61307c64..b0df7d806 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -105,7 +105,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: # These modules are required to avoid `NameError`s when resolving types. from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver from griptape.structures import Structure - from griptape.common import PromptStack, PromptStackElement + from griptape.common import PromptStack, PromptStackMessage from griptape.tokenizers.base_tokenizer import BaseTokenizer from typing import Any @@ -116,7 +116,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: attrs_cls, localns={ "PromptStack": PromptStack, - "Usage": PromptStackElement.Usage, + "Usage": PromptStackMessage.Usage, "Structure": Structure, "BaseConversationMemoryDriver": BaseConversationMemoryDriver, "BasePromptDriver": BasePromptDriver, diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 4fd8a34db..60a5f6a45 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -28,12 +28,12 @@ def prompt_stack(self) -> PromptStack: stack = PromptStack() memory = self.structure.conversation_memory - stack.add_system_input(self.generate_system_template(self)) + stack.add_system_message(self.generate_system_template(self)) - stack.add_user_input(self.input) + stack.add_user_message(self.input) if self.output: - stack.add_assistant_input(self.output) + stack.add_assistant_message(self.output) if memory: # inserting at index 1 to place memory right after system prompt diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 60eebc405..58300b529 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -65,16 +65,16 @@ def prompt_stack(self) -> PromptStack: stack = PromptStack() memory = self.structure.conversation_memory - stack.add_system_input(self.generate_system_template(self)) + stack.add_system_message(self.generate_system_template(self)) - stack.add_user_input(self.input) + stack.add_user_message(self.input) if self.output: - stack.add_assistant_input(self.output.to_text()) + stack.add_assistant_message(self.output.to_text()) else: for s in self.subtasks: - stack.add_assistant_input(self.generate_assistant_subtask_template(s)) - stack.add_user_input(self.generate_user_subtask_template(s)) + stack.add_assistant_message(self.generate_assistant_subtask_template(s)) + stack.add_user_message(self.generate_user_subtask_template(s)) if memory: # inserting at index 1 to place memory right after system prompt diff --git a/griptape/utils/conversation.py b/griptape/utils/conversation.py index ef076b168..634ad3715 100644 --- a/griptape/utils/conversation.py +++ b/griptape/utils/conversation.py @@ -22,7 +22,7 @@ def lines(self) -> list[str]: def prompt_stack(self) -> list[str]: lines = [] - for stack in self.memory.to_prompt_stack().inputs: + for stack in self.memory.to_prompt_stack().messages: lines.append(f"{stack.role}: {stack.to_text_artifact().to_text()}") return lines diff --git a/tests/mocks/mock_failing_prompt_driver.py b/tests/mocks/mock_failing_prompt_driver.py index 33127bf4a..0cf7b7df3 100644 --- a/tests/mocks/mock_failing_prompt_driver.py +++ b/tests/mocks/mock_failing_prompt_driver.py @@ -5,9 +5,9 @@ from griptape.artifacts import TextArtifact from griptape.common import ( PromptStack, - PromptStackElement, + PromptStackMessage, TextPromptStackContent, - DeltaPromptStackElement, + DeltaPromptStackMessage, DeltaTextPromptStackContent, BaseDeltaPromptStackContent, ) @@ -22,25 +22,25 @@ class MockFailingPromptDriver(BasePromptDriver): model: str = "test-model" tokenizer: BaseTokenizer = OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: if self.current_attempt < self.max_failures: self.current_attempt += 1 raise Exception("failed attempt") else: - return PromptStackElement( + return PromptStackMessage( content=[TextPromptStackContent(TextArtifact("success"))], - role=PromptStackElement.ASSISTANT_ROLE, - usage=PromptStackElement.Usage(input_tokens=100, output_tokens=100), + role=PromptStackMessage.ASSISTANT_ROLE, + usage=PromptStackMessage.Usage(input_tokens=100, output_tokens=100), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: if self.current_attempt < self.max_failures: self.current_attempt += 1 raise Exception("failed attempt") else: - yield DeltaPromptStackElement( + yield DeltaPromptStackMessage( delta_content=DeltaTextPromptStackContent("success"), - delta_usage=DeltaPromptStackElement.DeltaUsage(input_tokens=100, output_tokens=100), + delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=100, output_tokens=100), ) diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 0c1f9e00c..8fcbb5ab6 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -8,8 +8,8 @@ from griptape.artifacts import TextArtifact from griptape.common import ( PromptStack, - PromptStackElement, - DeltaPromptStackElement, + PromptStackMessage, + DeltaPromptStackMessage, BaseDeltaPromptStackContent, TextPromptStackContent, DeltaTextPromptStackContent, @@ -27,19 +27,19 @@ class MockPromptDriver(BasePromptDriver): mock_input: str | Callable[[], str] = field(default="mock input", kw_only=True) mock_output: str | Callable[[PromptStack], str] = field(default="mock output", kw_only=True) - def try_run(self, prompt_stack: PromptStack) -> PromptStackElement: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - return PromptStackElement( + return PromptStackMessage( content=[TextPromptStackContent(TextArtifact(output))], - role=PromptStackElement.ASSISTANT_ROLE, - usage=PromptStackElement.Usage(input_tokens=100, output_tokens=100), + role=PromptStackMessage.ASSISTANT_ROLE, + usage=PromptStackMessage.Usage(input_tokens=100, output_tokens=100), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackElement | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output yield DeltaTextPromptStackContent(output) - yield DeltaPromptStackElement( - delta_usage=DeltaPromptStackElement.DeltaUsage(input_tokens=100, output_tokens=100) + yield DeltaPromptStackMessage( + delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=100, output_tokens=100) ) diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index 7c7112ce0..6e4e3b4a9 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -30,9 +30,9 @@ def mock_converse_stream(self, mocker): @pytest.fixture def prompt_stack(self): prompt_stack = PromptStack() - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") return prompt_stack diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py index 1f3a3963b..318017c02 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -36,7 +36,7 @@ def test_try_run(self, mock_client): # Given driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") prompt_stack = PromptStack() - prompt_stack.add_user_input("prompt-stack") + prompt_stack.add_user_message("prompt-stack") # When response_body = [{"generated_text": "foobar"}] @@ -98,7 +98,7 @@ def test_try_stream(self, mock_client): # Given driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") prompt_stack = PromptStack() - prompt_stack.add_user_input("prompt-stack") + prompt_stack.add_user_message("prompt-stack") # When with pytest.raises(NotImplementedError) as e: @@ -123,7 +123,7 @@ def test_try_run_throws_on_empty_response(self, mock_client): driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body([])} prompt_stack = PromptStack() - prompt_stack.add_user_input("prompt-stack") + prompt_stack.add_user_message("prompt-stack") # When with pytest.raises(Exception) as e: diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 9e2e82534..17c3e97e2 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -49,9 +49,9 @@ def test_try_run(self, mock_client, model, system_enabled): # Given prompt_stack = PromptStack() if system_enabled: - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") driver = AnthropicPromptDriver(model=model, api_key="api-key") expected_messages = [ {"role": "user", "content": "user-input"}, @@ -90,9 +90,9 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): # Given prompt_stack = PromptStack() if system_enabled: - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") expected_messages = [ {"role": "user", "content": "user-input"}, {"role": "assistant", "content": "assistant-input"}, @@ -127,4 +127,4 @@ def test_try_run_throws_when_prompt_stack_is_string(self): driver.try_run(prompt_stack) # pyright: ignore # Then - assert e.value.args[0] == "'str' object has no attribute 'inputs'" + assert e.value.args[0] == "'str' object has no attribute 'messages'" diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 794df1186..0f0d0ee89 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -37,7 +37,7 @@ def test_run_via_pipeline_publishes_events(self, mocker): assert instance_count(events, FinishPromptEvent) == 1 def test_run(self): - assert isinstance(MockPromptDriver().run(PromptStack(inputs=[])), TextArtifact) + assert isinstance(MockPromptDriver().run(PromptStack(messages=[])), TextArtifact) def instance_count(instances, clazz): diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index ac434996c..964113ec5 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -29,9 +29,9 @@ def mock_tokenizer(self, mocker): @pytest.fixture def prompt_stack(self): prompt_stack = PromptStack() - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") return prompt_stack def test_init(self): diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index d9eb9313b..91bcd0da7 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -28,9 +28,9 @@ def test_init(self): def test_try_run(self, mock_generative_model): # Given prompt_stack = PromptStack() - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", top_p=0.5, top_k=50) # When @@ -51,9 +51,9 @@ def test_try_run(self, mock_generative_model): def test_try_stream(self, mock_stream_generative_model): # Given prompt_stack = PromptStack() - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", stream=True, top_p=0.5, top_k=50) # When diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index eed4f0922..259830ada 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -28,9 +28,9 @@ def mock_client_stream(self, mocker): @pytest.fixture def prompt_stack(self): prompt_stack = PromptStack() - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") return prompt_stack @pytest.fixture(autouse=True) diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index 16691b474..defb53056 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -27,9 +27,9 @@ def mock_autotokenizer(self, mocker): @pytest.fixture def prompt_stack(self): prompt_stack = PromptStack() - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") return prompt_stack def test_init(self): diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 80825efc6..db329785b 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -29,9 +29,9 @@ def mock_chat_completion_stream_create(self, mocker): @pytest.fixture def prompt_stack(self): prompt_stack = PromptStack() - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") return prompt_stack @pytest.fixture @@ -161,7 +161,7 @@ def test_try_run_throws_when_prompt_stack_is_string(self): driver.try_run("prompt-stack") # pyright: ignore # Then - assert e.value.args[0] == "'str' object has no attribute 'inputs'" + assert e.value.args[0] == "'str' object has no attribute 'messages'" @pytest.mark.parametrize("choices", [[], [1, 2]]) def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_chat_completion_create, prompt_stack): diff --git a/tests/unit/engines/summary/test_prompt_summary_engine.py b/tests/unit/engines/summary/test_prompt_summary_engine.py index 59b36f48e..34c6e3563 100644 --- a/tests/unit/engines/summary/test_prompt_summary_engine.py +++ b/tests/unit/engines/summary/test_prompt_summary_engine.py @@ -1,7 +1,7 @@ import pytest from griptape.artifacts import TextArtifact, ListArtifact from griptape.engines import PromptSummaryEngine -from griptape.utils import PromptStack +from griptape.common import PromptStack from tests.mocks.mock_prompt_driver import MockPromptDriver import os @@ -28,7 +28,7 @@ def test_max_token_multiplier_invalid(self, engine): def test_chunked_summary(self, engine): def smaller_input(prompt_stack: PromptStack): - return prompt_stack.inputs[0].content[: (len(prompt_stack.inputs[0].content) // 2)] + return prompt_stack.messages[0].content[: (len(prompt_stack.messages[0].content) // 2)] engine = PromptSummaryEngine(prompt_driver=MockPromptDriver(mock_output="smaller_input")) diff --git a/tests/unit/events/test_base_event.py b/tests/unit/events/test_base_event.py index 6e2166f6a..e3ed5aa0e 100644 --- a/tests/unit/events/test_base_event.py +++ b/tests/unit/events/test_base_event.py @@ -34,9 +34,9 @@ def test_start_prompt_event_from_dict(self): "model": "foo bar", "prompt_stack": { "type": "PromptStack", - "inputs": [ + "messages": [ { - "type": "PromptStackElement", + "type": "PromptStackMessage", "role": "user", "content": [ {"type": "TextPromptStackContent", "artifact": {"type": "TextArtifact", "value": "foo"}} @@ -44,7 +44,7 @@ def test_start_prompt_event_from_dict(self): "usage": {"type": "Usage", "input_tokens": None, "output_tokens": None}, }, { - "type": "PromptStackElement", + "type": "PromptStackMessage", "role": "system", "content": [ {"type": "TextPromptStackContent", "artifact": {"type": "TextArtifact", "value": "bar"}} @@ -59,10 +59,10 @@ def test_start_prompt_event_from_dict(self): assert isinstance(event, StartPromptEvent) assert event.timestamp == 123 - assert event.prompt_stack.inputs[0].content[0].artifact.value == "foo" - assert event.prompt_stack.inputs[0].role == "user" - assert event.prompt_stack.inputs[1].content[0].artifact.value == "bar" - assert event.prompt_stack.inputs[1].role == "system" + assert event.prompt_stack.messages[0].content[0].artifact.value == "foo" + assert event.prompt_stack.messages[0].role == "user" + assert event.prompt_stack.messages[1].content[0].artifact.value == "bar" + assert event.prompt_stack.messages[1].role == "system" assert event.model == "foo bar" def test_finish_prompt_event_from_dict(self): diff --git a/tests/unit/events/test_start_prompt_event.py b/tests/unit/events/test_start_prompt_event.py index 6f9268a63..4ef08ec5c 100644 --- a/tests/unit/events/test_start_prompt_event.py +++ b/tests/unit/events/test_start_prompt_event.py @@ -7,16 +7,16 @@ class TestStartPromptEvent: @pytest.fixture def start_prompt_event(self): prompt_stack = PromptStack() - prompt_stack.add_user_input("foo") - prompt_stack.add_system_input("bar") + prompt_stack.add_user_message("foo") + prompt_stack.add_system_message("bar") return StartPromptEvent(prompt_stack=prompt_stack, model="foo bar") def test_to_dict(self, start_prompt_event): assert "timestamp" in start_prompt_event.to_dict() - assert start_prompt_event.to_dict()["prompt_stack"]["inputs"][0]["content"][0]["artifact"]["value"] == "foo" - assert start_prompt_event.to_dict()["prompt_stack"]["inputs"][0]["role"] == "user" - assert start_prompt_event.to_dict()["prompt_stack"]["inputs"][1]["content"][0]["artifact"]["value"] == "bar" - assert start_prompt_event.to_dict()["prompt_stack"]["inputs"][1]["role"] == "system" + assert start_prompt_event.to_dict()["prompt_stack"]["messages"][0]["content"][0]["artifact"]["value"] == "foo" + assert start_prompt_event.to_dict()["prompt_stack"]["messages"][0]["role"] == "user" + assert start_prompt_event.to_dict()["prompt_stack"]["messages"][1]["content"][0]["artifact"]["value"] == "bar" + assert start_prompt_event.to_dict()["prompt_stack"]["messages"][1]["role"] == "system" assert start_prompt_event.to_dict()["model"] == "foo bar" diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 642ca4354..30685d863 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -38,8 +38,8 @@ def test_to_prompt_stack(self): prompt_stack = memory.to_prompt_stack() - assert prompt_stack.inputs[0].content[0].artifact.value == "foo" - assert prompt_stack.inputs[1].content[0].artifact.value == "bar" + assert prompt_stack.messages[0].content[0].artifact.value == "foo" + assert prompt_stack.messages[1].content[0].artifact.value == "bar" def test_from_dict(self): memory = ConversationMemory() @@ -88,11 +88,11 @@ def test_add_to_prompt_stack_autopruing_disabled(self): ) memory.structure = agent prompt_stack = PromptStack() - prompt_stack.add_user_input(TextArtifact("foo")) - prompt_stack.add_assistant_input("bar") + prompt_stack.add_user_message(TextArtifact("foo")) + prompt_stack.add_assistant_message("bar") memory.add_to_prompt_stack(prompt_stack) - assert len(prompt_stack.inputs) == 12 + assert len(prompt_stack.messages) == 12 def test_add_to_prompt_stack_autopruning_enabled(self): # All memory is pruned. @@ -109,12 +109,12 @@ def test_add_to_prompt_stack_autopruning_enabled(self): ) memory.structure = agent prompt_stack = PromptStack() - prompt_stack.add_system_input("fizz") - prompt_stack.add_user_input("foo") - prompt_stack.add_assistant_input("bar") + prompt_stack.add_system_message("fizz") + prompt_stack.add_user_message("foo") + prompt_stack.add_assistant_message("bar") memory.add_to_prompt_stack(prompt_stack) - assert len(prompt_stack.inputs) == 3 + assert len(prompt_stack.messages) == 3 # No memory is pruned. agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000))) @@ -130,12 +130,12 @@ def test_add_to_prompt_stack_autopruning_enabled(self): ) memory.structure = agent prompt_stack = PromptStack() - prompt_stack.add_system_input("fizz") - prompt_stack.add_user_input("foo") - prompt_stack.add_assistant_input("bar") + prompt_stack.add_system_message("fizz") + prompt_stack.add_user_message("foo") + prompt_stack.add_assistant_message("bar") memory.add_to_prompt_stack(prompt_stack) - assert len(prompt_stack.inputs) == 13 + assert len(prompt_stack.messages) == 13 # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens @@ -155,15 +155,15 @@ def test_add_to_prompt_stack_autopruning_enabled(self): memory.structure = agent prompt_stack = PromptStack() # And then another 6 tokens from fizz for a total of 161 tokens. - prompt_stack.add_system_input("fizz") - prompt_stack.add_user_input("foo") - prompt_stack.add_assistant_input("bar") + prompt_stack.add_system_message("fizz") + prompt_stack.add_user_message("foo") + prompt_stack.add_assistant_message("bar") memory.add_to_prompt_stack(prompt_stack, 1) # We expect one run (2 prompt stack inputs) to be pruned. - assert len(prompt_stack.inputs) == 11 - assert prompt_stack.inputs[0].content[0].artifact.value == "fizz" - assert prompt_stack.inputs[1].content[0].artifact.value == "foo2" - assert prompt_stack.inputs[2].content[0].artifact.value == "bar2" - assert prompt_stack.inputs[-2].content[0].artifact.value == "foo" - assert prompt_stack.inputs[-1].content[0].artifact.value == "bar" + assert len(prompt_stack.messages) == 11 + assert prompt_stack.messages[0].content[0].artifact.value == "fizz" + assert prompt_stack.messages[1].content[0].artifact.value == "foo2" + assert prompt_stack.messages[2].content[0].artifact.value == "bar2" + assert prompt_stack.messages[-2].content[0].artifact.value == "foo" + assert prompt_stack.messages[-1].content[0].artifact.value == "bar" diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index 5ca99f07a..e625ac6c6 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -59,9 +59,9 @@ def test_to_prompt_stack(self): prompt_stack = memory.to_prompt_stack() - assert prompt_stack.inputs[0].content[0].artifact.value == "Summary of the conversation so far: foobar" - assert prompt_stack.inputs[1].content[0].artifact.value == "foo" - assert prompt_stack.inputs[2].content[0].artifact.value == "bar" + assert prompt_stack.messages[0].content[0].artifact.value == "Summary of the conversation so far: foobar" + assert prompt_stack.messages[1].content[0].artifact.value == "foo" + assert prompt_stack.messages[2].content[0].artifact.value == "bar" def test_from_dict(self): memory = SummaryConversationMemory() diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index e6c2a1f01..0d5d8a565 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -164,15 +164,15 @@ def test_prompt_stack_without_memory(self): agent.add_task(task1) - assert len(task1.prompt_stack.inputs) == 2 + assert len(task1.prompt_stack.messages) == 2 agent.run() - assert len(task1.prompt_stack.inputs) == 3 + assert len(task1.prompt_stack.messages) == 3 agent.run() - assert len(task1.prompt_stack.inputs) == 3 + assert len(task1.prompt_stack.messages) == 3 def test_prompt_stack_with_memory(self): agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) @@ -181,15 +181,15 @@ def test_prompt_stack_with_memory(self): agent.add_task(task1) - assert len(task1.prompt_stack.inputs) == 2 + assert len(task1.prompt_stack.messages) == 2 agent.run() - assert len(task1.prompt_stack.inputs) == 5 + assert len(task1.prompt_stack.messages) == 5 agent.run() - assert len(task1.prompt_stack.inputs) == 7 + assert len(task1.prompt_stack.messages) == 7 def test_run(self): task = PromptTask("test") diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index 8c1cd8511..5131ed728 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -259,18 +259,18 @@ def test_prompt_stack_without_memory(self): pipeline.add_tasks(task1, task2) - assert len(task1.prompt_stack.inputs) == 2 - assert len(task2.prompt_stack.inputs) == 2 + assert len(task1.prompt_stack.messages) == 2 + assert len(task2.prompt_stack.messages) == 2 pipeline.run() - assert len(task1.prompt_stack.inputs) == 3 - assert len(task2.prompt_stack.inputs) == 3 + assert len(task1.prompt_stack.messages) == 3 + assert len(task2.prompt_stack.messages) == 3 pipeline.run() - assert len(task1.prompt_stack.inputs) == 3 - assert len(task2.prompt_stack.inputs) == 3 + assert len(task1.prompt_stack.messages) == 3 + assert len(task2.prompt_stack.messages) == 3 def test_prompt_stack_with_memory(self): pipeline = Pipeline(prompt_driver=MockPromptDriver()) @@ -280,18 +280,18 @@ def test_prompt_stack_with_memory(self): pipeline.add_tasks(task1, task2) - assert len(task1.prompt_stack.inputs) == 2 - assert len(task2.prompt_stack.inputs) == 2 + assert len(task1.prompt_stack.messages) == 2 + assert len(task2.prompt_stack.messages) == 2 pipeline.run() - assert len(task1.prompt_stack.inputs) == 5 - assert len(task2.prompt_stack.inputs) == 5 + assert len(task1.prompt_stack.messages) == 5 + assert len(task2.prompt_stack.messages) == 5 pipeline.run() - assert len(task1.prompt_stack.inputs) == 7 - assert len(task2.prompt_stack.inputs) == 7 + assert len(task1.prompt_stack.messages) == 7 + assert len(task2.prompt_stack.messages) == 7 def test_text_artifact_token_count(self): text = "foobar" diff --git a/tests/unit/tokenizers/test_google_tokenizer.py b/tests/unit/tokenizers/test_google_tokenizer.py index 70b441000..0f940b06d 100644 --- a/tests/unit/tokenizers/test_google_tokenizer.py +++ b/tests/unit/tokenizers/test_google_tokenizer.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import Mock from griptape.common import PromptStack -from griptape.common.prompt_stack.elements.prompt_stack_element import PromptStackElement +from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage from griptape.tokenizers import GoogleTokenizer @@ -20,7 +20,9 @@ def tokenizer(self, request): @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 5)], indirect=["tokenizer"]) def test_token_count(self, tokenizer, expected): assert tokenizer.count_tokens("foo bar huzzah") == expected - assert tokenizer.count_tokens(PromptStack(inputs=[PromptStackElement(content="foo", role="user")])) == expected + assert ( + tokenizer.count_tokens(PromptStack(messages=[PromptStackMessage(content="foo", role="user")])) == expected + ) assert tokenizer.count_tokens(["foo", "bar", "huzzah"]) == expected @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 30715)], indirect=["tokenizer"]) diff --git a/tests/unit/utils/test_prompt_stack.py b/tests/unit/utils/test_prompt_stack.py index 0732bc733..87976c02d 100644 --- a/tests/unit/utils/test_prompt_stack.py +++ b/tests/unit/utils/test_prompt_stack.py @@ -11,26 +11,26 @@ def prompt_stack(self): def test_init(self): assert PromptStack() - def test_add_input(self, prompt_stack): - prompt_stack.add_input("foo", "role") + def test_add_message(self, prompt_stack): + prompt_stack.add_message("foo", "role") - assert prompt_stack.inputs[0].role == "role" - assert prompt_stack.inputs[0].content[0].artifact.value == "foo" + assert prompt_stack.messages[0].role == "role" + assert prompt_stack.messages[0].content[0].artifact.value == "foo" - def test_add_system_input(self, prompt_stack): - prompt_stack.add_system_input("foo") + def test_add_system_message(self, prompt_stack): + prompt_stack.add_system_message("foo") - assert prompt_stack.inputs[0].role == "system" - assert prompt_stack.inputs[0].content[0].artifact.value == "foo" + assert prompt_stack.messages[0].role == "system" + assert prompt_stack.messages[0].content[0].artifact.value == "foo" - def test_add_user_input(self, prompt_stack): - prompt_stack.add_user_input("foo") + def test_add_user_message(self, prompt_stack): + prompt_stack.add_user_message("foo") - assert prompt_stack.inputs[0].role == "user" - assert prompt_stack.inputs[0].content[0].artifact.value == "foo" + assert prompt_stack.messages[0].role == "user" + assert prompt_stack.messages[0].content[0].artifact.value == "foo" - def test_add_assistant_input(self, prompt_stack): - prompt_stack.add_assistant_input("foo") + def test_add_assistant_message(self, prompt_stack): + prompt_stack.add_assistant_message("foo") - assert prompt_stack.inputs[0].role == "assistant" - assert prompt_stack.inputs[0].content[0].artifact.value == "foo" + assert prompt_stack.messages[0].role == "assistant" + assert prompt_stack.messages[0].content[0].artifact.value == "foo" From 1e4aac4d01b4492e85029ba11f345e193bff4922 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 19 Jun 2024 12:45:36 -0700 Subject: [PATCH 06/34] Fix Ollama --- .../drivers/prompt/ollama_prompt_driver.py | 24 ++++++++++++++----- .../prompt/test_ollama_prompt_driver.py | 22 ++++++++--------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index b21176e82..3820272ed 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -5,8 +5,15 @@ from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver from griptape.tokenizers.base_tokenizer import BaseTokenizer -from griptape.utils import PromptStack, import_optional_dependency +from griptape.common import PromptStack, TextPromptStackContent +from griptape.utils import import_optional_dependency from griptape.tokenizers import SimpleTokenizer +from griptape.common import ( + PromptStackMessage, + BaseDeltaPromptStackContent, + DeltaPromptStackMessage, + DeltaTextPromptStackContent, +) if TYPE_CHECKING: from ollama import Client @@ -46,24 +53,29 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: response = self.client.chat(**self._base_params(prompt_stack)) if isinstance(response, dict): - return TextArtifact(value=response["message"]["content"]) + return PromptStackMessage( + content=[TextPromptStackContent(TextArtifact(value=response["message"]["content"]))], + role=PromptStackMessage.ASSISTANT_ROLE, + ) else: raise Exception("invalid model response") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: stream = self.client.chat(**self._base_params(prompt_stack), stream=True) if isinstance(stream, Iterator): for chunk in stream: - yield TextArtifact(value=chunk["message"]["content"]) + yield DeltaTextPromptStackContent(chunk["message"]["content"], role=PromptStackMessage.ASSISTANT_ROLE) else: raise Exception("invalid model response") def _base_params(self, prompt_stack: PromptStack) -> dict: - messages = [{"role": input.role, "content": input.content} for input in prompt_stack.inputs] + messages = [ + {"role": message.role, "content": message.to_text_artifact().to_text()} for message in prompt_stack.messages + ] return {"messages": messages, "model": self.model, "options": self.options} diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index d42a8b45d..4a52e9b9c 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,5 +1,6 @@ +from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent from griptape.drivers import OllamaPromptDriver -from griptape.utils import PromptStack +from griptape.common import PromptStack import pytest @@ -25,13 +26,11 @@ def test_init(self): def test_try_run(self, mock_client): # Given prompt_stack = PromptStack() - prompt_stack.add_generic_input("generic-input") - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") driver = OllamaPromptDriver(model="llama") expected_messages = [ - {"role": "generic", "content": "generic-input"}, {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, {"role": "assistant", "content": "assistant-input"}, @@ -61,12 +60,10 @@ def test_try_run_bad_response(self, mock_client): def test_try_stream_run(self, mock_stream_client): # Given prompt_stack = PromptStack() - prompt_stack.add_generic_input("generic-input") - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") expected_messages = [ - {"role": "generic", "content": "generic-input"}, {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, {"role": "assistant", "content": "assistant-input"}, @@ -83,7 +80,8 @@ def test_try_stream_run(self, mock_stream_client): options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, stream=True, ) - assert text_artifact.value == "model-output" + if isinstance(text_artifact, DeltaTextPromptStackContent): + assert text_artifact.text == "model-output" def test_try_stream_bad_response(self, mock_stream_client): # Given From 9da420d306410c1441d7d80b58b73f1123b6ed78 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 19 Jun 2024 13:19:25 -0700 Subject: [PATCH 07/34] Clean up roles --- griptape/drivers/prompt/amazon_bedrock_prompt_driver.py | 4 +--- griptape/drivers/prompt/anthropic_prompt_driver.py | 5 ++--- griptape/drivers/prompt/cohere_prompt_driver.py | 3 +-- griptape/drivers/prompt/google_prompt_driver.py | 3 +-- griptape/drivers/prompt/huggingface_hub_prompt_driver.py | 3 +-- griptape/drivers/prompt/ollama_prompt_driver.py | 2 +- griptape/drivers/prompt/openai_chat_prompt_driver.py | 8 ++++---- 7 files changed, 11 insertions(+), 17 deletions(-) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 4446c852f..8adf67c05 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -54,9 +54,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess stream = response.get("stream") if stream is not None: for event in stream: - if "messageStart" in event: - yield DeltaPromptStackMessage(role=event["messageStart"]["role"]) - elif "contentBlockDelta" in event: + if "contentBlockDelta" in event: content_block_delta = event["contentBlockDelta"] yield DeltaTextPromptStackContent( content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"] diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 5b8b7be74..57a7e9ba8 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -54,7 +54,7 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: return PromptStackMessage( content=[self.__message_content_to_prompt_stack_content(content) for content in response.content], - role=response.role, + role=PromptStackMessage.ASSISTANT_ROLE, usage=PromptStackMessage.Usage( input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens ), @@ -68,8 +68,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess yield self.__message_content_delta_to_prompt_stack_content_delta(event) elif event.type == "message_start": yield DeltaPromptStackMessage( - role=event.message.role, - delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=event.message.usage.input_tokens), + delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=event.message.usage.input_tokens) ) elif event.type == "message_delta": yield DeltaPromptStackMessage( diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 8c951587f..f3eee8f67 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -61,10 +61,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess usage = event.response.meta.tokens yield DeltaPromptStackMessage( - role=PromptStackMessage.ASSISTANT_ROLE, delta_usage=DeltaPromptStackMessage.DeltaUsage( input_tokens=usage.input_tokens, output_tokens=usage.output_tokens - ), + ) ) def _prompt_stack_messages_to_messages(self, elements: list[PromptStackMessage]) -> list[dict]: diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 6a4a99859..a45ec0c07 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -94,10 +94,9 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess # TODO: Only yield the first one yield DeltaPromptStackMessage( - role=PromptStackMessage.ASSISTANT_ROLE, delta_usage=DeltaPromptStackMessage.DeltaUsage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count - ), + ) ) def _default_model_client(self) -> GenerativeModel: diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 296ed1b0f..31f1cec29 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -85,8 +85,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess output_tokens = len(self.tokenizer.tokenizer.encode(full_text)) yield DeltaPromptStackMessage( - role=PromptStackMessage.ASSISTANT_ROLE, - delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=input_tokens, output_tokens=output_tokens), + delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=input_tokens, output_tokens=output_tokens) ) def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 3820272ed..46730d64b 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -69,7 +69,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess if isinstance(stream, Iterator): for chunk in stream: - yield DeltaTextPromptStackContent(chunk["message"]["content"], role=PromptStackMessage.ASSISTANT_ROLE) + yield DeltaTextPromptStackContent(chunk["message"]["content"]) else: raise Exception("invalid model response") diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 5a36c4c2f..4957b3407 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -167,9 +167,9 @@ def __message_to_prompt_stack_content(self, message: ChatCompletionMessage) -> B raise ValueError(f"Unsupported message type: {message}") def __message_delta_to_prompt_stack_content_delta(self, content_delta: ChoiceDelta) -> BaseDeltaPromptStackContent: - if content_delta.content is not None: + if content_delta.content is None: + return DeltaTextPromptStackContent("") + else: delta_content = content_delta.content - return DeltaTextPromptStackContent(delta_content, role=content_delta.role) - else: - return DeltaTextPromptStackContent("", role=content_delta.role) + return DeltaTextPromptStackContent(delta_content) From 0caf02855b36612fdc6331e5a6ac2f304a981a15 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 19 Jun 2024 13:37:26 -0700 Subject: [PATCH 08/34] Rename deltas --- griptape/common/__init__.py | 4 ++-- .../base_delta_prompt_stack_content.py | 2 -- ....py => text_delta_prompt_stack_content.py} | 2 +- .../contents/text_prompt_stack_content.py | 4 ++-- .../messages/base_prompt_stack_message.py | 23 ++++++++++++++++++- .../messages/delta_prompt_stack_message.py | 22 ++---------------- .../messages/prompt_stack_message.py | 17 ++------------ .../prompt/amazon_bedrock_prompt_driver.py | 6 ++--- .../drivers/prompt/anthropic_prompt_driver.py | 8 +++---- griptape/drivers/prompt/base_prompt_driver.py | 12 +++++----- .../drivers/prompt/cohere_prompt_driver.py | 6 ++--- .../drivers/prompt/google_prompt_driver.py | 6 ++--- .../prompt/huggingface_hub_prompt_driver.py | 6 ++--- .../drivers/prompt/ollama_prompt_driver.py | 4 ++-- .../prompt/openai_chat_prompt_driver.py | 8 +++---- tests/mocks/mock_failing_prompt_driver.py | 6 ++--- tests/mocks/mock_prompt_driver.py | 8 +++---- .../test_amazon_bedrock_prompt_driver.py | 4 ++-- .../prompt/test_anthropic_prompt_driver.py | 4 ++-- .../test_azure_openai_chat_prompt_driver.py | 4 ++-- .../prompt/test_cohere_prompt_driver.py | 4 ++-- .../prompt/test_google_prompt_driver.py | 4 ++-- .../test_hugging_face_hub_prompt_driver.py | 4 ++-- .../prompt/test_ollama_prompt_driver.py | 4 ++-- .../prompt/test_openai_chat_prompt_driver.py | 4 ++-- 25 files changed, 81 insertions(+), 95 deletions(-) rename griptape/common/prompt_stack/contents/{delta_text_prompt_stack_content.py => text_delta_prompt_stack_content.py} (74%) diff --git a/griptape/common/__init__.py b/griptape/common/__init__.py index db5e011e1..ff598d638 100644 --- a/griptape/common/__init__.py +++ b/griptape/common/__init__.py @@ -1,6 +1,6 @@ from .prompt_stack.contents.base_prompt_stack_content import BasePromptStackContent from .prompt_stack.contents.base_delta_prompt_stack_content import BaseDeltaPromptStackContent -from .prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent +from .prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from .prompt_stack.contents.text_prompt_stack_content import TextPromptStackContent from .prompt_stack.contents.image_prompt_stack_content import ImagePromptStackContent @@ -16,7 +16,7 @@ "BasePromptStackContent", "DeltaPromptStackMessage", "PromptStackMessage", - "DeltaTextPromptStackContent", + "TextDeltaPromptStackContent", "TextPromptStackContent", "ImagePromptStackContent", "PromptStack", diff --git a/griptape/common/prompt_stack/contents/base_delta_prompt_stack_content.py b/griptape/common/prompt_stack/contents/base_delta_prompt_stack_content.py index 5e06f4ee9..8f0cc7ae9 100644 --- a/griptape/common/prompt_stack/contents/base_delta_prompt_stack_content.py +++ b/griptape/common/prompt_stack/contents/base_delta_prompt_stack_content.py @@ -1,7 +1,6 @@ from __future__ import annotations from abc import ABC -from typing import Optional from attrs import define, field @@ -11,4 +10,3 @@ @define class BaseDeltaPromptStackContent(ABC, SerializableMixin): index: int = field(kw_only=True, default=0, metadata={"serializable": True}) - role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) diff --git a/griptape/common/prompt_stack/contents/delta_text_prompt_stack_content.py b/griptape/common/prompt_stack/contents/text_delta_prompt_stack_content.py similarity index 74% rename from griptape/common/prompt_stack/contents/delta_text_prompt_stack_content.py rename to griptape/common/prompt_stack/contents/text_delta_prompt_stack_content.py index 25b6c25b3..05fdc1b45 100644 --- a/griptape/common/prompt_stack/contents/delta_text_prompt_stack_content.py +++ b/griptape/common/prompt_stack/contents/text_delta_prompt_stack_content.py @@ -5,5 +5,5 @@ @define -class DeltaTextPromptStackContent(BaseDeltaPromptStackContent): +class TextDeltaPromptStackContent(BaseDeltaPromptStackContent): text: str = field(metadata={"serializable": True}) diff --git a/griptape/common/prompt_stack/contents/text_prompt_stack_content.py b/griptape/common/prompt_stack/contents/text_prompt_stack_content.py index b82f2fb1f..93cebd25b 100644 --- a/griptape/common/prompt_stack/contents/text_prompt_stack_content.py +++ b/griptape/common/prompt_stack/contents/text_prompt_stack_content.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from griptape.artifacts import TextArtifact -from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent, DeltaTextPromptStackContent +from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent, TextDeltaPromptStackContent @define @@ -13,7 +13,7 @@ class TextPromptStackContent(BasePromptStackContent): @classmethod def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> TextPromptStackContent: - text_deltas = [delta for delta in deltas if isinstance(delta, DeltaTextPromptStackContent)] + text_deltas = [delta for delta in deltas if isinstance(delta, TextDeltaPromptStackContent)] artifact = TextArtifact(value="".join(delta.text for delta in text_deltas)) diff --git a/griptape/common/prompt_stack/messages/base_prompt_stack_message.py b/griptape/common/prompt_stack/messages/base_prompt_stack_message.py index ac6d1365b..c6a77012d 100644 --- a/griptape/common/prompt_stack/messages/base_prompt_stack_message.py +++ b/griptape/common/prompt_stack/messages/base_prompt_stack_message.py @@ -1,19 +1,40 @@ from __future__ import annotations from abc import ABC +from typing import Optional, Union +from attrs import Factory, define, field -from attrs import define, field +from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent from griptape.mixins import SerializableMixin @define class BasePromptStackMessage(ABC, SerializableMixin): + @define + class Usage(SerializableMixin): + input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) + output_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) + + @property + def total_tokens(self) -> float: + return (self.input_tokens or 0) + (self.output_tokens or 0) + + def __add__(self, other: BasePromptStackMessage.Usage) -> BasePromptStackMessage.Usage: + return BasePromptStackMessage.Usage( + input_tokens=(self.input_tokens or 0) + (other.input_tokens or 0), + output_tokens=(self.output_tokens or 0) + (other.output_tokens or 0), + ) + USER_ROLE = "user" ASSISTANT_ROLE = "assistant" SYSTEM_ROLE = "system" + content: list[Union[BasePromptStackContent, BaseDeltaPromptStackContent]] = field(metadata={"serializable": True}) role: str = field(kw_only=True, metadata={"serializable": True}) + usage: Usage = field( + kw_only=True, default=Factory(lambda: BasePromptStackMessage.Usage()), metadata={"serializable": True} + ) def is_system(self) -> bool: return self.role == self.SYSTEM_ROLE diff --git a/griptape/common/prompt_stack/messages/delta_prompt_stack_message.py b/griptape/common/prompt_stack/messages/delta_prompt_stack_message.py index bbc18a024..cf0799193 100644 --- a/griptape/common/prompt_stack/messages/delta_prompt_stack_message.py +++ b/griptape/common/prompt_stack/messages/delta_prompt_stack_message.py @@ -3,7 +3,7 @@ from attrs import define, field -from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent +from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from .base_prompt_stack_message import BasePromptStackMessage @@ -11,23 +11,5 @@ @define class DeltaPromptStackMessage(BasePromptStackMessage): - @define - class DeltaUsage: - input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) - output_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) - - @property - def total_tokens(self) -> float: - return (self.input_tokens or 0) + (self.output_tokens or 0) - - def __add__(self, other: DeltaPromptStackMessage.DeltaUsage) -> DeltaPromptStackMessage.DeltaUsage: - return DeltaPromptStackMessage.DeltaUsage( - input_tokens=(self.input_tokens or 0) + (other.input_tokens or 0), - output_tokens=(self.output_tokens or 0) + (other.output_tokens or 0), - ) - role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) - delta_content: Optional[DeltaTextPromptStackContent] = field( - kw_only=True, default=None, metadata={"serializable": True} - ) - delta_usage: DeltaUsage = field(kw_only=True, default=DeltaUsage(), metadata={"serializable": True}) + content: Optional[TextDeltaPromptStackContent] = field(kw_only=True, default=None, metadata={"serializable": True}) diff --git a/griptape/common/prompt_stack/messages/prompt_stack_message.py b/griptape/common/prompt_stack/messages/prompt_stack_message.py index bffde91de..4b393f570 100644 --- a/griptape/common/prompt_stack/messages/prompt_stack_message.py +++ b/griptape/common/prompt_stack/messages/prompt_stack_message.py @@ -1,36 +1,23 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any -from attrs import Factory, define, field +from attrs import define, field from griptape.artifacts import TextArtifact from griptape.common import BasePromptStackContent, TextPromptStackContent -from griptape.mixins.serializable_mixin import SerializableMixin from .base_prompt_stack_message import BasePromptStackMessage @define class PromptStackMessage(BasePromptStackMessage): - @define - class Usage(SerializableMixin): - input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) - output_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) - - @property - def total_tokens(self) -> float: - return (self.input_tokens or 0) + (self.output_tokens or 0) - def __init__(self, content: str | list[BasePromptStackContent], **kwargs: Any): if isinstance(content, str): content = [TextPromptStackContent(TextArtifact(value=content))] self.__attrs_init__(content, **kwargs) # pyright: ignore[reportAttributeAccessIssue] content: list[BasePromptStackContent] = field(metadata={"serializable": True}) - usage: Usage = field( - kw_only=True, default=Factory(lambda: PromptStackMessage.Usage()), metadata={"serializable": True} - ) @property def value(self) -> Any: diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 8adf67c05..d598ca0a7 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -10,7 +10,7 @@ BaseDeltaPromptStackContent, DeltaPromptStackMessage, PromptStackMessage, - DeltaTextPromptStackContent, + TextDeltaPromptStackContent, BasePromptStackContent, TextPromptStackContent, ImagePromptStackContent, @@ -56,13 +56,13 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess for event in stream: if "contentBlockDelta" in event: content_block_delta = event["contentBlockDelta"] - yield DeltaTextPromptStackContent( + yield TextDeltaPromptStackContent( content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"] ) elif "metadata" in event: usage = event["metadata"]["usage"] yield DeltaPromptStackMessage( - delta_usage=DeltaPromptStackMessage.DeltaUsage( + usage=DeltaPromptStackMessage.Usage( input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"] ) ) diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 57a7e9ba8..a253591d1 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -10,7 +10,7 @@ BaseDeltaPromptStackContent, BasePromptStackContent, DeltaPromptStackMessage, - DeltaTextPromptStackContent, + TextDeltaPromptStackContent, ImagePromptStackContent, PromptStack, PromptStackMessage, @@ -68,11 +68,11 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess yield self.__message_content_delta_to_prompt_stack_content_delta(event) elif event.type == "message_start": yield DeltaPromptStackMessage( - delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=event.message.usage.input_tokens) + usage=DeltaPromptStackMessage.Usage(input_tokens=event.message.usage.input_tokens) ) elif event.type == "message_delta": yield DeltaPromptStackMessage( - delta_usage=DeltaPromptStackMessage.DeltaUsage(output_tokens=event.usage.output_tokens) + usage=DeltaPromptStackMessage.Usage(output_tokens=event.usage.output_tokens) ) def _prompt_stack_messages_to_messages(self, elements: list[PromptStackMessage]) -> list[dict]: @@ -135,6 +135,6 @@ def __message_content_delta_to_prompt_stack_content_delta( index = content_delta.index if content_delta.delta.type == "text_delta": - return DeltaTextPromptStackContent(content_delta.delta.text, index=index) + return TextDeltaPromptStackContent(content_delta.delta.text, index=index) else: raise ValueError(f"Unsupported message content delta type : {content_delta.delta.type}") diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 2373aeeb8..d6bcbe383 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -10,7 +10,7 @@ from griptape.common import ( BaseDeltaPromptStackContent, DeltaPromptStackMessage, - DeltaTextPromptStackContent, + TextDeltaPromptStackContent, PromptStack, PromptStackMessage, TextPromptStackContent, @@ -119,25 +119,25 @@ def __process_run(self, prompt_stack: PromptStack) -> PromptStackMessage: def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage: delta_contents: dict[int, list[BaseDeltaPromptStackContent]] = {} - delta_usage = DeltaPromptStackMessage.DeltaUsage() + usage = DeltaPromptStackMessage.Usage() deltas = self.try_stream(prompt_stack) for delta in deltas: if isinstance(delta, DeltaPromptStackMessage): - delta_usage += delta.delta_usage + usage += delta.usage elif isinstance(delta, BaseDeltaPromptStackContent): if delta.index in delta_contents: delta_contents[delta.index].append(delta) else: delta_contents[delta.index] = [delta] - if isinstance(delta, DeltaTextPromptStackContent): + if isinstance(delta, TextDeltaPromptStackContent): self.structure.publish_event(CompletionChunkEvent(token=delta.text)) content = [] for index, deltas in delta_contents.items(): - text_deltas = [delta for delta in deltas if isinstance(delta, DeltaTextPromptStackContent)] + text_deltas = [delta for delta in deltas if isinstance(delta, TextDeltaPromptStackContent)] if text_deltas: content.append(TextPromptStackContent.from_deltas(text_deltas)) @@ -145,7 +145,7 @@ def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage: content=content, role=PromptStackMessage.ASSISTANT_ROLE, usage=PromptStackMessage.Usage( - input_tokens=delta_usage.input_tokens or 0, output_tokens=delta_usage.output_tokens or 0 + input_tokens=usage.input_tokens or 0, output_tokens=usage.output_tokens or 0 ), ) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index f3eee8f67..9c36609b1 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -12,7 +12,7 @@ BaseDeltaPromptStackContent, TextPromptStackContent, BasePromptStackContent, - DeltaTextPromptStackContent, + TextDeltaPromptStackContent, ) from griptape.utils import import_optional_dependency from griptape.tokenizers import BaseTokenizer @@ -56,12 +56,12 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess for event in result: if event.event_type == "text-generation": - yield DeltaTextPromptStackContent(event.text, index=0) + yield TextDeltaPromptStackContent(event.text, index=0) if event.event_type == "stream-end": usage = event.response.meta.tokens yield DeltaPromptStackMessage( - delta_usage=DeltaPromptStackMessage.DeltaUsage( + usage=DeltaPromptStackMessage.Usage( input_tokens=usage.input_tokens, output_tokens=usage.output_tokens ) ) diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index a45ec0c07..2dd649c5f 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -10,7 +10,7 @@ BaseDeltaPromptStackContent, BasePromptStackContent, DeltaPromptStackMessage, - DeltaTextPromptStackContent, + TextDeltaPromptStackContent, ImagePromptStackContent, PromptStack, PromptStackMessage, @@ -90,11 +90,11 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess for chunk in response: usage_metadata = chunk.usage_metadata - yield DeltaTextPromptStackContent(chunk.text) + yield TextDeltaPromptStackContent(chunk.text) # TODO: Only yield the first one yield DeltaPromptStackMessage( - delta_usage=DeltaPromptStackMessage.DeltaUsage( + usage=DeltaPromptStackMessage.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count ) ) diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 31f1cec29..45ae5eb87 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -13,7 +13,7 @@ DeltaPromptStackMessage, BaseDeltaPromptStackContent, TextPromptStackContent, - DeltaTextPromptStackContent, + TextDeltaPromptStackContent, ) from griptape.utils import import_optional_dependency @@ -81,11 +81,11 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess full_text = "" for token in response: full_text += token - yield DeltaTextPromptStackContent(token, index=0) + yield TextDeltaPromptStackContent(token, index=0) output_tokens = len(self.tokenizer.tokenizer.encode(full_text)) yield DeltaPromptStackMessage( - delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=input_tokens, output_tokens=output_tokens) + usage=DeltaPromptStackMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens) ) def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 46730d64b..a913e9477 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -12,7 +12,7 @@ PromptStackMessage, BaseDeltaPromptStackContent, DeltaPromptStackMessage, - DeltaTextPromptStackContent, + TextDeltaPromptStackContent, ) if TYPE_CHECKING: @@ -69,7 +69,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess if isinstance(stream, Iterator): for chunk in stream: - yield DeltaTextPromptStackContent(chunk["message"]["content"]) + yield TextDeltaPromptStackContent(chunk["message"]["content"]) else: raise Exception("invalid model response") diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 4957b3407..391425c38 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -11,7 +11,7 @@ BaseDeltaPromptStackContent, BasePromptStackContent, DeltaPromptStackMessage, - DeltaTextPromptStackContent, + TextDeltaPromptStackContent, ImagePromptStackContent, PromptStack, PromptStackMessage, @@ -98,7 +98,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess for chunk in result: if chunk.usage is not None: yield DeltaPromptStackMessage( - delta_usage=DeltaPromptStackMessage.DeltaUsage( + usage=DeltaPromptStackMessage.Usage( input_tokens=chunk.usage.prompt_tokens, output_tokens=chunk.usage.completion_tokens ) ) @@ -168,8 +168,8 @@ def __message_to_prompt_stack_content(self, message: ChatCompletionMessage) -> B def __message_delta_to_prompt_stack_content_delta(self, content_delta: ChoiceDelta) -> BaseDeltaPromptStackContent: if content_delta.content is None: - return DeltaTextPromptStackContent("") + return TextDeltaPromptStackContent("") else: delta_content = content_delta.content - return DeltaTextPromptStackContent(delta_content) + return TextDeltaPromptStackContent(delta_content) diff --git a/tests/mocks/mock_failing_prompt_driver.py b/tests/mocks/mock_failing_prompt_driver.py index 0cf7b7df3..80684378d 100644 --- a/tests/mocks/mock_failing_prompt_driver.py +++ b/tests/mocks/mock_failing_prompt_driver.py @@ -8,7 +8,7 @@ PromptStackMessage, TextPromptStackContent, DeltaPromptStackMessage, - DeltaTextPromptStackContent, + TextDeltaPromptStackContent, BaseDeltaPromptStackContent, ) from griptape.drivers import BasePromptDriver @@ -41,6 +41,6 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess raise Exception("failed attempt") else: yield DeltaPromptStackMessage( - delta_content=DeltaTextPromptStackContent("success"), - delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=100, output_tokens=100), + delta_content=TextDeltaPromptStackContent("success"), + usage=DeltaPromptStackMessage.Usage(input_tokens=100, output_tokens=100), ) diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 8fcbb5ab6..1b1e8b38a 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -12,7 +12,7 @@ DeltaPromptStackMessage, BaseDeltaPromptStackContent, TextPromptStackContent, - DeltaTextPromptStackContent, + TextDeltaPromptStackContent, ) from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer @@ -39,7 +39,5 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - yield DeltaTextPromptStackContent(output) - yield DeltaPromptStackMessage( - delta_usage=DeltaPromptStackMessage.DeltaUsage(input_tokens=100, output_tokens=100) - ) + yield TextDeltaPromptStackContent(output) + yield DeltaPromptStackMessage(usage=DeltaPromptStackMessage.Usage(input_tokens=100, output_tokens=100)) diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index 6e4e3b4a9..783e940f9 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,7 +1,7 @@ import pytest from griptape.common import PromptStack -from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent +from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from griptape.drivers import AmazonBedrockPromptDriver @@ -83,5 +83,5 @@ def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages): additionalModelRequestFields={}, ) - if isinstance(text_artifact, DeltaTextPromptStackContent): + if isinstance(text_artifact, TextDeltaPromptStackContent): assert text_artifact.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 17c3e97e2..205d66ffa 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,4 +1,4 @@ -from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent +from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from griptape.drivers import AnthropicPromptDriver from griptape.common import PromptStack from unittest.mock import Mock @@ -114,7 +114,7 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): top_k=250, **{"system": "system-input"} if system_enabled else {}, ) - if isinstance(text_artifact, DeltaTextPromptStackContent): + if isinstance(text_artifact, TextDeltaPromptStackContent): assert text_artifact.text == "model-output" def test_try_run_throws_when_prompt_stack_is_string(self): diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index b2b3be062..93d0e165a 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -1,6 +1,6 @@ import pytest from unittest.mock import Mock -from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent +from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from griptape.drivers import AzureOpenAiChatPromptDriver from tests.unit.drivers.prompt.test_openai_chat_prompt_driver import TestOpenAiChatPromptDriverFixtureMixin @@ -61,5 +61,5 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, stream_options={"include_usage": True}, ) - if isinstance(text_artifact, DeltaTextPromptStackContent): + if isinstance(text_artifact, TextDeltaPromptStackContent): assert text_artifact == "model-output" diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 964113ec5..7f6c6f400 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.common import DeltaTextPromptStackContent, PromptStack +from griptape.common import TextDeltaPromptStackContent, PromptStack from griptape.drivers import CoherePromptDriver @@ -55,5 +55,5 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack): # pyright: ign text_artifact = next(driver.try_stream(prompt_stack)) # Then - if isinstance(text_artifact, DeltaTextPromptStackContent): + if isinstance(text_artifact, TextDeltaPromptStackContent): assert text_artifact.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 91bcd0da7..0ee048338 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -1,5 +1,5 @@ from google.generativeai.types import GenerationConfig -from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent +from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from griptape.drivers import GooglePromptDriver from griptape.common import PromptStack from unittest.mock import Mock @@ -69,5 +69,5 @@ def test_try_stream(self, mock_stream_generative_model): stream=True, generation_config=GenerationConfig(temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[]), ) - if isinstance(text_artifact, DeltaTextPromptStackContent): + if isinstance(text_artifact, TextDeltaPromptStackContent): assert text_artifact.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 259830ada..6e7367a10 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -1,4 +1,4 @@ -from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent +from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from griptape.drivers import HuggingFaceHubPromptDriver from griptape.common import PromptStack import pytest @@ -60,5 +60,5 @@ def test_try_stream(self, prompt_stack, mock_client_stream): text_artifact = next(driver.try_stream(prompt_stack)) # Then - if isinstance(text_artifact, DeltaTextPromptStackContent): + if isinstance(text_artifact, TextDeltaPromptStackContent): assert text_artifact.text == "model-output" diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index 4a52e9b9c..31ee3fec7 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,4 +1,4 @@ -from griptape.common.prompt_stack.contents.delta_text_prompt_stack_content import DeltaTextPromptStackContent +from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from griptape.drivers import OllamaPromptDriver from griptape.common import PromptStack import pytest @@ -80,7 +80,7 @@ def test_try_stream_run(self, mock_stream_client): options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, stream=True, ) - if isinstance(text_artifact, DeltaTextPromptStackContent): + if isinstance(text_artifact, TextDeltaPromptStackContent): assert text_artifact.text == "model-output" def test_try_stream_bad_response(self, mock_stream_client): diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index db329785b..6d8d2cb51 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,5 +1,5 @@ from griptape.drivers import OpenAiChatPromptDriver -from griptape.common import PromptStack, DeltaTextPromptStackContent +from griptape.common import PromptStack, TextDeltaPromptStackContent from griptape.tokenizers import OpenAiTokenizer from unittest.mock import Mock from tests.mocks.mock_tokenizer import MockTokenizer @@ -131,7 +131,7 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, stream_options={"include_usage": True}, ) - if isinstance(text_artifact, DeltaTextPromptStackContent): + if isinstance(text_artifact, TextDeltaPromptStackContent): assert text_artifact.text == "model-output" def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): From 784aafcf86baf99c90e800fb5986341bf4f78a5e Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 19 Jun 2024 15:18:55 -0700 Subject: [PATCH 09/34] PR cleanup --- griptape/config/google_structure_config.py | 2 +- griptape/drivers/prompt/base_prompt_driver.py | 7 +++---- .../drivers/prompt/dummy_prompt_driver.py | 18 +---------------- .../drivers/prompt/google_prompt_driver.py | 20 ++++++++++++++----- griptape/schemas/base_schema.py | 2 +- 5 files changed, 21 insertions(+), 28 deletions(-) diff --git a/griptape/config/google_structure_config.py b/griptape/config/google_structure_config.py index 76f55d3ef..fc0548ff7 100644 --- a/griptape/config/google_structure_config.py +++ b/griptape/config/google_structure_config.py @@ -14,7 +14,7 @@ @define class GoogleStructureConfig(StructureConfig): prompt_driver: BasePromptDriver = field( - default=Factory(lambda: GooglePromptDriver(model="gemini-1.5-flash")), + default=Factory(lambda: GooglePromptDriver(model="gemini-1.5-pro")), kw_only=True, metadata={"serializable": True}, ) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index d6bcbe383..0a08eceb0 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -121,8 +121,8 @@ def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage: delta_contents: dict[int, list[BaseDeltaPromptStackContent]] = {} usage = DeltaPromptStackMessage.Usage() + # Aggregate all content deltas from the stream deltas = self.try_stream(prompt_stack) - for delta in deltas: if isinstance(delta, DeltaPromptStackMessage): usage += delta.usage @@ -135,6 +135,7 @@ def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage: if isinstance(delta, TextDeltaPromptStackContent): self.structure.publish_event(CompletionChunkEvent(token=delta.text)) + # Build a complete content from the content deltas content = [] for index, deltas in delta_contents.items(): text_deltas = [delta for delta in deltas if isinstance(delta, TextDeltaPromptStackContent)] @@ -144,9 +145,7 @@ def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage: result = PromptStackMessage( content=content, role=PromptStackMessage.ASSISTANT_ROLE, - usage=PromptStackMessage.Usage( - input_tokens=usage.input_tokens or 0, output_tokens=usage.output_tokens or 0 - ), + usage=PromptStackMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), ) return result diff --git a/griptape/drivers/prompt/dummy_prompt_driver.py b/griptape/drivers/prompt/dummy_prompt_driver.py index 48eb95b25..bb825566a 100644 --- a/griptape/drivers/prompt/dummy_prompt_driver.py +++ b/griptape/drivers/prompt/dummy_prompt_driver.py @@ -1,16 +1,9 @@ from __future__ import annotations from collections.abc import Iterator -from typing import Any from attrs import Factory, define, field -from griptape.common import ( - BasePromptStackContent, - PromptStack, - PromptStackMessage, - DeltaPromptStackMessage, - BaseDeltaPromptStackContent, -) +from griptape.common import PromptStack, PromptStackMessage, DeltaPromptStackMessage, BaseDeltaPromptStackContent from griptape.drivers import BasePromptDriver from griptape.exceptions import DummyException from griptape.tokenizers import DummyTokenizer @@ -26,12 +19,3 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: raise DummyException(__class__.__name__, "try_stream") - - def _prompt_stack_input_to_message(self, prompt_input: PromptStackMessage) -> dict: - raise DummyException(__class__.__name__, "_prompt_stack_input_to_message") - - def _prompt_stack_content_to_message_content(self, content: BasePromptStackContent) -> Any: - raise DummyException(__class__.__name__, "_prompt_stack_content_to_message_content") - - def _message_content_to_prompt_stack_content(self, message_content: Any) -> BasePromptStackContent: - raise DummyException(__class__.__name__, "_message_content_to_prompt_stack_content") diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 2dd649c5f..dc4f2d591 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -87,17 +87,27 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess ), ) + prompt_token_count = None for chunk in response: usage_metadata = chunk.usage_metadata yield TextDeltaPromptStackContent(chunk.text) - # TODO: Only yield the first one - yield DeltaPromptStackMessage( - usage=DeltaPromptStackMessage.Usage( - input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count + # Only want to output the prompt token count once since it is static each chunk + if prompt_token_count is None: + prompt_token_count = usage_metadata.prompt_token_count + yield DeltaPromptStackMessage( + role=PromptStackMessage.ASSISTANT_ROLE, + usage=DeltaPromptStackMessage.Usage( + input_tokens=usage_metadata.prompt_token_count, + output_tokens=usage_metadata.candidates_token_count, + ), + ) + else: + yield DeltaPromptStackMessage( + role=PromptStackMessage.ASSISTANT_ROLE, + usage=DeltaPromptStackMessage.Usage(output_tokens=usage_metadata.candidates_token_count), ) - ) def _default_model_client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index b0df7d806..e309a5eab 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -35,7 +35,7 @@ def make_obj(self, data, **kwargs): cls._resolve_types(attrs_cls) return SubSchema.from_dict( { - a.alias or a.name: cls._get_field_for_type(a.type) + a.name: cls._get_field_for_type(a.type) for a in attrs.fields(attrs_cls) if a.metadata.get("serializable") }, From 52c1a04b6992c9de9397127eef443d92b33a8d7a Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 19 Jun 2024 15:30:36 -0700 Subject: [PATCH 10/34] Change task hierarchy --- griptape/tasks/actions_subtask.py | 21 +++++--- griptape/tasks/base_text_input_task.py | 54 +++++-------------- griptape/tasks/prompt_task.py | 54 ++++++++++++++++++- .../config/test_google_structure_config.py | 2 +- 4 files changed, 78 insertions(+), 53 deletions(-) diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 4aa2b2783..1546a825d 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -38,6 +38,19 @@ class Action: _input: Optional[str | TextArtifact | Callable[[BaseTask], TextArtifact]] = field(default=None) _memory: Optional[TaskMemory] = None + @property + def input(self) -> TextArtifact: + if isinstance(self._input, TextArtifact): + return self._input + elif isinstance(self._input, Callable): + return self._input(self) + else: + return TextArtifact(self._input) + + @input.setter + def input(self, value: str | TextArtifact | Callable[[BaseTask], TextArtifact]) -> None: + self._input = value + @property def origin_task(self) -> BaseTask: if self.parent_task_id: @@ -165,14 +178,6 @@ def actions_to_dicts(self) -> list[dict]: def actions_to_json(self) -> str: return json.dumps(self.actions_to_dicts()) - def _process_task_input( - self, task_input: str | list | BaseArtifact | Callable[[BaseTask], BaseArtifact] - ) -> BaseArtifact: - if isinstance(task_input, TextArtifact): - return task_input - else: - return super()._process_task_input(task_input) - def __init_from_prompt(self, value: str) -> None: thought_matches = re.findall(self.THOUGHT_PATTERN, value, re.MULTILINE) actions_matches = re.findall(self.ACTIONS_PATTERN, value, re.DOTALL) diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index f0b5bd9e0..c5641bb14 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -2,11 +2,10 @@ from abc import ABC from typing import Callable -from collections.abc import Sequence from attrs import define, field -from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact +from griptape.artifacts import TextArtifact from griptape.mixins.rule_mixin import RuleMixin from griptape.tasks import BaseTask from griptape.utils import J2 @@ -14,23 +13,23 @@ @define class BaseTextInputTask(RuleMixin, BaseTask, ABC): - _input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( - default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), - alias="input", + DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" + + _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field( + default=DEFAULT_INPUT_TEMPLATE, alias="input" ) @property - def input(self) -> BaseArtifact: - if isinstance(self._input, list) or isinstance(self._input, tuple): - artifacts = [self._process_task_input(input) for input in self._input] - flattened_artifacts = self.__flatten_artifacts(artifacts) - - return ListArtifact(flattened_artifacts) + def input(self) -> TextArtifact: + if isinstance(self._input, TextArtifact): + return self._input + elif isinstance(self._input, Callable): + return self._input(self) else: - return self._process_task_input(self._input) + return TextArtifact(J2().render_from_string(self._input, **self.full_context)) @input.setter - def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact]) -> None: + def input(self, value: str | TextArtifact | Callable[[BaseTask], TextArtifact]) -> None: self._input = value def before_run(self) -> None: @@ -42,32 +41,3 @@ def after_run(self) -> None: super().after_run() self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") - - def _process_task_input( - self, task_input: str | list | BaseArtifact | Callable[[BaseTask], BaseArtifact] - ) -> BaseArtifact: - if isinstance(task_input, TextArtifact): - task_input.value = J2().render_from_string(task_input.value, **self.full_context) - - return task_input - elif isinstance(task_input, Callable): - return self._process_task_input(task_input(self)) - elif isinstance(task_input, str): - return self._process_task_input(TextArtifact(task_input)) - elif isinstance(task_input, BaseArtifact): - return task_input - elif isinstance(task_input, list): - return ListArtifact([self._process_task_input(elem) for elem in task_input]) - else: - raise ValueError(f"Invalid input type: {type(task_input)} ") - - def __flatten_artifacts(self, artifacts: Sequence[BaseArtifact]) -> Sequence[BaseArtifact]: - result = [] - - for elem in artifacts: - if isinstance(elem, ListArtifact): - result.extend(self.__flatten_artifacts(elem.value)) - else: - result.append(elem) - - return result diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 60a5f6a45..0839aff41 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -1,13 +1,16 @@ from __future__ import annotations from typing import TYPE_CHECKING, Callable, Optional +from collections.abc import Sequence from attrs import Factory, define, field from griptape.artifacts import BaseArtifact from griptape.common import PromptStack -from griptape.tasks import BaseTextInputTask +from griptape.tasks import BaseTask from griptape.utils import J2 +from griptape.artifacts import TextArtifact, ListArtifact +from griptape.mixins import RuleMixin if TYPE_CHECKING: from griptape.drivers import BasePromptDriver @@ -15,11 +18,29 @@ @define -class PromptTask(BaseTextInputTask): +class PromptTask(RuleMixin, BaseTask): _prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True, alias="prompt_driver") generate_system_template: Callable[[PromptTask], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True ) + _input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( + default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), + alias="input", + ) + + @property + def input(self) -> BaseArtifact: + if isinstance(self._input, list) or isinstance(self._input, tuple): + artifacts = [self._process_task_input(input) for input in self._input] + flattened_artifacts = self.__flatten_artifacts(artifacts) + + return ListArtifact(flattened_artifacts) + else: + return self._process_task_input(self._input) + + @input.setter + def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact]) -> None: + self._input = value output: Optional[BaseArtifact] = field(default=None, init=False) @@ -66,3 +87,32 @@ def run(self) -> BaseArtifact: self.output = self.prompt_driver.run(self.prompt_stack) return self.output + + def _process_task_input( + self, task_input: str | list | BaseArtifact | Callable[[BaseTask], BaseArtifact] + ) -> BaseArtifact: + if isinstance(task_input, TextArtifact): + task_input.value = J2().render_from_string(task_input.value, **self.full_context) + + return task_input + elif isinstance(task_input, Callable): + return self._process_task_input(task_input(self)) + elif isinstance(task_input, str): + return self._process_task_input(TextArtifact(task_input)) + elif isinstance(task_input, BaseArtifact): + return task_input + elif isinstance(task_input, list): + return ListArtifact([self._process_task_input(elem) for elem in task_input]) + else: + raise ValueError(f"Invalid input type: {type(task_input)} ") + + def __flatten_artifacts(self, artifacts: Sequence[BaseArtifact]) -> Sequence[BaseArtifact]: + result = [] + + for elem in artifacts: + if isinstance(elem, ListArtifact): + result.extend(self.__flatten_artifacts(elem.value)) + else: + result.append(elem) + + return result diff --git a/tests/unit/config/test_google_structure_config.py b/tests/unit/config/test_google_structure_config.py index 47c46f181..ad96f2a35 100644 --- a/tests/unit/config/test_google_structure_config.py +++ b/tests/unit/config/test_google_structure_config.py @@ -19,7 +19,7 @@ def test_to_dict(self, config): "temperature": 0.1, "max_tokens": None, "stream": False, - "model": "gemini-1.5-flash", + "model": "gemini-1.5-pro", "top_p": None, "top_k": None, }, From b82ae141374e82b290939c9e42e13dd21d37a6c4 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 19 Jun 2024 15:34:38 -0700 Subject: [PATCH 11/34] Update changelog --- CHANGELOG.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 66104a6e1..8552a4d59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,31 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased +### Added +- `PromptStackMessage` for storing messages in a `PromptStack`. Messages consist of a role, content, and usage. +- `DeltaPromptStackMessage` for storing partial messages in a `PromptStack`. Multiple `DeltaPromptStackMessage` can be combined to form a `PromptStackMessage`. +- `TextPromptStackContent` for storing textual content in a `PromptStackMessage`. +- `ImagePromptStackContent` for storing image content in a `PromptStackMessage`. +- Support for adding `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s to `PromptStack`. +- Support for image inputs to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AmazonBedrockPromptDriver`, `AnthropicPromptDriver`, and `GooglePromptDriver`. +- Input/output token usage metrics to all Prompt Drivers. +- `FinishPromptEvent.input_token_count` and `FinishPromptEvent.output_token_count`. +- Support for storing Artifacts as inputs/outputs in Conversation Memory Runs. +- `Agent.input` for passing Artifacts as input. +- Support for `PromptTask`s to take `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s as input. + +### Changed +- **BREAKING**: Moved `griptape.utils.PromptStack` to `griptape.common.PromptStack`. +- **BREAKING**: Renamed `PromptStack.inputs` to `PromptStack.messages`. +- **BREAKING**: Moved `PromptStack.USER_ROLE`, `PromptStack.ASSISTANT_ROLE`, and `PromptStack.SYSTEM_ROLE` to `PromptStackMessage`. +- **BREAKING**: Updated return type of `PromptDriver.try_run` from `TextArtifact` to `PromptStackMessage`. +- **BREAKING**: Updated return type of `PromptDriver.try_stream` from `Iterator[TextArtifact]` to `Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]`. +- **BREAKING**: Removed `BasePromptEvent.token_count` in favor of `FinishPromptEvent.input_token_count` and `FinishPromptEvent.output_token_count`. +- **BREAKING**: Removed `StartPromptEvent.prompt`. Use `StartPromptEvent.prompt_stack` instead. +- **BREAKING**: Removed `Agent.input_template` in favor of `Agent.input`. +- Default Prompt Driver model in `GoogleStructureConfig` to `gemini-1.5-pro`. + + ### Added - `GoogleWebSearchDriver` to web search with the Google Customsearch API. From fc63c54825eaa90461bed508ec585efbb11051cf Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 20 Jun 2024 09:05:54 -0700 Subject: [PATCH 12/34] Regenerate lock file --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 477d4ad08..73d84d3f4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6392,4 +6392,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "ea8512a93be36ad1076915c5751167e3088b035deab1650182f72b842c5f8372" +content-hash = "584f05b52935d6e3bbcafc7b4eae7d1b14fe3a28f4e292e37bd307f3202cd4ff" From 1169a50889db19093669cdc393b09a1db4994658 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 20 Jun 2024 09:08:30 -0700 Subject: [PATCH 13/34] Add back missing logs --- griptape/tasks/prompt_task.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 0839aff41..751fdab17 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -83,6 +83,16 @@ def default_system_template_generator(self, _: PromptTask) -> str: rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets) ) + def before_run(self) -> None: + super().before_run() + + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {self.input.to_text()}") + + def after_run(self) -> None: + super().after_run() + + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") + def run(self) -> BaseArtifact: self.output = self.prompt_driver.run(self.prompt_stack) From 46efbee8dfe206ba3c7dd312cc70e375bb4c4621 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 20 Jun 2024 12:39:49 -0700 Subject: [PATCH 14/34] Fix doc var names --- docs/griptape-framework/misc/events.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 226e96741..851b5f382 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -244,9 +244,9 @@ from griptape.events import BaseEvent, StartPromptEvent, EventListener def handler(event: BaseEvent): if isinstance(event, StartPromptEvent): - print("Prompt Stack Inputs:") - for input in event.prompt_stack.messages: - print(f"{input.role}: {input.content}") + print("Prompt Stack Messages:") + for message in event.prompt_stack.messages: + print(f"{message.role}: {message.content}") print("Final Prompt String:") print(event.prompt) @@ -259,7 +259,7 @@ agent.run("Write me a poem.") ``` ``` ... -Prompt Stack Inputs: +Prompt Stack Messages: system: user: Write me a poem. Final Prompt String: From d071defd27d21c229e9dd53c3665a4e2c4a0e252 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 20 Jun 2024 12:44:17 -0700 Subject: [PATCH 15/34] Clean up message building --- .../drivers/prompt/huggingface_pipeline_prompt_driver.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 1c88208fe..9a905ae51 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -80,11 +80,9 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] + for i in prompt_stack.messages: - if len(i.content) == 1: - messages.append({"role": i.role, "content": TextPromptStackContent(i.to_text_artifact())}) - else: - raise ValueError("Invalid input content length.") + messages.append({"role": i.role, "content": i.to_text_artifact().to_text()}) return messages From 5fdf8107eb3b76e78f842d0cf9cac3537001443b Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 20 Jun 2024 15:07:22 -0700 Subject: [PATCH 16/34] Add tests --- .../prompt/amazon_bedrock_prompt_driver.py | 13 ++- ...mazon_sagemaker_jumpstart_prompt_driver.py | 10 +-- .../drivers/prompt/anthropic_prompt_driver.py | 11 +-- griptape/drivers/prompt/base_prompt_driver.py | 14 +++- .../drivers/prompt/cohere_prompt_driver.py | 7 +- .../drivers/prompt/dummy_prompt_driver.py | 4 +- .../drivers/prompt/google_prompt_driver.py | 7 +- .../prompt/huggingface_hub_prompt_driver.py | 5 +- .../huggingface_pipeline_prompt_driver.py | 10 +-- .../drivers/prompt/ollama_prompt_driver.py | 11 +-- .../prompt/openai_chat_prompt_driver.py | 7 +- tests/mocks/mock_failing_prompt_driver.py | 5 +- tests/mocks/mock_prompt_driver.py | 9 ++- .../test_amazon_bedrock_prompt_driver.py | 36 +++++---- ...mazon_sagemaker_jumpstart_prompt_driver.py | 5 +- .../prompt/test_anthropic_prompt_driver.py | 79 +++++++++++++++---- .../test_azure_openai_chat_prompt_driver.py | 33 +++++--- .../prompt/test_cohere_prompt_driver.py | 45 +++++++++-- .../prompt/test_google_prompt_driver.py | 35 ++++++-- .../test_hugging_face_hub_prompt_driver.py | 20 +++-- ...est_hugging_face_pipeline_prompt_driver.py | 7 +- .../prompt/test_ollama_prompt_driver.py | 6 +- .../prompt/test_openai_chat_prompt_driver.py | 69 +++++++++------- tests/unit/tasks/test_prompt_task.py | 59 ++++++++++++++ tests/unit/utils/test_prompt_stack.py | 21 ++++- 25 files changed, 364 insertions(+), 164 deletions(-) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index d598ca0a7..62e5f97eb 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -7,7 +7,6 @@ from griptape.artifacts import TextArtifact from griptape.common import ( - BaseDeltaPromptStackContent, DeltaPromptStackMessage, PromptStackMessage, TextDeltaPromptStackContent, @@ -48,7 +47,7 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: usage=PromptStackMessage.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack)) stream = response.get("stream") @@ -56,8 +55,10 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess for event in stream: if "contentBlockDelta" in event: content_block_delta = event["contentBlockDelta"] - yield TextDeltaPromptStackContent( - content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"] + yield DeltaPromptStackMessage( + content=TextDeltaPromptStackContent( + content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"] + ) ) elif "metadata" in event: usage = event["metadata"]["usage"] @@ -104,9 +105,7 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent raise ValueError(f"Unsupported content type: {type(content)}") def __to_role(self, input: PromptStackMessage) -> str: - if input.is_system(): - return "system" - elif input.is_assistant(): + if input.is_assistant(): return "assistant" else: return "user" diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index f1858bc46..ad73670e0 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -7,13 +7,7 @@ from attrs import Factory, define, field from griptape.artifacts import TextArtifact -from griptape.common import ( - PromptStack, - PromptStackMessage, - TextPromptStackContent, - DeltaPromptStackMessage, - BaseDeltaPromptStackContent, -) +from griptape.common import PromptStack, PromptStackMessage, TextPromptStackContent, DeltaPromptStackMessage from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency @@ -84,7 +78,7 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: usage=PromptStackMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: raise NotImplementedError("streaming is not supported") def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index a253591d1..be7c48c17 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -7,7 +7,6 @@ from griptape.artifacts import TextArtifact from griptape.common import ( - BaseDeltaPromptStackContent, BasePromptStackContent, DeltaPromptStackMessage, TextDeltaPromptStackContent, @@ -60,12 +59,12 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: ), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: events = self.client.messages.create(**self._base_params(prompt_stack), stream=True) for event in events: if event.type == "content_block_delta": - yield self.__message_content_delta_to_prompt_stack_content_delta(event) + yield DeltaPromptStackMessage(content=self.__message_content_delta_to_prompt_stack_content_delta(event)) elif event.type == "message_start": yield DeltaPromptStackMessage( usage=DeltaPromptStackMessage.Usage(input_tokens=event.message.usage.input_tokens) @@ -99,9 +98,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: } def __to_role(self, input: PromptStackMessage) -> str: - if input.is_system(): - return "system" - elif input.is_assistant(): + if input.is_assistant(): return "assistant" else: return "user" @@ -131,7 +128,7 @@ def __message_content_to_prompt_stack_content(self, content: ContentBlock) -> Ba def __message_content_delta_to_prompt_stack_content_delta( self, content_delta: ContentBlockDeltaEvent - ) -> BaseDeltaPromptStackContent: + ) -> TextDeltaPromptStackContent: index = content_delta.index if content_delta.delta.type == "text_delta": diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 0a08eceb0..48a7063f2 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -108,9 +108,7 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: ... @abstractmethod - def try_stream( - self, prompt_stack: PromptStack - ) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: ... + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: ... def __process_run(self, prompt_stack: PromptStack) -> PromptStackMessage: result = self.try_run(prompt_stack) @@ -126,6 +124,16 @@ def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage: for delta in deltas: if isinstance(delta, DeltaPromptStackMessage): usage += delta.usage + + if delta.content is not None: + if delta.content.index in delta_contents: + delta_contents[delta.content.index].append(delta.content) + else: + delta_contents[delta.content.index] = [delta.content] + + if isinstance(delta, TextDeltaPromptStackContent): + self.structure.publish_event(CompletionChunkEvent(token=delta.text)) + elif isinstance(delta, BaseDeltaPromptStackContent): if delta.index in delta_contents: delta_contents[delta.index].append(delta) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 9c36609b1..36a2427a6 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -9,7 +9,6 @@ PromptStack, PromptStackMessage, DeltaPromptStackMessage, - BaseDeltaPromptStackContent, TextPromptStackContent, BasePromptStackContent, TextDeltaPromptStackContent, @@ -51,13 +50,13 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: usage=PromptStackMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: result = self.client.chat_stream(**self._base_params(prompt_stack)) for event in result: if event.event_type == "text-generation": - yield TextDeltaPromptStackContent(event.text, index=0) - if event.event_type == "stream-end": + yield DeltaPromptStackMessage(content=TextDeltaPromptStackContent(event.text, index=0)) + elif event.event_type == "stream-end": usage = event.response.meta.tokens yield DeltaPromptStackMessage( diff --git a/griptape/drivers/prompt/dummy_prompt_driver.py b/griptape/drivers/prompt/dummy_prompt_driver.py index bb825566a..5f0557869 100644 --- a/griptape/drivers/prompt/dummy_prompt_driver.py +++ b/griptape/drivers/prompt/dummy_prompt_driver.py @@ -3,7 +3,7 @@ from attrs import Factory, define, field -from griptape.common import PromptStack, PromptStackMessage, DeltaPromptStackMessage, BaseDeltaPromptStackContent +from griptape.common import PromptStack, PromptStackMessage, DeltaPromptStackMessage from griptape.drivers import BasePromptDriver from griptape.exceptions import DummyException from griptape.tokenizers import DummyTokenizer @@ -17,5 +17,5 @@ class DummyPromptDriver(BasePromptDriver): def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: raise DummyException(__class__.__name__, "try_run") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: raise DummyException(__class__.__name__, "try_stream") diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index dc4f2d591..2154567b4 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -7,7 +7,6 @@ from griptape.artifacts import TextArtifact from griptape.common import ( - BaseDeltaPromptStackContent, BasePromptStackContent, DeltaPromptStackMessage, TextDeltaPromptStackContent, @@ -71,7 +70,7 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: ), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig messages = self._prompt_stack_to_messages(prompt_stack) @@ -91,12 +90,11 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess for chunk in response: usage_metadata = chunk.usage_metadata - yield TextDeltaPromptStackContent(chunk.text) - # Only want to output the prompt token count once since it is static each chunk if prompt_token_count is None: prompt_token_count = usage_metadata.prompt_token_count yield DeltaPromptStackMessage( + content=TextDeltaPromptStackContent(chunk.text), role=PromptStackMessage.ASSISTANT_ROLE, usage=DeltaPromptStackMessage.Usage( input_tokens=usage_metadata.prompt_token_count, @@ -105,6 +103,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess ) else: yield DeltaPromptStackMessage( + content=TextDeltaPromptStackContent(chunk.text), role=PromptStackMessage.ASSISTANT_ROLE, usage=DeltaPromptStackMessage.Usage(output_tokens=usage_metadata.candidates_token_count), ) diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 45ae5eb87..87cd03b10 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -11,7 +11,6 @@ PromptStack, PromptStackMessage, DeltaPromptStackMessage, - BaseDeltaPromptStackContent, TextPromptStackContent, TextDeltaPromptStackContent, ) @@ -69,7 +68,7 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: usage=PromptStackMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: prompt = self.prompt_stack_to_string(prompt_stack) response = self.client.text_generation( @@ -81,7 +80,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess full_text = "" for token in response: full_text += token - yield TextDeltaPromptStackContent(token, index=0) + yield DeltaPromptStackMessage(content=TextDeltaPromptStackContent(token, index=0)) output_tokens = len(self.tokenizer.tokenizer.encode(full_text)) yield DeltaPromptStackMessage( diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 9a905ae51..4003340d1 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -6,13 +6,7 @@ from attrs import Factory, define, field from griptape.artifacts import TextArtifact -from griptape.common import ( - BaseDeltaPromptStackContent, - DeltaPromptStackMessage, - PromptStack, - PromptStackMessage, - TextPromptStackContent, -) +from griptape.common import DeltaPromptStackMessage, PromptStack, PromptStackMessage, TextPromptStackContent from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency @@ -72,7 +66,7 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: else: raise Exception("invalid output format") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: raise NotImplementedError("streaming is not supported") def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index a913e9477..afb43e8bf 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -8,12 +8,7 @@ from griptape.common import PromptStack, TextPromptStackContent from griptape.utils import import_optional_dependency from griptape.tokenizers import SimpleTokenizer -from griptape.common import ( - PromptStackMessage, - BaseDeltaPromptStackContent, - DeltaPromptStackMessage, - TextDeltaPromptStackContent, -) +from griptape.common import PromptStackMessage, DeltaPromptStackMessage, TextDeltaPromptStackContent if TYPE_CHECKING: from ollama import Client @@ -64,12 +59,12 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: else: raise Exception("invalid model response") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: stream = self.client.chat(**self._base_params(prompt_stack), stream=True) if isinstance(stream, Iterator): for chunk in stream: - yield TextDeltaPromptStackContent(chunk["message"]["content"]) + yield DeltaPromptStackMessage(content=TextDeltaPromptStackContent(chunk["message"]["content"])) else: raise Exception("invalid model response") diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 391425c38..b28c5e644 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -8,7 +8,6 @@ from griptape.artifacts import TextArtifact from griptape.common import ( - BaseDeltaPromptStackContent, BasePromptStackContent, DeltaPromptStackMessage, TextDeltaPromptStackContent, @@ -90,7 +89,7 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: else: raise Exception("Completion with more than one choice is not supported yet.") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: result = self.client.chat.completions.create( **self._base_params(prompt_stack), stream=True, stream_options={"include_usage": True} ) @@ -107,7 +106,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess choice = chunk.choices[0] delta = choice.delta - yield self.__message_delta_to_prompt_stack_content_delta(delta) + yield DeltaPromptStackMessage(content=self.__message_delta_to_prompt_stack_content_delta(delta)) else: raise Exception("Completion with more than one choice is not supported yet.") @@ -166,7 +165,7 @@ def __message_to_prompt_stack_content(self, message: ChatCompletionMessage) -> B else: raise ValueError(f"Unsupported message type: {message}") - def __message_delta_to_prompt_stack_content_delta(self, content_delta: ChoiceDelta) -> BaseDeltaPromptStackContent: + def __message_delta_to_prompt_stack_content_delta(self, content_delta: ChoiceDelta) -> TextDeltaPromptStackContent: if content_delta.content is None: return TextDeltaPromptStackContent("") else: diff --git a/tests/mocks/mock_failing_prompt_driver.py b/tests/mocks/mock_failing_prompt_driver.py index 80684378d..e7ff7ea66 100644 --- a/tests/mocks/mock_failing_prompt_driver.py +++ b/tests/mocks/mock_failing_prompt_driver.py @@ -9,7 +9,6 @@ TextPromptStackContent, DeltaPromptStackMessage, TextDeltaPromptStackContent, - BaseDeltaPromptStackContent, ) from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer @@ -34,13 +33,13 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: usage=PromptStackMessage.Usage(input_tokens=100, output_tokens=100), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: if self.current_attempt < self.max_failures: self.current_attempt += 1 raise Exception("failed attempt") else: yield DeltaPromptStackMessage( - delta_content=TextDeltaPromptStackContent("success"), + content=TextDeltaPromptStackContent("success"), usage=DeltaPromptStackMessage.Usage(input_tokens=100, output_tokens=100), ) diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 1b1e8b38a..16ba38abe 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -10,7 +10,6 @@ PromptStack, PromptStackMessage, DeltaPromptStackMessage, - BaseDeltaPromptStackContent, TextPromptStackContent, TextDeltaPromptStackContent, ) @@ -36,8 +35,10 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: usage=PromptStackMessage.Usage(input_tokens=100, output_tokens=100), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output - yield TextDeltaPromptStackContent(output) - yield DeltaPromptStackMessage(usage=DeltaPromptStackMessage.Usage(input_tokens=100, output_tokens=100)) + yield DeltaPromptStackMessage( + content=TextDeltaPromptStackContent(output), + usage=DeltaPromptStackMessage.Usage(input_tokens=100, output_tokens=100), + ) diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index 783e940f9..f4f4eda11 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,7 +1,7 @@ import pytest +from griptape.artifacts import ImageArtifact, TextArtifact from griptape.common import PromptStack -from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from griptape.drivers import AmazonBedrockPromptDriver @@ -12,7 +12,7 @@ def mock_converse(self, mocker): mock_converse.return_value = { "output": {"message": {"content": [{"text": "model-output"}]}}, - "usage": {"inputTokens": 100, "outputTokens": 100}, + "usage": {"inputTokens": 5, "outputTokens": 10}, } return mock_converse @@ -22,7 +22,10 @@ def mock_converse_stream(self, mocker): mock_converse_stream = mocker.patch("boto3.Session").return_value.client.return_value.converse_stream mock_converse_stream.return_value = { - "stream": [{"contentBlockDelta": {"contentBlockIndex": 0, "delta": {"text": "model-output"}}}] + "stream": [ + {"contentBlockDelta": {"contentBlockIndex": 0, "delta": {"text": "model-output"}}}, + {"metadata": {"usage": {"inputTokens": 5, "outputTokens": 10}}}, + ] } return mock_converse_stream @@ -32,6 +35,8 @@ def prompt_stack(self): prompt_stack = PromptStack() prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message(TextArtifact("user-input")) + prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) prompt_stack.add_assistant_message("assistant-input") return prompt_stack @@ -39,8 +44,9 @@ def prompt_stack(self): @pytest.fixture def messages(self): return [ - {"role": "system", "content": [{"text": "system-input"}]}, {"role": "user", "content": [{"text": "user-input"}]}, + {"role": "user", "content": [{"text": "user-input"}]}, + {"role": "user", "content": [{"image": {"format": "png", "source": {"bytes": b"image-data"}}}]}, {"role": "assistant", "content": [{"text": "assistant-input"}]}, ] @@ -54,34 +60,34 @@ def test_try_run(self, mock_converse, prompt_stack, messages): # Then mock_converse.assert_called_once_with( modelId=driver.model, - messages=[ - {"role": "user", "content": [{"text": "user-input"}]}, - {"role": "assistant", "content": [{"text": "assistant-input"}]}, - ], + messages=messages, system=[{"text": "system-input"}], inferenceConfig={"temperature": driver.temperature}, additionalModelRequestFields={}, ) assert text_artifact.value == "model-output" + assert text_artifact.usage.input_tokens == 5 + assert text_artifact.usage.output_tokens == 10 def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages): # Given driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True) # When - text_artifact = next(driver.try_stream(prompt_stack)) + stream = driver.try_stream(prompt_stack) + event = next(stream) # Then mock_converse_stream.assert_called_once_with( modelId=driver.model, - messages=[ - {"role": "user", "content": [{"text": "user-input"}]}, - {"role": "assistant", "content": [{"text": "assistant-input"}]}, - ], + messages=messages, system=[{"text": "system-input"}], inferenceConfig={"temperature": driver.temperature}, additionalModelRequestFields={}, ) - if isinstance(text_artifact, TextDeltaPromptStackContent): - assert text_artifact.text == "model-output" + assert event.content.text == "model-output" + + event = next(stream) + assert event.usage.input_tokens == 5 + assert event.usage.output_tokens == 10 diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py index 318017c02..a75fc6ed0 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -19,7 +19,8 @@ class TestAmazonSageMakerJumpstartPromptDriver: def tokenizer(self, mocker): from_pretrained = mocker.patch("transformers.AutoTokenizer").from_pretrained from_pretrained.return_value.decode.return_value = "foo\n\nUser: bar" - from_pretrained.return_value.apply_chat_template.return_value = ["foo", "\nbar"] + from_pretrained.return_value.apply_chat_template.return_value = [1, 2, 3] + from_pretrained.return_value.encode.return_value = [1, 2, 3] from_pretrained.return_value.model_max_length = 8000 from_pretrained.return_value.eos_token_id = 1 @@ -65,6 +66,8 @@ def test_try_run(self, mock_client): ) assert text_artifact.value == "foobar" + assert text_artifact.usage.input_tokens == 3 + assert text_artifact.usage.output_tokens == 3 # When response_body = {"generated_text": "foobar"} diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 205d66ffa..aa85021a0 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,6 +1,6 @@ -from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from griptape.drivers import AnthropicPromptDriver from griptape.common import PromptStack +from griptape.artifacts import TextArtifact, ImageArtifact from unittest.mock import Mock import pytest @@ -9,23 +9,36 @@ class TestAnthropicPromptDriver: @pytest.fixture def mock_client(self, mocker): mock_client = mocker.patch("anthropic.Anthropic") - mock_content = Mock() - mock_content.type = "text" - mock_content.text = "model-output" - mock_client.return_value.messages.create.return_value.content = [mock_content] - mock_client.return_value.count_tokens.return_value = 5 + + mock_client.return_value = Mock( + messages=Mock( + create=Mock( + return_value=Mock( + usage=Mock(input_tokens=5, output_tokens=10), content=[Mock(type="text", text="model-output")] + ) + ) + ) + ) return mock_client @pytest.fixture def mock_stream_client(self, mocker): mock_stream_client = mocker.patch("anthropic.Anthropic") - mock_chunk = Mock() - mock_chunk.type = "content_block_delta" - mock_chunk.delta.type = "text_delta" - mock_chunk.delta.text = "model-output" - mock_stream_client.return_value.messages.create.return_value = iter([mock_chunk]) - mock_stream_client.return_value.count_tokens.return_value = 5 + + mock_stream_client.return_value = Mock( + messages=Mock( + create=Mock( + return_value=iter( + [ + Mock(type="message_start", message=Mock(usage=Mock(input_tokens=5))), + Mock(type="content_block_delta", delta=Mock(type="text_delta", text="model-output")), + Mock(type="message_delta", usage=Mock(output_tokens=10)), + ] + ) + ) + ) + ) return mock_stream_client @@ -51,15 +64,27 @@ def test_try_run(self, mock_client, model, system_enabled): if system_enabled: prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message(TextArtifact("user-input")) + prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) prompt_stack.add_assistant_message("assistant-input") driver = AnthropicPromptDriver(model=model, api_key="api-key") expected_messages = [ {"role": "user", "content": "user-input"}, + {"role": "user", "content": "user-input"}, + { + "content": [ + { + "source": {"data": "aW1hZ2UtZGF0YQ==", "media_type": "image/png", "type": "base64"}, + "type": "image", + } + ], + "role": "user", + }, {"role": "assistant", "content": "assistant-input"}, ] # When - text_artifact = driver.try_run(prompt_stack) + message = driver.try_run(prompt_stack) # Then mock_client.return_value.messages.create.assert_called_once_with( @@ -72,7 +97,9 @@ def test_try_run(self, mock_client, model, system_enabled): top_k=250, **{"system": "system-input"} if system_enabled else {}, ) - assert text_artifact.value == "model-output" + assert message.value == "model-output" + assert message.usage.input_tokens == 5 + assert message.usage.output_tokens == 10 @pytest.mark.parametrize( "model", @@ -92,15 +119,28 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): if system_enabled: prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message(TextArtifact("user-input")) + prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) prompt_stack.add_assistant_message("assistant-input") expected_messages = [ {"role": "user", "content": "user-input"}, + {"role": "user", "content": "user-input"}, + { + "content": [ + { + "source": {"data": "aW1hZ2UtZGF0YQ==", "media_type": "image/png", "type": "base64"}, + "type": "image", + } + ], + "role": "user", + }, {"role": "assistant", "content": "assistant-input"}, ] driver = AnthropicPromptDriver(model=model, api_key="api-key", stream=True) # When - text_artifact = next(driver.try_stream(prompt_stack)) + stream = driver.try_stream(prompt_stack) + event = next(stream) # Then mock_stream_client.return_value.messages.create.assert_called_once_with( @@ -114,8 +154,13 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): top_k=250, **{"system": "system-input"} if system_enabled else {}, ) - if isinstance(text_artifact, TextDeltaPromptStackContent): - assert text_artifact.text == "model-output" + assert event.usage.input_tokens == 5 + + event = next(stream) + assert event.content.text == "model-output" + + event = next(stream) + assert event.usage.output_tokens == 10 def test_try_run_throws_when_prompt_stack_is_string(self): # Given diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index 93d0e165a..92544a74e 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -1,6 +1,5 @@ import pytest from unittest.mock import Mock -from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from griptape.drivers import AzureOpenAiChatPromptDriver from tests.unit.drivers.prompt.test_openai_chat_prompt_driver import TestOpenAiChatPromptDriverFixtureMixin @@ -9,20 +8,22 @@ class TestAzureOpenAiChatPromptDriver(TestOpenAiChatPromptDriverFixtureMixin): @pytest.fixture def mock_chat_completion_create(self, mocker): mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create - mock_choice = Mock() - mock_choice.message.content = "model-output" - mock_chat_create.return_value.headers = {} - mock_chat_create.return_value.choices = [mock_choice] + mock_chat_create.return_value = Mock( + headers={}, + choices=[Mock(message=Mock(content="model-output"))], + usage=Mock(prompt_tokens=5, completion_tokens=10), + ) return mock_chat_create @pytest.fixture def mock_chat_completion_stream_create(self, mocker): mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create - mock_chunk = Mock() - mock_choice = Mock() - mock_choice.delta.content = "model-output" - mock_chunk.choices = [mock_choice] - mock_chat_create.return_value = iter([mock_chunk]) + mock_chat_create.return_value = iter( + [ + Mock(choices=[Mock(delta=Mock(content="model-output"))], usage=None), + Mock(choices=None, usage=Mock(prompt_tokens=5, completion_tokens=10)), + ] + ) return mock_chat_create def test_init(self): @@ -41,6 +42,8 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): model=driver.model, temperature=driver.temperature, user=driver.user, messages=messages ) assert text_artifact.value == "model-output" + assert text_artifact.usage.input_tokens == 5 + assert text_artifact.usage.output_tokens == 10 def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages): # Given @@ -49,7 +52,8 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, ) # When - text_artifact = next(driver.try_stream(prompt_stack)) + stream = driver.try_stream(prompt_stack) + event = next(stream) # Then mock_chat_completion_stream_create.assert_called_once_with( @@ -61,5 +65,8 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, stream_options={"include_usage": True}, ) - if isinstance(text_artifact, TextDeltaPromptStackContent): - assert text_artifact == "model-output" + assert event.content.text == "model-output" + + event = next(stream) + assert event.usage.input_tokens == 5 + assert event.usage.output_tokens == 10 diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 7f6c6f400..cf556fc1e 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.common import TextDeltaPromptStackContent, PromptStack +from griptape.common import PromptStack from griptape.drivers import CoherePromptDriver @@ -10,15 +10,21 @@ class TestCoherePromptDriver: @pytest.fixture def mock_client(self, mocker): mock_client = mocker.patch("cohere.Client").return_value - mock_client.chat.return_value = Mock(text="model-output") + mock_client.chat.return_value = Mock( + text="model-output", meta=Mock(tokens=Mock(input_tokens=5, output_tokens=10)) + ) return mock_client @pytest.fixture def mock_stream_client(self, mocker): mock_client = mocker.patch("cohere.Client").return_value - mock_chunk = Mock(text="model-output", event_type="text-generation") - mock_client.chat_stream.return_value = iter([mock_chunk]) + mock_client.chat_stream.return_value = iter( + [ + Mock(text="model-output", event_type="text-generation"), + Mock(response=Mock(meta=Mock(tokens=Mock(input_tokens=5, output_tokens=10))), event_type="stream-end"), + ] + ) return mock_client @@ -45,15 +51,40 @@ def test_try_run(self, mock_client, prompt_stack): # pyright: ignore text_artifact = driver.try_run(prompt_stack) # Then + mock_client.chat.assert_called_once_with( + chat_history=[{"content": [{"text": "user-input"}], "role": "USER"}], + max_tokens=None, + message="assistant-input", + preamble="system-input", + stop_sequences=[], + temperature=0.1, + ) + assert text_artifact.value == "model-output" + assert text_artifact.usage.input_tokens == 5 + assert text_artifact.usage.output_tokens == 10 def test_try_stream_run(self, mock_stream_client, prompt_stack): # pyright: ignore # Given driver = CoherePromptDriver(model="command", api_key="api-key", stream=True) # When - text_artifact = next(driver.try_stream(prompt_stack)) + stream = driver.try_stream(prompt_stack) + event = next(stream) # Then - if isinstance(text_artifact, TextDeltaPromptStackContent): - assert text_artifact.text == "model-output" + + mock_stream_client.chat_stream.assert_called_once_with( + chat_history=[{"content": [{"text": "user-input"}], "role": "USER"}], + max_tokens=None, + message="assistant-input", + preamble="system-input", + stop_sequences=[], + temperature=0.1, + ) + + assert event.content.text == "model-output" + + event = next(stream) + assert event.usage.input_tokens == 5 + assert event.usage.output_tokens == 10 diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 0ee048338..9a454a563 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -1,5 +1,5 @@ from google.generativeai.types import GenerationConfig -from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent +from griptape.artifacts import TextArtifact, ImageArtifact from griptape.drivers import GooglePromptDriver from griptape.common import PromptStack from unittest.mock import Mock @@ -10,14 +10,21 @@ class TestGooglePromptDriver: @pytest.fixture def mock_generative_model(self, mocker): mock_generative_model = mocker.patch("google.generativeai.GenerativeModel") - mock_generative_model.return_value.generate_content.return_value = Mock(text="model-output") + mock_generative_model.return_value.generate_content.return_value = Mock( + text="model-output", usage_metadata=Mock(prompt_token_count=5, candidates_token_count=10) + ) return mock_generative_model @pytest.fixture def mock_stream_generative_model(self, mocker): mock_generative_model = mocker.patch("google.generativeai.GenerativeModel") - mock_generative_model.return_value.generate_content.return_value = iter([Mock(text="model-output")]) + mock_generative_model.return_value.generate_content.return_value = iter( + [ + Mock(text="model-output", usage_metadata=Mock(prompt_token_count=5, candidates_token_count=5)), + Mock(text="model-output", usage_metadata=Mock(prompt_token_count=5, candidates_token_count=5)), + ] + ) return mock_generative_model @@ -30,6 +37,8 @@ def test_try_run(self, mock_generative_model): prompt_stack = PromptStack() prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message(TextArtifact("user-input")) + prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) prompt_stack.add_assistant_message("assistant-input") driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", top_p=0.5, top_k=50) @@ -40,6 +49,8 @@ def test_try_run(self, mock_generative_model): mock_generative_model.return_value.generate_content.assert_called_once_with( [ {"parts": ["system-input", "user-input"], "role": "user"}, + {"parts": ["user-input"], "role": "user"}, + {"parts": [{"data": b"image-data", "mime_type": "image/png"}], "role": "user"}, {"parts": ["assistant-input"], "role": "model"}, ], generation_config=GenerationConfig( @@ -47,27 +58,37 @@ def test_try_run(self, mock_generative_model): ), ) assert text_artifact.value == "model-output" + assert text_artifact.usage.input_tokens == 5 + assert text_artifact.usage.output_tokens == 10 def test_try_stream(self, mock_stream_generative_model): # Given prompt_stack = PromptStack() prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message(TextArtifact("user-input")) + prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) prompt_stack.add_assistant_message("assistant-input") driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", stream=True, top_p=0.5, top_k=50) # When - text_artifact_stream = driver.try_stream(prompt_stack) + stream = driver.try_stream(prompt_stack) # Then - text_artifact = next(text_artifact_stream) + event = next(stream) mock_stream_generative_model.return_value.generate_content.assert_called_once_with( [ {"parts": ["system-input", "user-input"], "role": "user"}, + {"parts": ["user-input"], "role": "user"}, + {"parts": [{"data": b"image-data", "mime_type": "image/png"}], "role": "user"}, {"parts": ["assistant-input"], "role": "model"}, ], stream=True, generation_config=GenerationConfig(temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[]), ) - if isinstance(text_artifact, TextDeltaPromptStackContent): - assert text_artifact.text == "model-output" + assert event.content.text == "model-output" + assert event.usage.input_tokens == 5 + assert event.usage.output_tokens == 5 + + event = next(stream) + assert event.usage.output_tokens == 5 diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 6e7367a10..4618e1de3 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -1,4 +1,3 @@ -from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from griptape.drivers import HuggingFaceHubPromptDriver from griptape.common import PromptStack import pytest @@ -8,6 +7,7 @@ class TestHuggingFaceHubPromptDriver: @pytest.fixture def mock_client(self, mocker): mock_client = mocker.patch("huggingface_hub.InferenceClient").return_value + mock_client.text_generation.return_value = "model-output" return mock_client @@ -15,6 +15,8 @@ def mock_client(self, mocker): def tokenizer(self, mocker): from_pretrained = tokenizer = mocker.patch("transformers.AutoTokenizer").from_pretrained from_pretrained.return_value.apply_chat_template.return_value = [1, 2, 3] + from_pretrained.return_value.decode.return_value = "foo\n\nUser: bar" + from_pretrained.return_value.encode.return_value = [1, 2, 3] return tokenizer @@ -47,18 +49,24 @@ def test_try_run(self, prompt_stack, mock_client): driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id") # When - text_artifact = driver.try_run(prompt_stack) + message = driver.try_run(prompt_stack) # Then - assert text_artifact.value == "model-output" + assert message.value == "model-output" + assert message.usage.input_tokens == 3 + assert message.usage.output_tokens == 3 def test_try_stream(self, prompt_stack, mock_client_stream): # Given driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id", stream=True) # When - text_artifact = next(driver.try_stream(prompt_stack)) + stream = driver.try_stream(prompt_stack) + event = next(stream) # Then - if isinstance(text_artifact, TextDeltaPromptStackContent): - assert text_artifact.text == "model-output" + assert event.content.text == "model-output" + + event = next(stream) + assert event.usage.input_tokens == 3 + assert event.usage.output_tokens == 3 diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index defb53056..a63d697fb 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -22,6 +22,7 @@ def mock_autotokenizer(self, mocker): mock_autotokenizer.model_max_length = 42 mock_autotokenizer.apply_chat_template.return_value = [1, 2, 3] mock_autotokenizer.decode.return_value = "model-output" + mock_autotokenizer.encode.return_value = [1, 2, 3] return mock_autotokenizer @pytest.fixture @@ -40,10 +41,12 @@ def test_try_run(self, prompt_stack): driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) # When - text_artifact = driver.try_run(prompt_stack) + message = driver.try_run(prompt_stack) # Then - assert text_artifact.value == "model-output" + assert message.value == "model-output" + assert message.usage.input_tokens == 3 + assert message.usage.output_tokens == 3 def test_try_stream(self, prompt_stack): # Given diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index 31ee3fec7..5cfa1bc5d 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -37,7 +37,7 @@ def test_try_run(self, mock_client): ] # When - text_artifact = driver.try_run(prompt_stack) + message = driver.try_run(prompt_stack) # Then mock_client.return_value.chat.assert_called_once_with( @@ -45,7 +45,9 @@ def test_try_run(self, mock_client): model=driver.model, options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, ) - assert text_artifact.value == "model-output" + assert message.value == "model-output" + assert message.usage.input_tokens is None + assert message.usage.output_tokens is None def test_try_run_bad_response(self, mock_client): # Given diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 6d8d2cb51..6ccb4d4bb 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,5 +1,7 @@ +from griptape.artifacts.image_artifact import ImageArtifact +from griptape.artifacts.text_artifact import TextArtifact from griptape.drivers import OpenAiChatPromptDriver -from griptape.common import PromptStack, TextDeltaPromptStackContent +from griptape.common import PromptStack from griptape.tokenizers import OpenAiTokenizer from unittest.mock import Mock from tests.mocks.mock_tokenizer import MockTokenizer @@ -10,20 +12,23 @@ class TestOpenAiChatPromptDriverFixtureMixin: @pytest.fixture def mock_chat_completion_create(self, mocker): mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create - mock_choice = Mock() - mock_choice.message.content = "model-output" - mock_chat_create.return_value.headers = {} - mock_chat_create.return_value.choices = [mock_choice] + mock_chat_create.return_value = Mock( + headers={}, + choices=[Mock(message=Mock(content="model-output"))], + usage=Mock(prompt_tokens=5, completion_tokens=10), + ) + return mock_chat_create @pytest.fixture def mock_chat_completion_stream_create(self, mocker): mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create - mock_chunk = Mock() - mock_choice = Mock() - mock_choice.delta.content = "model-output" - mock_chunk.choices = [mock_choice] - mock_chat_create.return_value = iter([mock_chunk]) + mock_chat_create.return_value = iter( + [ + Mock(choices=[Mock(delta=Mock(content="model-output"))], usage=None), + Mock(choices=None, usage=Mock(prompt_tokens=5, completion_tokens=10)), + ] + ) return mock_chat_create @pytest.fixture @@ -31,6 +36,8 @@ def prompt_stack(self): prompt_stack = PromptStack() prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message(TextArtifact("user-input")) + prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) prompt_stack.add_assistant_message("assistant-input") return prompt_stack @@ -39,6 +46,11 @@ def messages(self): return [ {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, + {"role": "user", "content": "user-input"}, + { + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": "data:image/png;base64,aW1hZ2UtZGF0YQ=="}}], + }, {"role": "assistant", "content": "assistant-input"}, ] @@ -85,13 +97,13 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) # When - text_artifact = driver.try_run(prompt_stack) + event = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( model=driver.model, temperature=driver.temperature, user=driver.user, messages=messages, seed=driver.seed ) - assert text_artifact.value == "model-output" + assert event.value == "model-output" def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack, messages): # Given @@ -112,13 +124,16 @@ def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack response_format={"type": "json_object"}, ) assert element.value == "model-output" + assert element.usage.input_tokens == 5 + assert element.usage.output_tokens == 10 def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True) # When - text_artifact = next(driver.try_stream(prompt_stack)) + stream = driver.try_stream(prompt_stack) + event = next(stream) # Then mock_chat_completion_stream_create.assert_called_once_with( @@ -131,15 +146,18 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, stream_options={"include_usage": True}, ) - if isinstance(text_artifact, TextDeltaPromptStackContent): - assert text_artifact.text == "model-output" + assert event.content.text == "model-output" + + event = next(stream) + assert event.usage.input_tokens == 5 + assert event.usage.output_tokens == 10 def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1) # When - text_artifact = driver.try_run(prompt_stack) + event = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( @@ -150,7 +168,7 @@ def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack max_tokens=1, seed=driver.seed, ) - assert text_artifact.value == "model-output" + assert event.value == "model-output" def test_try_run_throws_when_prompt_stack_is_string(self): # Given @@ -163,18 +181,17 @@ def test_try_run_throws_when_prompt_stack_is_string(self): # Then assert e.value.args[0] == "'str' object has no attribute 'messages'" - @pytest.mark.parametrize("choices", [[], [1, 2]]) - def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_chat_completion_create, prompt_stack): + def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completion_create, prompt_stack): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, api_key="api-key") - mock_chat_completion_create.return_value.choices = [choices] + mock_chat_completion_create.return_value.choices = [Mock(message=Mock(content="model-output"))] * 10 # When with pytest.raises(Exception) as e: driver.try_run(prompt_stack) # Then - e.value.args[0] == "Completion with more than one choice is not supported yet." + assert e.value.args[0] == "Completion with more than one choice is not supported yet." def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): driver = OpenAiChatPromptDriver( @@ -184,7 +201,7 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa ) # When - text_artifact = driver.try_run(prompt_stack) + event = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( @@ -192,12 +209,8 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa temperature=driver.temperature, stop=driver.tokenizer.stop_sequences, user=driver.user, - messages=[ - {"role": "system", "content": "system-input"}, - {"role": "user", "content": "user-input"}, - {"role": "assistant", "content": "assistant-input"}, - ], + messages=messages, seed=driver.seed, max_tokens=1, ) - assert text_artifact.value == "model-output" + assert event.value == "model-output" diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index 1dd45ab64..a0fb1fd59 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -1,4 +1,7 @@ import pytest +from griptape.artifacts.image_artifact import ImageArtifact +from griptape.artifacts.list_artifact import ListArtifact +from griptape.artifacts.text_artifact import TextArtifact from tests.mocks.mock_structure_config import MockStructureConfig from griptape.tasks import PromptTask from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -32,3 +35,59 @@ def test_missing_prompt_driver(self): with pytest.raises(ValueError): task.prompt_driver + + def test_input(self): + task = PromptTask("test") + + assert task.input.value == "test" + + task = PromptTask(["test1", "test2"]) + + assert task.input.value[0].value == "test1" + assert task.input.value[1].value == "test2" + + task = PromptTask(("test1", "test2")) + + assert task.input.value[0].value == "test1" + assert task.input.value[1].value == "test2" + + task = PromptTask(ImageArtifact(b"image-data", format="png", width=100, height=100)) + + assert isinstance(task.input, ImageArtifact) + assert task.input.value == b"image-data" + assert task.input.format == "png" + assert task.input.width == 100 + assert task.input.height == 100 + + task = PromptTask(["foo", ImageArtifact(b"image-data", format="png", width=100, height=100)]) + + assert isinstance(task.input, ListArtifact) + assert task.input.value[0].value == "foo" + assert isinstance(task.input.value[1], ImageArtifact) + assert task.input.value[1].value == b"image-data" + assert task.input.value[1].format == "png" + assert task.input.value[1].width == 100 + + task = PromptTask( + ListArtifact([TextArtifact("foo"), ImageArtifact(b"image-data", format="png", width=100, height=100)]) + ) + + assert isinstance(task.input, ListArtifact) + assert task.input.value[0].value == "foo" + assert isinstance(task.input.value[1], ImageArtifact) + assert task.input.value[1].value == b"image-data" + assert task.input.value[1].format == "png" + assert task.input.value[1].width == 100 + + task = PromptTask( + lambda _: ListArtifact( + [TextArtifact("foo"), ImageArtifact(b"image-data", format="png", width=100, height=100)] + ) + ) + + assert isinstance(task.input, ListArtifact) + assert task.input.value[0].value == "foo" + assert isinstance(task.input.value[1], ImageArtifact) + assert task.input.value[1].value == b"image-data" + assert task.input.value[1].format == "png" + assert task.input.value[1].width == 100 diff --git a/tests/unit/utils/test_prompt_stack.py b/tests/unit/utils/test_prompt_stack.py index 87976c02d..98c9f48ff 100644 --- a/tests/unit/utils/test_prompt_stack.py +++ b/tests/unit/utils/test_prompt_stack.py @@ -1,6 +1,7 @@ import pytest -from griptape.common import PromptStack +from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact +from griptape.common import ImagePromptStackContent, PromptStack, TextPromptStackContent class TestPromptStack: @@ -13,10 +14,28 @@ def test_init(self): def test_add_message(self, prompt_stack): prompt_stack.add_message("foo", "role") + prompt_stack.add_message(TextArtifact("foo"), "role") + prompt_stack.add_message(ImageArtifact(b"foo", format="png", width=100, height=100), "role") + prompt_stack.add_message(ListArtifact([TextArtifact("foo"), TextArtifact("bar")]), "role") assert prompt_stack.messages[0].role == "role" + assert isinstance(prompt_stack.messages[0].content[0], TextPromptStackContent) assert prompt_stack.messages[0].content[0].artifact.value == "foo" + assert prompt_stack.messages[1].role == "role" + assert isinstance(prompt_stack.messages[1].content[0], TextPromptStackContent) + assert prompt_stack.messages[1].content[0].artifact.value == "foo" + + assert prompt_stack.messages[2].role == "role" + assert isinstance(prompt_stack.messages[2].content[0], ImagePromptStackContent) + assert prompt_stack.messages[2].content[0].artifact.value == b"foo" + + assert prompt_stack.messages[3].role == "role" + assert isinstance(prompt_stack.messages[3].content[0], TextPromptStackContent) + assert prompt_stack.messages[3].content[0].artifact.value == "foo" + assert isinstance(prompt_stack.messages[3].content[1], TextPromptStackContent) + assert prompt_stack.messages[3].content[1].artifact.value == "bar" + def test_add_system_message(self, prompt_stack): prompt_stack.add_system_message("foo") From 43826d388b5f52f80ab8924f4f1333d772e985e4 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 20 Jun 2024 15:42:44 -0700 Subject: [PATCH 17/34] Add image input support to ollama --- .../messages/prompt_stack_message.py | 11 ++++---- .../drivers/prompt/ollama_prompt_driver.py | 25 ++++++++++++++++--- .../prompt/test_ollama_prompt_driver.py | 13 ++++++++++ 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/griptape/common/prompt_stack/messages/prompt_stack_message.py b/griptape/common/prompt_stack/messages/prompt_stack_message.py index 4b393f570..d67da48bc 100644 --- a/griptape/common/prompt_stack/messages/prompt_stack_message.py +++ b/griptape/common/prompt_stack/messages/prompt_stack_message.py @@ -33,9 +33,8 @@ def to_text(self) -> str: return self.to_text_artifact().to_text() def to_text_artifact(self) -> TextArtifact: - artifact = TextArtifact(value="") - - for content in self.content: - artifact.value += content.artifact.to_text() - - return artifact + return TextArtifact( + "".join( + [content.artifact.to_text() for content in self.content if isinstance(content, TextPromptStackContent)] + ) + ) diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index afb43e8bf..d6a1083fb 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -9,6 +9,7 @@ from griptape.utils import import_optional_dependency from griptape.tokenizers import SimpleTokenizer from griptape.common import PromptStackMessage, DeltaPromptStackMessage, TextDeltaPromptStackContent +from griptape.common import ImagePromptStackContent if TYPE_CHECKING: from ollama import Client @@ -69,8 +70,26 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess raise Exception("invalid model response") def _base_params(self, prompt_stack: PromptStack) -> dict: - messages = [ - {"role": message.role, "content": message.to_text_artifact().to_text()} for message in prompt_stack.messages - ] + messages = self._prompt_stack_to_messages(prompt_stack) return {"messages": messages, "model": self.model, "options": self.options} + + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: + return [ + { + "role": message.role, + "content": message.to_text_artifact().to_text(), + **( + { + "images": [ + content.artifact.base64 + for content in message.content + if isinstance(content, ImagePromptStackContent) + ] + } + if any(isinstance(content, ImagePromptStackContent) for content in message.content) + else {} + ), + } + for message in prompt_stack.messages + ] diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index 5cfa1bc5d..e737aeaeb 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,6 +1,7 @@ from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent from griptape.drivers import OllamaPromptDriver from griptape.common import PromptStack +from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact import pytest @@ -28,11 +29,17 @@ def test_try_run(self, mock_client): prompt_stack = PromptStack() prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message( + ListArtifact( + [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] + ) + ) prompt_stack.add_assistant_message("assistant-input") driver = OllamaPromptDriver(model="llama") expected_messages = [ {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, + {"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, {"role": "assistant", "content": "assistant-input"}, ] @@ -64,10 +71,16 @@ def test_try_stream_run(self, mock_stream_client): prompt_stack = PromptStack() prompt_stack.add_system_message("system-input") prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message( + ListArtifact( + [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] + ) + ) prompt_stack.add_assistant_message("assistant-input") expected_messages = [ {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, + {"role": "user", "content": "user-input", "images": ["aW1hZ2UtZGF0YQ=="]}, {"role": "assistant", "content": "assistant-input"}, ] driver = OllamaPromptDriver(model="llama", stream=True) From c98163786fb91a343352f448a7186a49782082e1 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Thu, 20 Jun 2024 15:44:26 -0700 Subject: [PATCH 18/34] Fix tests --- griptape/drivers/prompt/base_prompt_driver.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 48a7063f2..064ae697c 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -131,17 +131,8 @@ def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage: else: delta_contents[delta.content.index] = [delta.content] - if isinstance(delta, TextDeltaPromptStackContent): - self.structure.publish_event(CompletionChunkEvent(token=delta.text)) - - elif isinstance(delta, BaseDeltaPromptStackContent): - if delta.index in delta_contents: - delta_contents[delta.index].append(delta) - else: - delta_contents[delta.index] = [delta] - - if isinstance(delta, TextDeltaPromptStackContent): - self.structure.publish_event(CompletionChunkEvent(token=delta.text)) + if isinstance(delta.content, TextDeltaPromptStackContent): + self.structure.publish_event(CompletionChunkEvent(token=delta.content.text)) # Build a complete content from the content deltas content = [] From 828e892de69950a199a3f413639008880b9f2b85 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 08:43:38 -0700 Subject: [PATCH 19/34] Rename inputs to messages --- .../prompt/amazon_bedrock_prompt_driver.py | 16 +++++------ ...mazon_sagemaker_jumpstart_prompt_driver.py | 4 +-- .../drivers/prompt/anthropic_prompt_driver.py | 26 +++++++++-------- griptape/drivers/prompt/base_prompt_driver.py | 17 ++++++----- .../drivers/prompt/cohere_prompt_driver.py | 28 +++++++++---------- .../drivers/prompt/google_prompt_driver.py | 14 +++++----- .../prompt/openai_chat_prompt_driver.py | 21 ++++++++------ .../prompt/test_openai_chat_prompt_driver.py | 8 +++--- 8 files changed, 69 insertions(+), 65 deletions(-) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 62e5f97eb..18018857e 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -70,22 +70,22 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess else: raise Exception("model response is empty") - def _prompt_stack_messages_to_messages(self, elements: list[PromptStackMessage]) -> list[dict]: + def _prompt_stack_messages_to_messages(self, messages: list[PromptStackMessage]) -> list[dict]: return [ { - "role": self.__to_role(input), - "content": [self.__prompt_stack_content_message_content(content) for content in input.content], + "role": self.__to_role(message), + "content": [self.__prompt_stack_content_message_content(content) for content in message.content], } - for input in elements + for message in messages ] def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = [ - {"text": input.to_text_artifact().to_text()} for input in prompt_stack.messages if input.is_system() + {"text": message.to_text_artifact().to_text()} for message in prompt_stack.messages if message.is_system() ] messages = self._prompt_stack_messages_to_messages( - [input for input in prompt_stack.messages if not input.is_system()] + [message for message in prompt_stack.messages if not message.is_system()] ) return { @@ -104,8 +104,8 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent else: raise ValueError(f"Unsupported content type: {type(content)}") - def __to_role(self, input: PromptStackMessage) -> str: - if input.is_assistant(): + def __to_role(self, message: PromptStackMessage) -> str: + if message.is_assistant(): return "assistant" else: return "user" diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index ad73670e0..0be2d1f58 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -97,8 +97,8 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] - for input in prompt_stack.messages: - messages.append({"role": input.role, "content": TextPromptStackContent(input.to_text_artifact())}) + for message in prompt_stack.messages: + messages.append({"role": message.role, "content": TextPromptStackContent(message.to_text_artifact())}) return messages diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index be7c48c17..90e7f8d52 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -74,15 +74,17 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess usage=DeltaPromptStackMessage.Usage(output_tokens=event.usage.output_tokens) ) - def _prompt_stack_messages_to_messages(self, elements: list[PromptStackMessage]) -> list[dict]: - return [{"role": self.__to_role(input), "content": self.__to_content(input)} for input in elements] + def _prompt_stack_messages_to_messages(self, messages: list[PromptStackMessage]) -> list[dict]: + return [{"role": self.__to_role(message), "content": self.__to_content(message)} for message in messages] def _base_params(self, prompt_stack: PromptStack) -> dict: - messages = self._prompt_stack_messages_to_messages([i for i in prompt_stack.messages if not i.is_system()]) + messages = self._prompt_stack_messages_to_messages( + [message for message in prompt_stack.messages if not message.is_system()] + ) - system_element = next((i for i in prompt_stack.messages if i.is_system()), None) - if system_element: - system_message = system_element.to_text_artifact().to_text() + system_message = next((message for message in prompt_stack.messages if message.is_system()), None) + if system_message: + system_message = system_message.to_text_artifact().to_text() else: system_message = None @@ -97,17 +99,17 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"system": system_message} if system_message else {}), } - def __to_role(self, input: PromptStackMessage) -> str: - if input.is_assistant(): + def __to_role(self, message: PromptStackMessage) -> str: + if message.is_assistant(): return "assistant" else: return "user" - def __to_content(self, input: PromptStackMessage) -> str | list[dict]: - if all(isinstance(content, TextPromptStackContent) for content in input.content): - return input.to_text_artifact().to_text() + def __to_content(self, message: PromptStackMessage) -> str | list[dict]: + if all(isinstance(content, TextPromptStackContent) for content in message.content): + return message.to_text_artifact().to_text() else: - return [self.__prompt_stack_content_message_content(content) for content in input.content] + return [self.__prompt_stack_content_message_content(content) for content in message.content] def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: if isinstance(content, TextPromptStackContent): diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 064ae697c..d7f9463bc 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -122,17 +122,16 @@ def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage: # Aggregate all content deltas from the stream deltas = self.try_stream(prompt_stack) for delta in deltas: - if isinstance(delta, DeltaPromptStackMessage): - usage += delta.usage + usage += delta.usage - if delta.content is not None: - if delta.content.index in delta_contents: - delta_contents[delta.content.index].append(delta.content) - else: - delta_contents[delta.content.index] = [delta.content] + if delta.content is not None: + if delta.content.index in delta_contents: + delta_contents[delta.content.index].append(delta.content) + else: + delta_contents[delta.content.index] = [delta.content] - if isinstance(delta.content, TextDeltaPromptStackContent): - self.structure.publish_event(CompletionChunkEvent(token=delta.content.text)) + if isinstance(delta.content, TextDeltaPromptStackContent): + self.structure.publish_event(CompletionChunkEvent(token=delta.content.text)) # Build a complete content from the content deltas content = [] diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 36a2427a6..4d28184a6 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -65,13 +65,13 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess ) ) - def _prompt_stack_messages_to_messages(self, elements: list[PromptStackMessage]) -> list[dict]: + def _prompt_stack_messages_to_messages(self, messages: list[PromptStackMessage]) -> list[dict]: return [ { - "role": self.__to_role(input), - "content": [self.__prompt_stack_content_message_content(content) for content in input.content], + "role": self.__to_role(message), + "content": [self.__prompt_stack_content_message_content(content) for content in message.content], } - for input in elements + for message in messages ] def _base_params(self, prompt_stack: PromptStack) -> dict: @@ -79,18 +79,18 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if last_input is not None and len(last_input.content) == 1: user_message = last_input.content[0].artifact.to_text() else: - raise ValueError("User element must have exactly one content.") + raise ValueError("User message must have exactly one content.") history_messages = self._prompt_stack_messages_to_messages( - [input for input in prompt_stack.messages[:-1] if not input.is_system()] + [message for message in prompt_stack.messages[:-1] if not message.is_system()] ) - system_element = next((input for input in prompt_stack.messages if input.is_system()), None) - if system_element is not None: - if len(system_element.content) == 1: - preamble = system_element.content[0].artifact.to_text() + system_message = next((message for message in prompt_stack.messages if message.is_system()), None) + if system_message is not None: + if len(system_message.content) == 1: + preamble = system_message.content[0].artifact.to_text() else: - raise ValueError("System element must have exactly one content.") + raise ValueError("System message must have exactly one content.") else: preamble = None @@ -109,10 +109,10 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent else: raise ValueError(f"Unsupported content type: {type(content)}") - def __to_role(self, input: PromptStackMessage) -> str: - if input.is_system(): + def __to_role(self, message: PromptStackMessage) -> str: + if message.is_system(): return "SYSTEM" - elif input.is_user(): + elif message.is_user(): return "USER" else: return "CHATBOT" diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 2154567b4..8c0867f10 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -116,9 +116,9 @@ def _default_model_client(self) -> GenerativeModel: def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: inputs = [ - {"role": self.__to_role(input), "parts": self.__to_content(input)} - for input in prompt_stack.messages - if not input.is_system() + {"role": self.__to_role(message), "parts": self.__to_content(message)} + for message in prompt_stack.messages + if not message.is_system() ] # Gemini does not have the notion of a system message, so we insert it as part of the first message in the history. @@ -138,11 +138,11 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent else: raise ValueError(f"Unsupported content type: {type(content)}") - def __to_role(self, input: PromptStackMessage) -> str: - if input.is_assistant(): + def __to_role(self, message: PromptStackMessage) -> str: + if message.is_assistant(): return "model" else: return "user" - def __to_content(self, input: PromptStackMessage) -> list[ContentDict | str]: - return [self.__prompt_stack_content_message_content(content) for content in input.content] + def __to_content(self, message: PromptStackMessage) -> list[ContentDict | str]: + return [self.__prompt_stack_content_message_content(content) for content in message.content] diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index b28c5e644..d7dafbfc6 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -111,7 +111,10 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess raise Exception("Completion with more than one choice is not supported yet.") def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: - return [{"role": self.__to_role(input), "content": self.__to_content(input)} for input in prompt_stack.messages] + return [ + {"role": self.__to_role(message), "content": self.__to_content(message)} + for message in prompt_stack.messages + ] def _base_params(self, prompt_stack: PromptStack) -> dict: params = { @@ -125,7 +128,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if self.response_format == "json_object": params["response_format"] = {"type": "json_object"} - # JSON mode still requires a system input instructing the LLM to output JSON. + # JSON mode still requires a system message instructing the LLM to output JSON. prompt_stack.add_system_message("Provide your response as a valid JSON object.") messages = self._prompt_stack_to_messages(prompt_stack) @@ -134,19 +137,19 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: return params - def __to_role(self, input: PromptStackMessage) -> str: - if input.is_system(): + def __to_role(self, message: PromptStackMessage) -> str: + if message.is_system(): return "system" - elif input.is_assistant(): + elif message.is_assistant(): return "assistant" else: return "user" - def __to_content(self, input: PromptStackMessage) -> str | list[dict]: - if all(isinstance(content, TextPromptStackContent) for content in input.content): - return input.to_text_artifact().to_text() + def __to_content(self, message: PromptStackMessage) -> str | list[dict]: + if all(isinstance(content, TextPromptStackContent) for content in message.content): + return message.to_text_artifact().to_text() else: - return [self.__prompt_stack_content_message_content(content) for content in input.content] + return [self.__prompt_stack_content_message_content(content) for content in message.content] def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: if isinstance(content, TextPromptStackContent): diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 6ccb4d4bb..01de35028 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -112,7 +112,7 @@ def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack ) # When - element = driver.try_run(prompt_stack) + message = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( @@ -123,9 +123,9 @@ def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack seed=driver.seed, response_format={"type": "json_object"}, ) - assert element.value == "model-output" - assert element.usage.input_tokens == 5 - assert element.usage.output_tokens == 10 + assert message.value == "model-output" + assert message.usage.input_tokens == 5 + assert message.usage.output_tokens == 10 def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages): # Given From 7af4e4e8ea3ad633d943d72c46097ecc02daca55 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 09:21:43 -0700 Subject: [PATCH 20/34] Big rename --- .../drivers/prompt-drivers.md | 20 ++--- docs/griptape-framework/misc/events.md | 8 +- griptape/common/__init__.py | 36 ++++----- .../__init__.py | 0 .../contents/__init__.py | 0 .../contents/base_delta_message_content.py} | 2 +- .../contents/base_message_content.py} | 6 +- .../contents/image_message_content.py} | 6 +- .../contents/text_delta_message_content.py} | 4 +- .../contents/text_message_content.py} | 8 +- .../message_stack.py} | 30 ++++---- .../messages/base_message.py} | 14 ++-- .../message_stack/messages/delta_message.py | 15 ++++ .../messages/message.py} | 16 ++-- .../messages/delta_prompt_stack_message.py | 15 ---- .../prompt/amazon_bedrock_prompt_driver.py | 60 +++++++-------- ...mazon_sagemaker_jumpstart_prompt_driver.py | 38 +++++----- .../drivers/prompt/anthropic_prompt_driver.py | 76 +++++++++---------- .../prompt/azure_openai_chat_prompt_driver.py | 6 +- griptape/drivers/prompt/base_prompt_driver.py | 66 ++++++++-------- .../drivers/prompt/cohere_prompt_driver.py | 56 +++++++------- .../drivers/prompt/dummy_prompt_driver.py | 6 +- .../drivers/prompt/google_prompt_driver.py | 64 ++++++++-------- .../prompt/huggingface_hub_prompt_driver.py | 46 +++++------ .../huggingface_pipeline_prompt_driver.py | 30 ++++---- .../drivers/prompt/ollama_prompt_driver.py | 34 ++++----- .../prompt/openai_chat_prompt_driver.py | 68 ++++++++--------- .../extraction/csv_extraction_engine.py | 10 +-- .../extraction/json_extraction_engine.py | 12 +-- griptape/engines/query/vector_query_engine.py | 18 ++--- .../engines/summary/prompt_summary_engine.py | 16 ++-- griptape/events/start_prompt_event.py | 4 +- .../structure/base_conversation_memory.py | 36 ++++----- .../memory/structure/conversation_memory.py | 12 +-- .../structure/summary_conversation_memory.py | 10 +-- griptape/schemas/base_schema.py | 6 +- griptape/tasks/prompt_task.py | 10 +-- griptape/tasks/tool_task.py | 2 +- griptape/tasks/toolkit_task.py | 12 +-- griptape/utils/conversation.py | 4 +- tests/mocks/mock_failing_prompt_driver.py | 26 +++---- tests/mocks/mock_prompt_driver.py | 31 +++----- .../test_amazon_bedrock_prompt_driver.py | 26 +++---- ...mazon_sagemaker_jumpstart_prompt_driver.py | 22 +++--- .../prompt/test_anthropic_prompt_driver.py | 36 ++++----- .../test_azure_openai_chat_prompt_driver.py | 8 +- .../drivers/prompt/test_base_prompt_driver.py | 4 +- .../prompt/test_cohere_prompt_driver.py | 22 +++--- .../prompt/test_google_prompt_driver.py | 30 ++++---- .../test_hugging_face_hub_prompt_driver.py | 22 +++--- ...est_hugging_face_pipeline_prompt_driver.py | 34 ++++----- .../prompt/test_ollama_prompt_driver.py | 38 +++++----- .../prompt/test_openai_chat_prompt_driver.py | 44 +++++------ .../summary/test_prompt_summary_engine.py | 6 +- tests/unit/events/test_base_event.py | 20 ++--- tests/unit/events/test_start_prompt_event.py | 18 ++--- .../structure/test_conversation_memory.py | 74 +++++++++--------- .../test_summary_conversation_memory.py | 10 +-- tests/unit/structures/test_agent.py | 16 ++-- tests/unit/structures/test_pipeline.py | 28 +++---- .../unit/tokenizers/test_google_tokenizer.py | 8 +- tests/unit/utils/test_conversation.py | 8 +- tests/unit/utils/test_message_stack.py | 55 ++++++++++++++ tests/unit/utils/test_prompt_stack.py | 55 -------------- 64 files changed, 740 insertions(+), 783 deletions(-) rename griptape/common/{prompt_stack => message_stack}/__init__.py (100%) rename griptape/common/{prompt_stack => message_stack}/contents/__init__.py (100%) rename griptape/common/{prompt_stack/contents/base_delta_prompt_stack_content.py => message_stack/contents/base_delta_message_content.py} (80%) rename griptape/common/{prompt_stack/contents/base_prompt_stack_content.py => message_stack/contents/base_message_content.py} (72%) rename griptape/common/{prompt_stack/contents/image_prompt_stack_content.py => message_stack/contents/image_message_content.py} (54%) rename griptape/common/{prompt_stack/contents/text_delta_prompt_stack_content.py => message_stack/contents/text_delta_message_content.py} (52%) rename griptape/common/{prompt_stack/contents/text_prompt_stack_content.py => message_stack/contents/text_message_content.py} (59%) rename griptape/common/{prompt_stack/prompt_stack.py => message_stack/message_stack.py} (57%) rename griptape/common/{prompt_stack/messages/base_prompt_stack_message.py => message_stack/messages/base_message.py} (68%) create mode 100644 griptape/common/message_stack/messages/delta_message.py rename griptape/common/{prompt_stack/messages/prompt_stack_message.py => message_stack/messages/message.py} (55%) delete mode 100644 griptape/common/prompt_stack/messages/delta_prompt_stack_message.py create mode 100644 tests/unit/utils/test_message_stack.py delete mode 100644 tests/unit/utils/test_prompt_stack.py diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 883a4f4b4..bceda20cb 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -28,23 +28,23 @@ agent.run("I loved the new Batman movie!") Or use them independently: ```python -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.drivers import OpenAiChatPromptDriver -stack = PromptStack() +stack = MessageStack() stack.add_system_input( "You will be provided with Python code, and your task is to calculate its time complexity." ) stack.add_user_input( -""" -def foo(n, k): - accum = 0 - for i in range(n): - for l in range(k): - accum += i - return accum -""" + """ + def foo(n, k): + accum = 0 + for i in range(n): + for l in range(k): + accum += i + return accum + """ ) result = OpenAiChatPromptDriver(model="gpt-3.5-turbo-16k", temperature=0).run(stack) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index 851b5f382..bea54f761 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -235,7 +235,7 @@ total tokens: 273 ## Inspecting Payloads -You can use the [StartPromptEvent](../../reference/griptape/events/start_prompt_event.md) to inspect the Prompt Stack and final prompt string before it is sent to the LLM. +You can use the [StartPromptEvent](../../reference/griptape/events/start_prompt_event.md) to inspect the Message Stack and final prompt string before it is sent to the LLM. ```python from griptape.structures import Agent @@ -244,8 +244,8 @@ from griptape.events import BaseEvent, StartPromptEvent, EventListener def handler(event: BaseEvent): if isinstance(event, StartPromptEvent): - print("Prompt Stack Messages:") - for message in event.prompt_stack.messages: + print("Message Stack MessageStack:") + for message in event.message_stack.messages: print(f"{message.role}: {message.content}") print("Final Prompt String:") print(event.prompt) @@ -259,7 +259,7 @@ agent.run("Write me a poem.") ``` ``` ... -Prompt Stack Messages: +Message Stack Messages: system: user: Write me a poem. Final Prompt String: diff --git a/griptape/common/__init__.py b/griptape/common/__init__.py index ff598d638..87c81c2ee 100644 --- a/griptape/common/__init__.py +++ b/griptape/common/__init__.py @@ -1,23 +1,23 @@ -from .prompt_stack.contents.base_prompt_stack_content import BasePromptStackContent -from .prompt_stack.contents.base_delta_prompt_stack_content import BaseDeltaPromptStackContent -from .prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent -from .prompt_stack.contents.text_prompt_stack_content import TextPromptStackContent -from .prompt_stack.contents.image_prompt_stack_content import ImagePromptStackContent +from .message_stack.contents.base_message_content import BaseMessageContent +from .message_stack.contents.base_delta_message_content import BaseDeltaMessageContent +from .message_stack.contents.text_delta_message_content import TextDeltaMessageContent +from .message_stack.contents.text_message_content import TextMessageContent +from .message_stack.contents.image_message_content import ImageMessageContent -from .prompt_stack.messages.base_prompt_stack_message import BasePromptStackMessage -from .prompt_stack.messages.delta_prompt_stack_message import DeltaPromptStackMessage -from .prompt_stack.messages.prompt_stack_message import PromptStackMessage +from .message_stack.messages.base_message import BaseMessage +from .message_stack.messages.delta_message import DeltaMessage +from .message_stack.messages.message import Message -from .prompt_stack.prompt_stack import PromptStack +from .message_stack.message_stack import MessageStack __all__ = [ - "BasePromptStackMessage", - "BaseDeltaPromptStackContent", - "BasePromptStackContent", - "DeltaPromptStackMessage", - "PromptStackMessage", - "TextDeltaPromptStackContent", - "TextPromptStackContent", - "ImagePromptStackContent", - "PromptStack", + "BaseMessage", + "BaseDeltaMessageContent", + "BaseMessageContent", + "DeltaMessage", + "Message", + "TextDeltaMessageContent", + "TextMessageContent", + "ImageMessageContent", + "MessageStack", ] diff --git a/griptape/common/prompt_stack/__init__.py b/griptape/common/message_stack/__init__.py similarity index 100% rename from griptape/common/prompt_stack/__init__.py rename to griptape/common/message_stack/__init__.py diff --git a/griptape/common/prompt_stack/contents/__init__.py b/griptape/common/message_stack/contents/__init__.py similarity index 100% rename from griptape/common/prompt_stack/contents/__init__.py rename to griptape/common/message_stack/contents/__init__.py diff --git a/griptape/common/prompt_stack/contents/base_delta_prompt_stack_content.py b/griptape/common/message_stack/contents/base_delta_message_content.py similarity index 80% rename from griptape/common/prompt_stack/contents/base_delta_prompt_stack_content.py rename to griptape/common/message_stack/contents/base_delta_message_content.py index 8f0cc7ae9..344f0bb7a 100644 --- a/griptape/common/prompt_stack/contents/base_delta_prompt_stack_content.py +++ b/griptape/common/message_stack/contents/base_delta_message_content.py @@ -8,5 +8,5 @@ @define -class BaseDeltaPromptStackContent(ABC, SerializableMixin): +class BaseDeltaMessageContent(ABC, SerializableMixin): index: int = field(kw_only=True, default=0, metadata={"serializable": True}) diff --git a/griptape/common/prompt_stack/contents/base_prompt_stack_content.py b/griptape/common/message_stack/contents/base_message_content.py similarity index 72% rename from griptape/common/prompt_stack/contents/base_prompt_stack_content.py rename to griptape/common/message_stack/contents/base_message_content.py index 74d94bc6d..cd4d5bd92 100644 --- a/griptape/common/prompt_stack/contents/base_prompt_stack_content.py +++ b/griptape/common/message_stack/contents/base_message_content.py @@ -8,11 +8,11 @@ from griptape.artifacts.base_artifact import BaseArtifact from griptape.mixins import SerializableMixin -from .base_delta_prompt_stack_content import BaseDeltaPromptStackContent +from .base_delta_message_content import BaseDeltaMessageContent @define -class BasePromptStackContent(ABC, SerializableMixin): +class BaseMessageContent(ABC, SerializableMixin): artifact: BaseArtifact = field(metadata={"serializable": True}) def to_text(self) -> str: @@ -28,4 +28,4 @@ def __len__(self) -> int: return len(self.artifact) @classmethod - def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> BasePromptStackContent: ... + def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> BaseMessageContent: ... diff --git a/griptape/common/prompt_stack/contents/image_prompt_stack_content.py b/griptape/common/message_stack/contents/image_message_content.py similarity index 54% rename from griptape/common/prompt_stack/contents/image_prompt_stack_content.py rename to griptape/common/message_stack/contents/image_message_content.py index ab61b3c19..0192a2cb4 100644 --- a/griptape/common/prompt_stack/contents/image_prompt_stack_content.py +++ b/griptape/common/message_stack/contents/image_message_content.py @@ -5,13 +5,13 @@ from attrs import define, field from griptape.artifacts import ImageArtifact -from griptape.common import BaseDeltaPromptStackContent, BasePromptStackContent +from griptape.common import BaseDeltaMessageContent, BaseMessageContent @define -class ImagePromptStackContent(BasePromptStackContent): +class ImageMessageContent(BaseMessageContent): artifact: ImageArtifact = field(metadata={"serializable": True}) @classmethod - def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> ImagePromptStackContent: + def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ImageMessageContent: raise NotImplementedError() diff --git a/griptape/common/prompt_stack/contents/text_delta_prompt_stack_content.py b/griptape/common/message_stack/contents/text_delta_message_content.py similarity index 52% rename from griptape/common/prompt_stack/contents/text_delta_prompt_stack_content.py rename to griptape/common/message_stack/contents/text_delta_message_content.py index 05fdc1b45..ab5313df6 100644 --- a/griptape/common/prompt_stack/contents/text_delta_prompt_stack_content.py +++ b/griptape/common/message_stack/contents/text_delta_message_content.py @@ -1,9 +1,9 @@ from __future__ import annotations from attrs import define, field -from griptape.common import BaseDeltaPromptStackContent +from griptape.common import BaseDeltaMessageContent @define -class TextDeltaPromptStackContent(BaseDeltaPromptStackContent): +class TextDeltaMessageContent(BaseDeltaMessageContent): text: str = field(metadata={"serializable": True}) diff --git a/griptape/common/prompt_stack/contents/text_prompt_stack_content.py b/griptape/common/message_stack/contents/text_message_content.py similarity index 59% rename from griptape/common/prompt_stack/contents/text_prompt_stack_content.py rename to griptape/common/message_stack/contents/text_message_content.py index 93cebd25b..1d7b2bd5b 100644 --- a/griptape/common/prompt_stack/contents/text_prompt_stack_content.py +++ b/griptape/common/message_stack/contents/text_message_content.py @@ -4,16 +4,16 @@ from collections.abc import Sequence from griptape.artifacts import TextArtifact -from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent, TextDeltaPromptStackContent +from griptape.common import BaseMessageContent, BaseDeltaMessageContent, TextDeltaMessageContent @define -class TextPromptStackContent(BasePromptStackContent): +class TextMessageContent(BaseMessageContent): artifact: TextArtifact = field(metadata={"serializable": True}) @classmethod - def from_deltas(cls, deltas: Sequence[BaseDeltaPromptStackContent]) -> TextPromptStackContent: - text_deltas = [delta for delta in deltas if isinstance(delta, TextDeltaPromptStackContent)] + def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> TextMessageContent: + text_deltas = [delta for delta in deltas if isinstance(delta, TextDeltaMessageContent)] artifact = TextArtifact(value="".join(delta.text for delta in text_deltas)) diff --git a/griptape/common/prompt_stack/prompt_stack.py b/griptape/common/message_stack/message_stack.py similarity index 57% rename from griptape/common/prompt_stack/prompt_stack.py rename to griptape/common/message_stack/message_stack.py index 9091ca708..4a8bb6985 100644 --- a/griptape/common/prompt_stack/prompt_stack.py +++ b/griptape/common/message_stack/message_stack.py @@ -3,36 +3,36 @@ from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact, ImageArtifact from griptape.mixins import SerializableMixin -from griptape.common import PromptStackMessage, TextPromptStackContent, BasePromptStackContent, ImagePromptStackContent +from griptape.common import Message, TextMessageContent, BaseMessageContent, ImageMessageContent @define -class PromptStack(SerializableMixin): - messages: list[PromptStackMessage] = field(factory=list, kw_only=True, metadata={"serializable": True}) +class MessageStack(SerializableMixin): + messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True}) - def add_message(self, artifact: str | BaseArtifact, role: str) -> PromptStackMessage: + def add_message(self, artifact: str | BaseArtifact, role: str) -> Message: new_content = self.__process_artifact(artifact) - self.messages.append(PromptStackMessage(content=new_content, role=role)) + self.messages.append(Message(content=new_content, role=role)) return self.messages[-1] - def add_system_message(self, artifact: str | BaseArtifact) -> PromptStackMessage: - return self.add_message(artifact, PromptStackMessage.SYSTEM_ROLE) + def add_system_message(self, artifact: str | BaseArtifact) -> Message: + return self.add_message(artifact, Message.SYSTEM_ROLE) - def add_user_message(self, artifact: str | BaseArtifact) -> PromptStackMessage: - return self.add_message(artifact, PromptStackMessage.USER_ROLE) + def add_user_message(self, artifact: str | BaseArtifact) -> Message: + return self.add_message(artifact, Message.USER_ROLE) - def add_assistant_message(self, artifact: str | BaseArtifact) -> PromptStackMessage: - return self.add_message(artifact, PromptStackMessage.ASSISTANT_ROLE) + def add_assistant_message(self, artifact: str | BaseArtifact) -> Message: + return self.add_message(artifact, Message.ASSISTANT_ROLE) - def __process_artifact(self, artifact: str | BaseArtifact) -> list[BasePromptStackContent]: + def __process_artifact(self, artifact: str | BaseArtifact) -> list[BaseMessageContent]: if isinstance(artifact, str): - return [TextPromptStackContent(TextArtifact(artifact))] + return [TextMessageContent(TextArtifact(artifact))] elif isinstance(artifact, TextArtifact): - return [TextPromptStackContent(artifact)] + return [TextMessageContent(artifact)] elif isinstance(artifact, ImageArtifact): - return [ImagePromptStackContent(artifact)] + return [ImageMessageContent(artifact)] elif isinstance(artifact, ListArtifact): processed_contents = [self.__process_artifact(artifact) for artifact in artifact.value] flattened_content = [ diff --git a/griptape/common/prompt_stack/messages/base_prompt_stack_message.py b/griptape/common/message_stack/messages/base_message.py similarity index 68% rename from griptape/common/prompt_stack/messages/base_prompt_stack_message.py rename to griptape/common/message_stack/messages/base_message.py index c6a77012d..3cc8e532a 100644 --- a/griptape/common/prompt_stack/messages/base_prompt_stack_message.py +++ b/griptape/common/message_stack/messages/base_message.py @@ -5,12 +5,12 @@ from attrs import Factory, define, field -from griptape.common import BasePromptStackContent, BaseDeltaPromptStackContent +from griptape.common import BaseMessageContent, BaseDeltaMessageContent from griptape.mixins import SerializableMixin @define -class BasePromptStackMessage(ABC, SerializableMixin): +class BaseMessage(ABC, SerializableMixin): @define class Usage(SerializableMixin): input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True}) @@ -20,8 +20,8 @@ class Usage(SerializableMixin): def total_tokens(self) -> float: return (self.input_tokens or 0) + (self.output_tokens or 0) - def __add__(self, other: BasePromptStackMessage.Usage) -> BasePromptStackMessage.Usage: - return BasePromptStackMessage.Usage( + def __add__(self, other: BaseMessage.Usage) -> BaseMessage.Usage: + return BaseMessage.Usage( input_tokens=(self.input_tokens or 0) + (other.input_tokens or 0), output_tokens=(self.output_tokens or 0) + (other.output_tokens or 0), ) @@ -30,11 +30,9 @@ def __add__(self, other: BasePromptStackMessage.Usage) -> BasePromptStackMessage ASSISTANT_ROLE = "assistant" SYSTEM_ROLE = "system" - content: list[Union[BasePromptStackContent, BaseDeltaPromptStackContent]] = field(metadata={"serializable": True}) + content: list[Union[BaseMessageContent, BaseDeltaMessageContent]] = field(metadata={"serializable": True}) role: str = field(kw_only=True, metadata={"serializable": True}) - usage: Usage = field( - kw_only=True, default=Factory(lambda: BasePromptStackMessage.Usage()), metadata={"serializable": True} - ) + usage: Usage = field(kw_only=True, default=Factory(lambda: BaseMessage.Usage()), metadata={"serializable": True}) def is_system(self) -> bool: return self.role == self.SYSTEM_ROLE diff --git a/griptape/common/message_stack/messages/delta_message.py b/griptape/common/message_stack/messages/delta_message.py new file mode 100644 index 000000000..f022c8e0a --- /dev/null +++ b/griptape/common/message_stack/messages/delta_message.py @@ -0,0 +1,15 @@ +from __future__ import annotations +from typing import Optional + +from attrs import define, field + +from griptape.common.message_stack.contents.text_delta_message_content import TextDeltaMessageContent + + +from .base_message import BaseMessage + + +@define +class DeltaMessage(BaseMessage): + role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) + content: Optional[TextDeltaMessageContent] = field(kw_only=True, default=None, metadata={"serializable": True}) diff --git a/griptape/common/prompt_stack/messages/prompt_stack_message.py b/griptape/common/message_stack/messages/message.py similarity index 55% rename from griptape/common/prompt_stack/messages/prompt_stack_message.py rename to griptape/common/message_stack/messages/message.py index d67da48bc..3a4e07ccb 100644 --- a/griptape/common/prompt_stack/messages/prompt_stack_message.py +++ b/griptape/common/message_stack/messages/message.py @@ -5,19 +5,19 @@ from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.common import BasePromptStackContent, TextPromptStackContent +from griptape.common import BaseMessageContent, TextMessageContent -from .base_prompt_stack_message import BasePromptStackMessage +from .base_message import BaseMessage @define -class PromptStackMessage(BasePromptStackMessage): - def __init__(self, content: str | list[BasePromptStackContent], **kwargs: Any): +class Message(BaseMessage): + def __init__(self, content: str | list[BaseMessageContent], **kwargs: Any): if isinstance(content, str): - content = [TextPromptStackContent(TextArtifact(value=content))] + content = [TextMessageContent(TextArtifact(value=content))] self.__attrs_init__(content, **kwargs) # pyright: ignore[reportAttributeAccessIssue] - content: list[BasePromptStackContent] = field(metadata={"serializable": True}) + content: list[BaseMessageContent] = field(metadata={"serializable": True}) @property def value(self) -> Any: @@ -34,7 +34,5 @@ def to_text(self) -> str: def to_text_artifact(self) -> TextArtifact: return TextArtifact( - "".join( - [content.artifact.to_text() for content in self.content if isinstance(content, TextPromptStackContent)] - ) + "".join([content.artifact.to_text() for content in self.content if isinstance(content, TextMessageContent)]) ) diff --git a/griptape/common/prompt_stack/messages/delta_prompt_stack_message.py b/griptape/common/prompt_stack/messages/delta_prompt_stack_message.py deleted file mode 100644 index cf0799193..000000000 --- a/griptape/common/prompt_stack/messages/delta_prompt_stack_message.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations -from typing import Optional - -from attrs import define, field - -from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent - - -from .base_prompt_stack_message import BasePromptStackMessage - - -@define -class DeltaPromptStackMessage(BasePromptStackMessage): - role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) - content: Optional[TextDeltaPromptStackContent] = field(kw_only=True, default=None, metadata={"serializable": True}) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 18018857e..6673a8283 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -7,12 +7,12 @@ from griptape.artifacts import TextArtifact from griptape.common import ( - DeltaPromptStackMessage, - PromptStackMessage, - TextDeltaPromptStackContent, - BasePromptStackContent, - TextPromptStackContent, - ImagePromptStackContent, + DeltaMessage, + Message, + TextDeltaMessageContent, + BaseMessageContent, + TextMessageContent, + ImageMessageContent, ) from griptape.drivers import BasePromptDriver from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer @@ -21,7 +21,7 @@ if TYPE_CHECKING: import boto3 - from griptape.common import PromptStack + from griptape.common import MessageStack @define @@ -35,57 +35,55 @@ class AmazonBedrockPromptDriver(BasePromptDriver): default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True ) - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: - response = self.bedrock_client.converse(**self._base_params(prompt_stack)) + def try_run(self, message_stack: MessageStack) -> Message: + response = self.bedrock_client.converse(**self._base_params(message_stack)) usage = response["usage"] output_message = response["output"]["message"] - return PromptStackMessage( - content=[TextPromptStackContent(TextArtifact(content["text"])) for content in output_message["content"]], - role=PromptStackMessage.ASSISTANT_ROLE, - usage=PromptStackMessage.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]), + return Message( + content=[TextMessageContent(TextArtifact(content["text"])) for content in output_message["content"]], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: - response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack)) + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + response = self.bedrock_client.converse_stream(**self._base_params(message_stack)) stream = response.get("stream") if stream is not None: for event in stream: if "contentBlockDelta" in event: content_block_delta = event["contentBlockDelta"] - yield DeltaPromptStackMessage( - content=TextDeltaPromptStackContent( + yield DeltaMessage( + content=TextDeltaMessageContent( content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"] ) ) elif "metadata" in event: usage = event["metadata"]["usage"] - yield DeltaPromptStackMessage( - usage=DeltaPromptStackMessage.Usage( - input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"] - ) + yield DeltaMessage( + usage=DeltaMessage.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]) ) else: raise Exception("model response is empty") - def _prompt_stack_messages_to_messages(self, messages: list[PromptStackMessage]) -> list[dict]: + def _message_stack_messages_to_messages(self, messages: list[Message]) -> list[dict]: return [ { "role": self.__to_role(message), - "content": [self.__prompt_stack_content_message_content(content) for content in message.content], + "content": [self.__message_stack_content_message_content(content) for content in message.content], } for message in messages ] - def _base_params(self, prompt_stack: PromptStack) -> dict: + def _base_params(self, message_stack: MessageStack) -> dict: system_messages = [ - {"text": message.to_text_artifact().to_text()} for message in prompt_stack.messages if message.is_system() + {"text": message.to_text_artifact().to_text()} for message in message_stack.messages if message.is_system() ] - messages = self._prompt_stack_messages_to_messages( - [message for message in prompt_stack.messages if not message.is_system()] + messages = self._message_stack_messages_to_messages( + [message for message in message_stack.messages if not message.is_system()] ) return { @@ -96,15 +94,15 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "additionalModelRequestFields": self.additional_model_request_fields, } - def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: - if isinstance(content, TextPromptStackContent): + def __message_stack_content_message_content(self, content: BaseMessageContent) -> dict: + if isinstance(content, TextMessageContent): return {"text": content.artifact.to_text()} - elif isinstance(content, ImagePromptStackContent): + elif isinstance(content, ImageMessageContent): return {"image": {"format": content.artifact.format, "source": {"bytes": content.artifact.value}}} else: raise ValueError(f"Unsupported content type: {type(content)}") - def __to_role(self, message: PromptStackMessage) -> str: + def __to_role(self, message: Message) -> str: if message.is_assistant(): return "assistant" else: diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index 0be2d1f58..ff59ced5a 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -7,7 +7,7 @@ from attrs import Factory, define, field from griptape.artifacts import TextArtifact -from griptape.common import PromptStack, PromptStackMessage, TextPromptStackContent, DeltaPromptStackMessage +from griptape.common import MessageStack, Message, TextMessageContent, DeltaMessage from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency @@ -15,7 +15,7 @@ if TYPE_CHECKING: import boto3 - from griptape.common import PromptStack + from griptape.common import MessageStack @define @@ -41,10 +41,10 @@ def validate_stream(self, _, stream): if stream: raise ValueError("streaming is not supported") - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: + def try_run(self, message_stack: MessageStack) -> Message: payload = { - "inputs": self.prompt_stack_to_string(prompt_stack), - "parameters": {**self._base_params(prompt_stack)}, + "inputs": self.message_stack_to_string(message_stack), + "parameters": {**self._base_params(message_stack)}, } response = self.sagemaker_client.invoke_endpoint( @@ -69,22 +69,22 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: else: generated_text = decoded_body["generated_text"] - input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) + input_tokens = len(self.__message_stack_to_tokens(message_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) - return PromptStackMessage( - content=[TextPromptStackContent(TextArtifact(generated_text))], - role=PromptStackMessage.ASSISTANT_ROLE, - usage=PromptStackMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens), + return Message( + content=[TextMessageContent(TextArtifact(generated_text))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: raise NotImplementedError("streaming is not supported") - def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: - return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) + def message_stack_to_string(self, message_stack: MessageStack) -> str: + return self.tokenizer.tokenizer.decode(self.__message_stack_to_tokens(message_stack)) - def _base_params(self, prompt_stack: PromptStack) -> dict: + def _base_params(self, message_stack: MessageStack) -> dict: return { "temperature": self.temperature, "max_new_tokens": self.max_tokens, @@ -94,16 +94,16 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: "return_full_text": False, } - def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: + def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: messages = [] - for message in prompt_stack.messages: - messages.append({"role": message.role, "content": TextPromptStackContent(message.to_text_artifact())}) + for message in message_stack.messages: + messages.append({"role": message.role, "content": TextMessageContent(message.to_text_artifact())}) return messages - def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: - messages = self._prompt_stack_to_messages(prompt_stack) + def __message_stack_to_tokens(self, message_stack: MessageStack) -> list[int]: + messages = self._message_stack_to_messages(message_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 90e7f8d52..a7d7b62bc 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -7,13 +7,13 @@ from griptape.artifacts import TextArtifact from griptape.common import ( - BasePromptStackContent, - DeltaPromptStackMessage, - TextDeltaPromptStackContent, - ImagePromptStackContent, - PromptStack, - PromptStackMessage, - TextPromptStackContent, + BaseMessageContent, + DeltaMessage, + TextDeltaMessageContent, + ImageMessageContent, + MessageStack, + Message, + TextMessageContent, ) from griptape.drivers import BasePromptDriver from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer @@ -48,41 +48,35 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: - response = self.client.messages.create(**self._base_params(prompt_stack)) + def try_run(self, message_stack: MessageStack) -> Message: + response = self.client.messages.create(**self._base_params(message_stack)) - return PromptStackMessage( - content=[self.__message_content_to_prompt_stack_content(content) for content in response.content], - role=PromptStackMessage.ASSISTANT_ROLE, - usage=PromptStackMessage.Usage( - input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens - ), + return Message( + content=[self.__message_content_to_message_stack_content(content) for content in response.content], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: - events = self.client.messages.create(**self._base_params(prompt_stack), stream=True) + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + events = self.client.messages.create(**self._base_params(message_stack), stream=True) for event in events: if event.type == "content_block_delta": - yield DeltaPromptStackMessage(content=self.__message_content_delta_to_prompt_stack_content_delta(event)) + yield DeltaMessage(content=self.__message_content_delta_to_message_stack_content_delta(event)) elif event.type == "message_start": - yield DeltaPromptStackMessage( - usage=DeltaPromptStackMessage.Usage(input_tokens=event.message.usage.input_tokens) - ) + yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=event.message.usage.input_tokens)) elif event.type == "message_delta": - yield DeltaPromptStackMessage( - usage=DeltaPromptStackMessage.Usage(output_tokens=event.usage.output_tokens) - ) + yield DeltaMessage(usage=DeltaMessage.Usage(output_tokens=event.usage.output_tokens)) - def _prompt_stack_messages_to_messages(self, messages: list[PromptStackMessage]) -> list[dict]: + def _message_stack_messages_to_messages(self, messages: list[Message]) -> list[dict]: return [{"role": self.__to_role(message), "content": self.__to_content(message)} for message in messages] - def _base_params(self, prompt_stack: PromptStack) -> dict: - messages = self._prompt_stack_messages_to_messages( - [message for message in prompt_stack.messages if not message.is_system()] + def _base_params(self, message_stack: MessageStack) -> dict: + messages = self._message_stack_messages_to_messages( + [message for message in message_stack.messages if not message.is_system()] ) - system_message = next((message for message in prompt_stack.messages if message.is_system()), None) + system_message = next((message for message in message_stack.messages if message.is_system()), None) if system_message: system_message = system_message.to_text_artifact().to_text() else: @@ -99,22 +93,22 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"system": system_message} if system_message else {}), } - def __to_role(self, message: PromptStackMessage) -> str: + def __to_role(self, message: Message) -> str: if message.is_assistant(): return "assistant" else: return "user" - def __to_content(self, message: PromptStackMessage) -> str | list[dict]: - if all(isinstance(content, TextPromptStackContent) for content in message.content): + def __to_content(self, message: Message) -> str | list[dict]: + if all(isinstance(content, TextMessageContent) for content in message.content): return message.to_text_artifact().to_text() else: - return [self.__prompt_stack_content_message_content(content) for content in message.content] + return [self.__message_stack_content_message_content(content) for content in message.content] - def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: - if isinstance(content, TextPromptStackContent): + def __message_stack_content_message_content(self, content: BaseMessageContent) -> dict: + if isinstance(content, TextMessageContent): return {"type": "text", "text": content.artifact.to_text()} - elif isinstance(content, ImagePromptStackContent): + elif isinstance(content, ImageMessageContent): return { "type": "image", "source": {"type": "base64", "media_type": content.artifact.mime_type, "data": content.artifact.base64}, @@ -122,18 +116,18 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent else: raise ValueError(f"Unsupported prompt content type: {type(content)}") - def __message_content_to_prompt_stack_content(self, content: ContentBlock) -> BasePromptStackContent: + def __message_content_to_message_stack_content(self, content: ContentBlock) -> BaseMessageContent: if content.type == "text": - return TextPromptStackContent(TextArtifact(content.text)) + return TextMessageContent(TextArtifact(content.text)) else: raise ValueError(f"Unsupported message content type: {content.type}") - def __message_content_delta_to_prompt_stack_content_delta( + def __message_content_delta_to_message_stack_content_delta( self, content_delta: ContentBlockDeltaEvent - ) -> TextDeltaPromptStackContent: + ) -> TextDeltaMessageContent: index = content_delta.index if content_delta.delta.type == "text_delta": - return TextDeltaPromptStackContent(content_delta.delta.text, index=index) + return TextDeltaMessageContent(content_delta.delta.text, index=index) else: raise ValueError(f"Unsupported message content delta type : {content_delta.delta.type}") diff --git a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py index 50e9effe6..37350f0a6 100644 --- a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py @@ -1,6 +1,6 @@ from attrs import define, field, Factory from typing import Callable, Optional -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.drivers import OpenAiChatPromptDriver import openai @@ -41,8 +41,8 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver): ) ) - def _base_params(self, prompt_stack: PromptStack) -> dict: - params = super()._base_params(prompt_stack) + def _base_params(self, message_stack: MessageStack) -> dict: + params = super()._base_params(message_stack) # TODO: Add `seed` parameter once Azure supports it. del params["seed"] diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index d7f9463bc..5adb2b601 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -8,12 +8,12 @@ from griptape.artifacts.text_artifact import TextArtifact from griptape.common import ( - BaseDeltaPromptStackContent, - DeltaPromptStackMessage, - TextDeltaPromptStackContent, - PromptStack, - PromptStackMessage, - TextPromptStackContent, + BaseDeltaMessageContent, + DeltaMessage, + TextDeltaMessageContent, + MessageStack, + Message, + TextMessageContent, ) from griptape.events import CompletionChunkEvent, FinishPromptEvent, StartPromptEvent from griptape.mixins import ExponentialBackoffMixin, SerializableMixin @@ -31,7 +31,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): temperature: The temperature to use for the completion. max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer. structure: An optional `Structure` to publish events to. - prompt_stack_to_string: A function that converts a `PromptStack` to a string. + message_stack_to_string: A function that converts a `MessageStack` to a string. ignored_exception_types: A tuple of exception types to ignore. model: The model name. tokenizer: An instance of `BaseTokenizer` to when calculating tokens. @@ -48,11 +48,11 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - def before_run(self, prompt_stack: PromptStack) -> None: + def before_run(self, message_stack: MessageStack) -> None: if self.structure: - self.structure.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) + self.structure.publish_event(StartPromptEvent(model=self.model, message_stack=message_stack)) - def after_run(self, result: PromptStackMessage) -> None: + def after_run(self, result: Message) -> None: if self.structure: self.structure.publish_event( FinishPromptEvent( @@ -63,15 +63,15 @@ def after_run(self, result: PromptStackMessage) -> None: ) ) - def run(self, prompt_stack: PromptStack) -> TextArtifact: + def run(self, message_stack: MessageStack) -> TextArtifact: for attempt in self.retrying(): with attempt: - self.before_run(prompt_stack) + self.before_run(message_stack) if self.stream: - result = self.__process_stream(prompt_stack) + result = self.__process_stream(message_stack) else: - result = self.__process_run(prompt_stack) + result = self.__process_run(message_stack) self.after_run(result) @@ -79,19 +79,19 @@ def run(self, prompt_stack: PromptStack) -> TextArtifact: else: raise Exception("prompt driver failed after all retry attempts") - def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: - """Converts a Prompt Stack to a string for token counting or model input. + def message_stack_to_string(self, message_stack: MessageStack) -> str: + """Converts a Message Stack to a string for token counting or model input. This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens. Args: - prompt_stack: The Prompt Stack to convert to a string. + message_stack: The Message Stack to convert to a string. Returns: - A single string representation of the Prompt Stack. + A single string representation of the Message Stack. """ prompt_lines = [] - for i in prompt_stack.messages: + for i in message_stack.messages: content = i.to_text_artifact().to_text() if i.is_user(): prompt_lines.append(f"User: {content}") @@ -105,22 +105,22 @@ def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return "\n\n".join(prompt_lines) @abstractmethod - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: ... + def try_run(self, message_stack: MessageStack) -> Message: ... @abstractmethod - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: ... + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: ... - def __process_run(self, prompt_stack: PromptStack) -> PromptStackMessage: - result = self.try_run(prompt_stack) + def __process_run(self, message_stack: MessageStack) -> Message: + result = self.try_run(message_stack) return result - def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage: - delta_contents: dict[int, list[BaseDeltaPromptStackContent]] = {} - usage = DeltaPromptStackMessage.Usage() + def __process_stream(self, message_stack: MessageStack) -> Message: + delta_contents: dict[int, list[BaseDeltaMessageContent]] = {} + usage = DeltaMessage.Usage() # Aggregate all content deltas from the stream - deltas = self.try_stream(prompt_stack) + deltas = self.try_stream(message_stack) for delta in deltas: usage += delta.usage @@ -130,20 +130,20 @@ def __process_stream(self, prompt_stack: PromptStack) -> PromptStackMessage: else: delta_contents[delta.content.index] = [delta.content] - if isinstance(delta.content, TextDeltaPromptStackContent): + if isinstance(delta.content, TextDeltaMessageContent): self.structure.publish_event(CompletionChunkEvent(token=delta.content.text)) # Build a complete content from the content deltas content = [] for index, deltas in delta_contents.items(): - text_deltas = [delta for delta in deltas if isinstance(delta, TextDeltaPromptStackContent)] + text_deltas = [delta for delta in deltas if isinstance(delta, TextDeltaMessageContent)] if text_deltas: - content.append(TextPromptStackContent.from_deltas(text_deltas)) + content.append(TextMessageContent.from_deltas(text_deltas)) - result = PromptStackMessage( + result = Message( content=content, - role=PromptStackMessage.ASSISTANT_ROLE, - usage=PromptStackMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), ) return result diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 4d28184a6..c9bd0e119 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -6,12 +6,12 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import CohereTokenizer from griptape.common import ( - PromptStack, - PromptStackMessage, - DeltaPromptStackMessage, - TextPromptStackContent, - BasePromptStackContent, - TextDeltaPromptStackContent, + MessageStack, + Message, + DeltaMessage, + TextMessageContent, + BaseMessageContent, + TextDeltaMessageContent, ) from griptape.utils import import_optional_dependency from griptape.tokenizers import BaseTokenizer @@ -40,52 +40,50 @@ class CoherePromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: - result = self.client.chat(**self._base_params(prompt_stack)) + def try_run(self, message_stack: MessageStack) -> Message: + result = self.client.chat(**self._base_params(message_stack)) usage = result.meta.tokens - return PromptStackMessage( - content=[TextPromptStackContent(TextArtifact(result.text))], - role=PromptStackMessage.ASSISTANT_ROLE, - usage=PromptStackMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), + return Message( + content=[TextMessageContent(TextArtifact(result.text))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: - result = self.client.chat_stream(**self._base_params(prompt_stack)) + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + result = self.client.chat_stream(**self._base_params(message_stack)) for event in result: if event.event_type == "text-generation": - yield DeltaPromptStackMessage(content=TextDeltaPromptStackContent(event.text, index=0)) + yield DeltaMessage(content=TextDeltaMessageContent(event.text, index=0)) elif event.event_type == "stream-end": usage = event.response.meta.tokens - yield DeltaPromptStackMessage( - usage=DeltaPromptStackMessage.Usage( - input_tokens=usage.input_tokens, output_tokens=usage.output_tokens - ) + yield DeltaMessage( + usage=DeltaMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens) ) - def _prompt_stack_messages_to_messages(self, messages: list[PromptStackMessage]) -> list[dict]: + def _message_stack_messages_to_messages(self, messages: list[Message]) -> list[dict]: return [ { "role": self.__to_role(message), - "content": [self.__prompt_stack_content_message_content(content) for content in message.content], + "content": [self.__message_stack_content_message_content(content) for content in message.content], } for message in messages ] - def _base_params(self, prompt_stack: PromptStack) -> dict: - last_input = prompt_stack.messages[-1] + def _base_params(self, message_stack: MessageStack) -> dict: + last_input = message_stack.messages[-1] if last_input is not None and len(last_input.content) == 1: user_message = last_input.content[0].artifact.to_text() else: raise ValueError("User message must have exactly one content.") - history_messages = self._prompt_stack_messages_to_messages( - [message for message in prompt_stack.messages[:-1] if not message.is_system()] + history_messages = self._message_stack_messages_to_messages( + [message for message in message_stack.messages[:-1] if not message.is_system()] ) - system_message = next((message for message in prompt_stack.messages if message.is_system()), None) + system_message = next((message for message in message_stack.messages if message.is_system()), None) if system_message is not None: if len(system_message.content) == 1: preamble = system_message.content[0].artifact.to_text() @@ -103,13 +101,13 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: **({"preamble": preamble} if preamble else {}), } - def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: - if isinstance(content, TextPromptStackContent): + def __message_stack_content_message_content(self, content: BaseMessageContent) -> dict: + if isinstance(content, TextMessageContent): return {"text": content.artifact.to_text()} else: raise ValueError(f"Unsupported content type: {type(content)}") - def __to_role(self, message: PromptStackMessage) -> str: + def __to_role(self, message: Message) -> str: if message.is_system(): return "SYSTEM" elif message.is_user(): diff --git a/griptape/drivers/prompt/dummy_prompt_driver.py b/griptape/drivers/prompt/dummy_prompt_driver.py index 5f0557869..aadb0ee8c 100644 --- a/griptape/drivers/prompt/dummy_prompt_driver.py +++ b/griptape/drivers/prompt/dummy_prompt_driver.py @@ -3,7 +3,7 @@ from attrs import Factory, define, field -from griptape.common import PromptStack, PromptStackMessage, DeltaPromptStackMessage +from griptape.common import MessageStack, Message, DeltaMessage from griptape.drivers import BasePromptDriver from griptape.exceptions import DummyException from griptape.tokenizers import DummyTokenizer @@ -14,8 +14,8 @@ class DummyPromptDriver(BasePromptDriver): model: None = field(init=False, default=None, kw_only=True) tokenizer: DummyTokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True) - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: + def try_run(self, message_stack: MessageStack) -> Message: raise DummyException(__class__.__name__, "try_run") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: raise DummyException(__class__.__name__, "try_stream") diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 8c0867f10..83582a7be 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -7,13 +7,13 @@ from griptape.artifacts import TextArtifact from griptape.common import ( - BasePromptStackContent, - DeltaPromptStackMessage, - TextDeltaPromptStackContent, - ImagePromptStackContent, - PromptStack, - PromptStackMessage, - TextPromptStackContent, + BaseMessageContent, + DeltaMessage, + TextDeltaMessageContent, + ImageMessageContent, + MessageStack, + Message, + TextMessageContent, ) from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, GoogleTokenizer @@ -45,10 +45,10 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: + def try_run(self, message_stack: MessageStack) -> Message: GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig - messages = self._prompt_stack_to_messages(prompt_stack) + messages = self._message_stack_to_messages(message_stack) response: GenerateContentResponse = self.model_client.generate_content( messages, generation_config=GenerationConfig( @@ -62,18 +62,18 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: usage_metadata = response.usage_metadata - return PromptStackMessage( - content=[TextPromptStackContent(TextArtifact(response.text))], - role=PromptStackMessage.ASSISTANT_ROLE, - usage=PromptStackMessage.Usage( + return Message( + content=[TextMessageContent(TextArtifact(response.text))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count ), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig - messages = self._prompt_stack_to_messages(prompt_stack) + messages = self._message_stack_to_messages(message_stack) response: Iterator[GenerateContentResponse] = self.model_client.generate_content( messages, stream=True, @@ -93,19 +93,19 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess # Only want to output the prompt token count once since it is static each chunk if prompt_token_count is None: prompt_token_count = usage_metadata.prompt_token_count - yield DeltaPromptStackMessage( - content=TextDeltaPromptStackContent(chunk.text), - role=PromptStackMessage.ASSISTANT_ROLE, - usage=DeltaPromptStackMessage.Usage( + yield DeltaMessage( + content=TextDeltaMessageContent(chunk.text), + role=Message.ASSISTANT_ROLE, + usage=DeltaMessage.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count, ), ) else: - yield DeltaPromptStackMessage( - content=TextDeltaPromptStackContent(chunk.text), - role=PromptStackMessage.ASSISTANT_ROLE, - usage=DeltaPromptStackMessage.Usage(output_tokens=usage_metadata.candidates_token_count), + yield DeltaMessage( + content=TextDeltaMessageContent(chunk.text), + role=Message.ASSISTANT_ROLE, + usage=DeltaMessage.Usage(output_tokens=usage_metadata.candidates_token_count), ) def _default_model_client(self) -> GenerativeModel: @@ -114,35 +114,35 @@ def _default_model_client(self) -> GenerativeModel: return genai.GenerativeModel(self.model) - def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: + def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: inputs = [ {"role": self.__to_role(message), "parts": self.__to_content(message)} - for message in prompt_stack.messages + for message in message_stack.messages if not message.is_system() ] # Gemini does not have the notion of a system message, so we insert it as part of the first message in the history. - system = next((i for i in prompt_stack.messages if i.is_system()), None) + system = next((i for i in message_stack.messages if i.is_system()), None) if system is not None: inputs[0]["parts"].insert(0, "\n".join(content.to_text() for content in system.content)) return inputs - def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> ContentDict | str: + def __message_stack_content_message_content(self, content: BaseMessageContent) -> ContentDict | str: ContentDict = import_optional_dependency("google.generativeai.types").ContentDict - if isinstance(content, TextPromptStackContent): + if isinstance(content, TextMessageContent): return content.artifact.to_text() - elif isinstance(content, ImagePromptStackContent): + elif isinstance(content, ImageMessageContent): return ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value) else: raise ValueError(f"Unsupported content type: {type(content)}") - def __to_role(self, message: PromptStackMessage) -> str: + def __to_role(self, message: Message) -> str: if message.is_assistant(): return "model" else: return "user" - def __to_content(self, message: PromptStackMessage) -> list[ContentDict | str]: - return [self.__prompt_stack_content_message_content(content) for content in message.content] + def __to_content(self, message: Message) -> list[ContentDict | str]: + return [self.__message_stack_content_message_content(content) for content in message.content] diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 87cd03b10..74b1b68d2 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -7,13 +7,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer -from griptape.common import ( - PromptStack, - PromptStackMessage, - DeltaPromptStackMessage, - TextPromptStackContent, - TextDeltaPromptStackContent, -) +from griptape.common import MessageStack, Message, DeltaMessage, TextMessageContent, TextDeltaMessageContent from griptape.utils import import_optional_dependency if TYPE_CHECKING: @@ -53,55 +47,53 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: - prompt = self.prompt_stack_to_string(prompt_stack) + def try_run(self, message_stack: MessageStack) -> Message: + prompt = self.message_stack_to_string(message_stack) response = self.client.text_generation( prompt, return_full_text=False, max_new_tokens=self.max_tokens, **self.params ) - input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) + input_tokens = len(self.__message_stack_to_tokens(message_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(response)) - return PromptStackMessage( + return Message( content=response, - role=PromptStackMessage.ASSISTANT_ROLE, - usage=PromptStackMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens), + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: - prompt = self.prompt_stack_to_string(prompt_stack) + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + prompt = self.message_stack_to_string(message_stack) response = self.client.text_generation( prompt, return_full_text=False, max_new_tokens=self.max_tokens, stream=True, **self.params ) - input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) + input_tokens = len(self.__message_stack_to_tokens(message_stack)) full_text = "" for token in response: full_text += token - yield DeltaPromptStackMessage(content=TextDeltaPromptStackContent(token, index=0)) + yield DeltaMessage(content=TextDeltaMessageContent(token, index=0)) output_tokens = len(self.tokenizer.tokenizer.encode(full_text)) - yield DeltaPromptStackMessage( - usage=DeltaPromptStackMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens) - ) + yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens)) - def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: - return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) + def message_stack_to_string(self, message_stack: MessageStack) -> str: + return self.tokenizer.tokenizer.decode(self.__message_stack_to_tokens(message_stack)) - def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: + def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: messages = [] - for i in prompt_stack.messages: + for i in message_stack.messages: if len(i.content) == 1: - messages.append({"role": i.role, "content": TextPromptStackContent(i.to_text_artifact())}) + messages.append({"role": i.role, "content": TextMessageContent(i.to_text_artifact())}) else: raise ValueError("Invalid input content length.") return messages - def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: - messages = self._prompt_stack_to_messages(prompt_stack) + def __message_stack_to_tokens(self, message_stack: MessageStack) -> list[int]: + messages = self._message_stack_to_messages(message_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 4003340d1..155d0c488 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field from griptape.artifacts import TextArtifact -from griptape.common import DeltaPromptStackMessage, PromptStack, PromptStackMessage, TextPromptStackContent +from griptape.common import DeltaMessage, MessageStack, Message, TextMessageContent from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency @@ -42,8 +42,8 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): ) ) - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: - messages = self._prompt_stack_to_messages(prompt_stack) + def try_run(self, message_stack: MessageStack) -> Message: + messages = self._message_stack_to_messages(message_stack) result = self.pipe( messages, max_new_tokens=self.max_tokens, temperature=self.temperature, do_sample=True, **self.params @@ -53,35 +53,35 @@ def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: if len(result) == 1: generated_text = result[0]["generated_text"][-1]["content"] - input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) + input_tokens = len(self.__message_stack_to_tokens(message_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) - return PromptStackMessage( - content=[TextPromptStackContent(TextArtifact(generated_text))], - role=PromptStackMessage.ASSISTANT_ROLE, - usage=PromptStackMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens), + return Message( + content=[TextMessageContent(TextArtifact(generated_text))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) else: raise Exception("completion with more than one choice is not supported yet") else: raise Exception("invalid output format") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: raise NotImplementedError("streaming is not supported") - def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: - return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) + def message_stack_to_string(self, message_stack: MessageStack) -> str: + return self.tokenizer.tokenizer.decode(self.__message_stack_to_tokens(message_stack)) - def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: + def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: messages = [] - for i in prompt_stack.messages: + for i in message_stack.messages: messages.append({"role": i.role, "content": i.to_text_artifact().to_text()}) return messages - def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: - messages = self._prompt_stack_to_messages(prompt_stack) + def __message_stack_to_tokens(self, message_stack: MessageStack) -> list[int]: + messages = self._message_stack_to_messages(message_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index d6a1083fb..a93270a1c 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -5,11 +5,11 @@ from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver from griptape.tokenizers.base_tokenizer import BaseTokenizer -from griptape.common import PromptStack, TextPromptStackContent +from griptape.common import MessageStack, TextMessageContent from griptape.utils import import_optional_dependency from griptape.tokenizers import SimpleTokenizer -from griptape.common import PromptStackMessage, DeltaPromptStackMessage, TextDeltaPromptStackContent -from griptape.common import ImagePromptStackContent +from griptape.common import Message, DeltaMessage, TextDeltaMessageContent +from griptape.common import ImageMessageContent if TYPE_CHECKING: from ollama import Client @@ -49,32 +49,32 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: - response = self.client.chat(**self._base_params(prompt_stack)) + def try_run(self, message_stack: MessageStack) -> Message: + response = self.client.chat(**self._base_params(message_stack)) if isinstance(response, dict): - return PromptStackMessage( - content=[TextPromptStackContent(TextArtifact(value=response["message"]["content"]))], - role=PromptStackMessage.ASSISTANT_ROLE, + return Message( + content=[TextMessageContent(TextArtifact(value=response["message"]["content"]))], + role=Message.ASSISTANT_ROLE, ) else: raise Exception("invalid model response") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: - stream = self.client.chat(**self._base_params(prompt_stack), stream=True) + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + stream = self.client.chat(**self._base_params(message_stack), stream=True) if isinstance(stream, Iterator): for chunk in stream: - yield DeltaPromptStackMessage(content=TextDeltaPromptStackContent(chunk["message"]["content"])) + yield DeltaMessage(content=TextDeltaMessageContent(chunk["message"]["content"])) else: raise Exception("invalid model response") - def _base_params(self, prompt_stack: PromptStack) -> dict: - messages = self._prompt_stack_to_messages(prompt_stack) + def _base_params(self, message_stack: MessageStack) -> dict: + messages = self._message_stack_to_messages(message_stack) return {"messages": messages, "model": self.model, "options": self.options} - def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: + def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: return [ { "role": message.role, @@ -84,12 +84,12 @@ def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: "images": [ content.artifact.base64 for content in message.content - if isinstance(content, ImagePromptStackContent) + if isinstance(content, ImageMessageContent) ] } - if any(isinstance(content, ImagePromptStackContent) for content in message.content) + if any(isinstance(content, ImageMessageContent) for content in message.content) else {} ), } - for message in prompt_stack.messages + for message in message_stack.messages ] diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index d7dafbfc6..e181502da 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -8,13 +8,13 @@ from griptape.artifacts import TextArtifact from griptape.common import ( - BasePromptStackContent, - DeltaPromptStackMessage, - TextDeltaPromptStackContent, - ImagePromptStackContent, - PromptStack, - PromptStackMessage, - TextPromptStackContent, + BaseMessageContent, + DeltaMessage, + TextDeltaMessageContent, + ImageMessageContent, + MessageStack, + Message, + TextMessageContent, ) from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer @@ -73,31 +73,31 @@ class OpenAiChatPromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: - result = self.client.chat.completions.create(**self._base_params(prompt_stack)) + def try_run(self, message_stack: MessageStack) -> Message: + result = self.client.chat.completions.create(**self._base_params(message_stack)) if len(result.choices) == 1: message = result.choices[0].message - return PromptStackMessage( - content=[self.__message_to_prompt_stack_content(message)], + return Message( + content=[self.__message_to_message_stack_content(message)], role=message.role, - usage=PromptStackMessage.Usage( + usage=Message.Usage( input_tokens=result.usage.prompt_tokens, output_tokens=result.usage.completion_tokens ), ) else: raise Exception("Completion with more than one choice is not supported yet.") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: result = self.client.chat.completions.create( - **self._base_params(prompt_stack), stream=True, stream_options={"include_usage": True} + **self._base_params(message_stack), stream=True, stream_options={"include_usage": True} ) for chunk in result: if chunk.usage is not None: - yield DeltaPromptStackMessage( - usage=DeltaPromptStackMessage.Usage( + yield DeltaMessage( + usage=DeltaMessage.Usage( input_tokens=chunk.usage.prompt_tokens, output_tokens=chunk.usage.completion_tokens ) ) @@ -106,17 +106,17 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMess choice = chunk.choices[0] delta = choice.delta - yield DeltaPromptStackMessage(content=self.__message_delta_to_prompt_stack_content_delta(delta)) + yield DeltaMessage(content=self.__message_delta_to_message_stack_content_delta(delta)) else: raise Exception("Completion with more than one choice is not supported yet.") - def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: + def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: return [ {"role": self.__to_role(message), "content": self.__to_content(message)} - for message in prompt_stack.messages + for message in message_stack.messages ] - def _base_params(self, prompt_stack: PromptStack) -> dict: + def _base_params(self, message_stack: MessageStack) -> dict: params = { "model": self.model, "temperature": self.temperature, @@ -129,15 +129,15 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: if self.response_format == "json_object": params["response_format"] = {"type": "json_object"} # JSON mode still requires a system message instructing the LLM to output JSON. - prompt_stack.add_system_message("Provide your response as a valid JSON object.") + message_stack.add_system_message("Provide your response as a valid JSON object.") - messages = self._prompt_stack_to_messages(prompt_stack) + messages = self._message_stack_to_messages(message_stack) params["messages"] = messages return params - def __to_role(self, message: PromptStackMessage) -> str: + def __to_role(self, message: Message) -> str: if message.is_system(): return "system" elif message.is_assistant(): @@ -145,16 +145,16 @@ def __to_role(self, message: PromptStackMessage) -> str: else: return "user" - def __to_content(self, message: PromptStackMessage) -> str | list[dict]: - if all(isinstance(content, TextPromptStackContent) for content in message.content): + def __to_content(self, message: Message) -> str | list[dict]: + if all(isinstance(content, TextMessageContent) for content in message.content): return message.to_text_artifact().to_text() else: - return [self.__prompt_stack_content_message_content(content) for content in message.content] + return [self.__message_stack_content_message_content(content) for content in message.content] - def __prompt_stack_content_message_content(self, content: BasePromptStackContent) -> dict: - if isinstance(content, TextPromptStackContent): + def __message_stack_content_message_content(self, content: BaseMessageContent) -> dict: + if isinstance(content, TextMessageContent): return {"type": "text", "text": content.artifact.to_text()} - elif isinstance(content, ImagePromptStackContent): + elif isinstance(content, ImageMessageContent): return { "type": "image_url", "image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"}, @@ -162,16 +162,16 @@ def __prompt_stack_content_message_content(self, content: BasePromptStackContent else: raise ValueError(f"Unsupported content type: {type(content)}") - def __message_to_prompt_stack_content(self, message: ChatCompletionMessage) -> BasePromptStackContent: + def __message_to_message_stack_content(self, message: ChatCompletionMessage) -> BaseMessageContent: if message.content is not None: - return TextPromptStackContent(TextArtifact(message.content)) + return TextMessageContent(TextArtifact(message.content)) else: raise ValueError(f"Unsupported message type: {message}") - def __message_delta_to_prompt_stack_content_delta(self, content_delta: ChoiceDelta) -> TextDeltaPromptStackContent: + def __message_delta_to_message_stack_content_delta(self, content_delta: ChoiceDelta) -> TextDeltaMessageContent: if content_delta.content is None: - return TextDeltaPromptStackContent("") + return TextDeltaMessageContent("") else: delta_content = content_delta.content - return TextDeltaPromptStackContent(delta_content) + return TextDeltaMessageContent(delta_content) diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 48eb0d392..e59b6ec23 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -4,8 +4,8 @@ import io from attrs import field, Factory, define from griptape.artifacts import TextArtifact, CsvRowArtifact, ListArtifact, ErrorArtifact -from griptape.common import PromptStack -from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage +from griptape.common import MessageStack +from griptape.common.message_stack.messages.message import Message from griptape.engines import BaseExtractionEngine from griptape.utils import J2 from griptape.rules import Ruleset @@ -64,9 +64,7 @@ def _extract_rec( if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: rows.extend( self.text_to_csv_rows( - self.prompt_driver.run( - PromptStack(messages=[PromptStackMessage(full_text, role=PromptStackMessage.USER_ROLE)]) - ).value, + self.prompt_driver.run(MessageStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value, column_names, ) ) @@ -83,7 +81,7 @@ def _extract_rec( rows.extend( self.text_to_csv_rows( self.prompt_driver.run( - PromptStack(messages=[PromptStackMessage(partial_text, role=PromptStackMessage.USER_ROLE)]) + MessageStack(messages=[Message(partial_text, role=Message.USER_ROLE)]) ).value, column_names, ) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index 4b8f45a03..744cab563 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -3,10 +3,10 @@ import json from attrs import field, Factory, define from griptape.artifacts import TextArtifact, ListArtifact, ErrorArtifact -from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage +from griptape.common.message_stack.messages.message import Message from griptape.engines import BaseExtractionEngine from griptape.utils import J2 -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.rules import Ruleset @@ -59,9 +59,7 @@ def _extract_rec( if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: extractions.extend( self.json_to_text_artifacts( - self.prompt_driver.run( - PromptStack(messages=[PromptStackMessage(full_text, role=PromptStackMessage.USER_ROLE)]) - ).value + self.prompt_driver.run(MessageStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value ) ) @@ -76,9 +74,7 @@ def _extract_rec( extractions.extend( self.json_to_text_artifacts( - self.prompt_driver.run( - PromptStack(messages=[PromptStackMessage(partial_text, role=PromptStackMessage.USER_ROLE)]) - ).value + self.prompt_driver.run(MessageStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value ) ) diff --git a/griptape/engines/query/vector_query_engine.py b/griptape/engines/query/vector_query_engine.py index 3eb4948e7..08253b73f 100644 --- a/griptape/engines/query/vector_query_engine.py +++ b/griptape/engines/query/vector_query_engine.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING, Optional from attrs import define, field, Factory from griptape.artifacts import TextArtifact, BaseArtifact, ListArtifact -from griptape.common import PromptStack -from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage +from griptape.common import MessageStack +from griptape.common.message_stack.messages.message import Message from griptape.engines import BaseQueryEngine from griptape.utils.j2 import J2 from griptape.rules import Ruleset @@ -51,11 +51,11 @@ def query( user_message = self.user_template_generator.render(query=query) message_token_count = self.prompt_driver.tokenizer.count_input_tokens_left( - self.prompt_driver.prompt_stack_to_string( - PromptStack( + self.prompt_driver.message_stack_to_string( + MessageStack( messages=[ - PromptStackMessage(system_message, role=PromptStackMessage.SYSTEM_ROLE), - PromptStackMessage(user_message, role=PromptStackMessage.USER_ROLE), + Message(system_message, role=Message.SYSTEM_ROLE), + Message(user_message, role=Message.USER_ROLE), ] ) ) @@ -73,10 +73,10 @@ def query( break result = self.prompt_driver.run( - PromptStack( + MessageStack( messages=[ - PromptStackMessage(system_message, role=PromptStackMessage.SYSTEM_ROLE), - PromptStackMessage(user_message, role=PromptStackMessage.USER_ROLE), + Message(system_message, role=Message.SYSTEM_ROLE), + Message(user_message, role=Message.USER_ROLE), ] ) ) diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index a76388025..13958ee9c 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -2,8 +2,8 @@ from attrs import define, Factory, field from griptape.artifacts import TextArtifact, ListArtifact from griptape.chunkers import BaseChunker, TextChunker -from griptape.common import PromptStack -from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage +from griptape.common import MessageStack +from griptape.common.message_stack.messages.message import Message from griptape.drivers import BasePromptDriver from griptape.engines import BaseSummaryEngine from griptape.utils import J2 @@ -62,10 +62,10 @@ def summarize_artifacts_rec( >= self.min_response_tokens ): return self.prompt_driver.run( - PromptStack( + MessageStack( messages=[ - PromptStackMessage(system_prompt, role=PromptStackMessage.SYSTEM_ROLE), - PromptStackMessage(user_prompt, role=PromptStackMessage.USER_ROLE), + Message(system_prompt, role=Message.SYSTEM_ROLE), + Message(user_prompt, role=Message.USER_ROLE), ] ) ) @@ -77,10 +77,10 @@ def summarize_artifacts_rec( return self.summarize_artifacts_rec( chunks[1:], self.prompt_driver.run( - PromptStack( + MessageStack( messages=[ - PromptStackMessage(system_prompt, role=PromptStackMessage.SYSTEM_ROLE), - PromptStackMessage(partial_text, role=PromptStackMessage.USER_ROLE), + Message(system_prompt, role=Message.SYSTEM_ROLE), + Message(partial_text, role=Message.USER_ROLE), ] ) ).value, diff --git a/griptape/events/start_prompt_event.py b/griptape/events/start_prompt_event.py index 35dae95d6..c6cd4aab4 100644 --- a/griptape/events/start_prompt_event.py +++ b/griptape/events/start_prompt_event.py @@ -5,9 +5,9 @@ from griptape.events.base_prompt_event import BasePromptEvent if TYPE_CHECKING: - from griptape.common import PromptStack + from griptape.common import MessageStack @define class StartPromptEvent(BasePromptEvent): - prompt_stack: PromptStack = field(kw_only=True, metadata={"serializable": True}) + message_stack: MessageStack = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 53a72d227..4e29d9925 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Optional from attrs import define, field from griptape.memory.structure import Run -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.mixins import SerializableMixin from abc import ABC, abstractmethod @@ -44,40 +44,40 @@ def after_add_run(self) -> None: def try_add_run(self, run: Run) -> None: ... @abstractmethod - def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: ... + def to_message_stack(self, last_n: Optional[int] = None) -> MessageStack: ... - def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = None) -> PromptStack: - """Add the Conversation Memory runs to the Prompt Stack by modifying the messages in place. + def add_to_message_stack(self, message_stack: MessageStack, index: Optional[int] = None) -> MessageStack: + """Add the Conversation Memory runs to the Message Stack by modifying the messages in place. - If autoprune is enabled, this will fit as many Conversation Memory runs into the Prompt Stack + If autoprune is enabled, this will fit as many Conversation Memory runs into the Message Stack as possible without exceeding the token limit. Args: - prompt_stack: The Prompt Stack to add the Conversation Memory to. + message_stack: The Message Stack to add the Conversation Memory to. index: Optional index to insert the Conversation Memory runs at. - Defaults to appending to the end of the Prompt Stack. + Defaults to appending to the end of the Message Stack. """ num_runs_to_fit_in_prompt = len(self.runs) if self.autoprune and hasattr(self, "structure"): should_prune = True prompt_driver = self.structure.config.prompt_driver - temp_stack = PromptStack() + temp_stack = MessageStack() # Try to determine how many Conversation Memory runs we can - # fit into the Prompt Stack without exceeding the token limit. + # fit into the Message Stack without exceeding the token limit. while should_prune and num_runs_to_fit_in_prompt > 0: - temp_stack.messages = prompt_stack.messages.copy() + temp_stack.messages = message_stack.messages.copy() # Add n runs from Conversation Memory. - # Where we insert into the Prompt Stack doesn't matter here + # Where we insert into the Message Stack doesn't matter here # since we only care about the total token count. - memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).messages + memory_inputs = self.to_message_stack(num_runs_to_fit_in_prompt).messages temp_stack.messages.extend(memory_inputs) - # Convert the prompt stack into tokens left. + # Convert the Message Stack into tokens left. tokens_left = prompt_driver.tokenizer.count_input_tokens_left( - prompt_driver.prompt_stack_to_string(temp_stack) + prompt_driver.message_stack_to_string(temp_stack) ) if tokens_left > 0: # There are still tokens left, no need to prune. @@ -87,10 +87,10 @@ def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = num_runs_to_fit_in_prompt -= 1 if num_runs_to_fit_in_prompt: - memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).messages + memory_inputs = self.to_message_stack(num_runs_to_fit_in_prompt).messages if index: - prompt_stack.messages[index:index] = memory_inputs + message_stack.messages[index:index] = memory_inputs else: - prompt_stack.messages.extend(memory_inputs) + message_stack.messages.extend(memory_inputs) - return prompt_stack + return message_stack diff --git a/griptape/memory/structure/conversation_memory.py b/griptape/memory/structure/conversation_memory.py index 42d160abd..b1401c1e4 100644 --- a/griptape/memory/structure/conversation_memory.py +++ b/griptape/memory/structure/conversation_memory.py @@ -2,7 +2,7 @@ from attrs import define from typing import Optional from griptape.memory.structure import Run, BaseConversationMemory -from griptape.common import PromptStack +from griptape.common import MessageStack @define @@ -14,10 +14,10 @@ def try_add_run(self, run: Run) -> None: while len(self.runs) > self.max_runs: self.runs.pop(0) - def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: - prompt_stack = PromptStack() + def to_message_stack(self, last_n: Optional[int] = None) -> MessageStack: + message_stack = MessageStack() runs = self.runs[-last_n:] if last_n else self.runs for run in runs: - prompt_stack.add_user_message(run.input) - prompt_stack.add_assistant_message(run.output) - return prompt_stack + message_stack.add_user_message(run.input) + message_stack.add_assistant_message(run.output) + return message_stack diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 5a4ff084d..de1a198d9 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -2,9 +2,9 @@ import logging from typing import TYPE_CHECKING, Optional from attrs import define, field, Factory -from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage +from griptape.common.message_stack.messages.message import Message from griptape.utils import J2 -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.memory.structure import ConversationMemory if TYPE_CHECKING: @@ -36,8 +36,8 @@ def prompt_driver(self) -> BasePromptDriver: def prompt_driver(self, value: BasePromptDriver) -> None: self._prompt_driver = value - def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: - stack = PromptStack() + def to_message_stack(self, last_n: Optional[int] = None) -> MessageStack: + stack = MessageStack() if self.summary: stack.add_user_message(self.summary_template_generator.render(summary=self.summary)) @@ -75,7 +75,7 @@ def summarize_runs(self, previous_summary: str | None, runs: list[Run]) -> str | if len(runs) > 0: summary = self.summarize_conversation_template_generator.render(summary=previous_summary, runs=runs) return self.prompt_driver.run( - prompt_stack=PromptStack(messages=[PromptStackMessage(summary, role=PromptStackMessage.USER_ROLE)]) + message_stack=MessageStack(messages=[Message(summary, role=Message.USER_ROLE)]) ).to_text() else: return previous_summary diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index e309a5eab..1026ade1a 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -105,7 +105,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: # These modules are required to avoid `NameError`s when resolving types. from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver from griptape.structures import Structure - from griptape.common import PromptStack, PromptStackMessage + from griptape.common import MessageStack, Message from griptape.tokenizers.base_tokenizer import BaseTokenizer from typing import Any @@ -115,8 +115,8 @@ def _resolve_types(cls, attrs_cls: type) -> None: attrs.resolve_types( attrs_cls, localns={ - "PromptStack": PromptStack, - "Usage": PromptStackMessage.Usage, + "MessageStack": MessageStack, + "Usage": Message.Usage, "Structure": Structure, "BaseConversationMemoryDriver": BaseConversationMemoryDriver, "BasePromptDriver": BasePromptDriver, diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 751fdab17..b0743303a 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field from griptape.artifacts import BaseArtifact -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.tasks import BaseTask from griptape.utils import J2 from griptape.artifacts import TextArtifact, ListArtifact @@ -45,8 +45,8 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], output: Optional[BaseArtifact] = field(default=None, init=False) @property - def prompt_stack(self) -> PromptStack: - stack = PromptStack() + def message_stack(self) -> MessageStack: + stack = MessageStack() memory = self.structure.conversation_memory stack.add_system_message(self.generate_system_template(self)) @@ -58,7 +58,7 @@ def prompt_stack(self) -> PromptStack: if memory: # inserting at index 1 to place memory right after system prompt - memory.add_to_prompt_stack(stack, 1) + memory.add_to_message_stack(stack, 1) return stack @@ -94,7 +94,7 @@ def after_run(self) -> None: self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") def run(self) -> BaseArtifact: - self.output = self.prompt_driver.run(self.prompt_stack) + self.output = self.prompt_driver.run(self.message_stack) return self.output diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index edd90c26e..aec246da9 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -48,7 +48,7 @@ def actions_schema(self) -> Schema: return self._actions_schema_for_tools([self.tool]) def run(self) -> BaseArtifact: - prompt_output = self.prompt_driver.run(prompt_stack=self.prompt_stack).to_text() + prompt_output = self.prompt_driver.run(message_stack=self.message_stack).to_text() action_matches = re.findall(self.ACTION_PATTERN, prompt_output, re.DOTALL) if action_matches: diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 58300b529..9c0fe33bb 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -10,7 +10,7 @@ from griptape.tasks import ActionsSubtask from griptape.tasks import PromptTask from griptape.utils import J2 -from griptape.common import PromptStack +from griptape.common import MessageStack if TYPE_CHECKING: from griptape.tools import BaseTool @@ -61,8 +61,8 @@ def tool_output_memory(self) -> list[TaskMemory]: return list(unique_memory_dict.values()) @property - def prompt_stack(self) -> PromptStack: - stack = PromptStack() + def message_stack(self) -> MessageStack: + stack = MessageStack() memory = self.structure.conversation_memory stack.add_system_message(self.generate_system_template(self)) @@ -78,7 +78,7 @@ def prompt_stack(self) -> PromptStack: if memory: # inserting at index 1 to place memory right after system prompt - memory.add_to_prompt_stack(stack, 1) + memory.add_to_message_stack(stack, 1) return stack @@ -131,7 +131,7 @@ def run(self) -> BaseArtifact: self.subtasks.clear() self.prompt_driver.tokenizer.stop_sequences.extend([self.response_stop_sequence]) - subtask = self.add_subtask(ActionsSubtask(self.prompt_driver.run(prompt_stack=self.prompt_stack).to_text())) + subtask = self.add_subtask(ActionsSubtask(self.prompt_driver.run(message_stack=self.message_stack).to_text())) while True: if subtask.output is None: @@ -146,7 +146,7 @@ def run(self) -> BaseArtifact: subtask.after_run() subtask = self.add_subtask( - ActionsSubtask(self.prompt_driver.run(prompt_stack=self.prompt_stack).to_text()) + ActionsSubtask(self.prompt_driver.run(message_stack=self.message_stack).to_text()) ) else: break diff --git a/griptape/utils/conversation.py b/griptape/utils/conversation.py index 634ad3715..0bdc078dd 100644 --- a/griptape/utils/conversation.py +++ b/griptape/utils/conversation.py @@ -19,10 +19,10 @@ def lines(self) -> list[str]: return lines - def prompt_stack(self) -> list[str]: + def message_stack(self) -> list[str]: lines = [] - for stack in self.memory.to_prompt_stack().messages: + for stack in self.memory.to_message_stack().messages: lines.append(f"{stack.role}: {stack.to_text_artifact().to_text()}") return lines diff --git a/tests/mocks/mock_failing_prompt_driver.py b/tests/mocks/mock_failing_prompt_driver.py index e7ff7ea66..d376d1981 100644 --- a/tests/mocks/mock_failing_prompt_driver.py +++ b/tests/mocks/mock_failing_prompt_driver.py @@ -3,13 +3,7 @@ from attrs import define from griptape.artifacts import TextArtifact -from griptape.common import ( - PromptStack, - PromptStackMessage, - TextPromptStackContent, - DeltaPromptStackMessage, - TextDeltaPromptStackContent, -) +from griptape.common import MessageStack, Message, TextMessageContent, DeltaMessage, TextDeltaMessageContent from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer @@ -21,25 +15,25 @@ class MockFailingPromptDriver(BasePromptDriver): model: str = "test-model" tokenizer: BaseTokenizer = OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: + def try_run(self, message_stack: MessageStack) -> Message: if self.current_attempt < self.max_failures: self.current_attempt += 1 raise Exception("failed attempt") else: - return PromptStackMessage( - content=[TextPromptStackContent(TextArtifact("success"))], - role=PromptStackMessage.ASSISTANT_ROLE, - usage=PromptStackMessage.Usage(input_tokens=100, output_tokens=100), + return Message( + content=[TextMessageContent(TextArtifact("success"))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: if self.current_attempt < self.max_failures: self.current_attempt += 1 raise Exception("failed attempt") else: - yield DeltaPromptStackMessage( - content=TextDeltaPromptStackContent("success"), - usage=DeltaPromptStackMessage.Usage(input_tokens=100, output_tokens=100), + yield DeltaMessage( + content=TextDeltaMessageContent("success"), + usage=DeltaMessage.Usage(input_tokens=100, output_tokens=100), ) diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 16ba38abe..40782dc39 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -6,13 +6,7 @@ from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.common import ( - PromptStack, - PromptStackMessage, - DeltaPromptStackMessage, - TextPromptStackContent, - TextDeltaPromptStackContent, -) +from griptape.common import MessageStack, Message, DeltaMessage, TextMessageContent, TextDeltaMessageContent from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer @@ -24,21 +18,20 @@ class MockPromptDriver(BasePromptDriver): model: str = "test-model" tokenizer: BaseTokenizer = MockTokenizer(model="test-model", max_input_tokens=4096, max_output_tokens=4096) mock_input: str | Callable[[], str] = field(default="mock input", kw_only=True) - mock_output: str | Callable[[PromptStack], str] = field(default="mock output", kw_only=True) + mock_output: str | Callable[[MessageStack], str] = field(default="mock output", kw_only=True) - def try_run(self, prompt_stack: PromptStack) -> PromptStackMessage: - output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output + def try_run(self, message_stack: MessageStack) -> Message: + output = self.mock_output(message_stack) if isinstance(self.mock_output, Callable) else self.mock_output - return PromptStackMessage( - content=[TextPromptStackContent(TextArtifact(output))], - role=PromptStackMessage.ASSISTANT_ROLE, - usage=PromptStackMessage.Usage(input_tokens=100, output_tokens=100), + return Message( + content=[TextMessageContent(TextArtifact(output))], + role=Message.ASSISTANT_ROLE, + usage=Message.Usage(input_tokens=100, output_tokens=100), ) - def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaPromptStackMessage]: - output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output + def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + output = self.mock_output(message_stack) if isinstance(self.mock_output, Callable) else self.mock_output - yield DeltaPromptStackMessage( - content=TextDeltaPromptStackContent(output), - usage=DeltaPromptStackMessage.Usage(input_tokens=100, output_tokens=100), + yield DeltaMessage( + content=TextDeltaMessageContent(output), usage=DeltaMessage.Usage(input_tokens=100, output_tokens=100) ) diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index f4f4eda11..816692075 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,7 +1,7 @@ import pytest from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.drivers import AmazonBedrockPromptDriver @@ -31,15 +31,15 @@ def mock_converse_stream(self, mocker): return mock_converse_stream @pytest.fixture - def prompt_stack(self): - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message(TextArtifact("user-input")) - prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) - prompt_stack.add_assistant_message("assistant-input") + def message_stack(self): + message_stack = MessageStack() + message_stack.add_system_message("system-input") + message_stack.add_user_message("user-input") + message_stack.add_user_message(TextArtifact("user-input")) + message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + message_stack.add_assistant_message("assistant-input") - return prompt_stack + return message_stack @pytest.fixture def messages(self): @@ -50,12 +50,12 @@ def messages(self): {"role": "assistant", "content": [{"text": "assistant-input"}]}, ] - def test_try_run(self, mock_converse, prompt_stack, messages): + def test_try_run(self, mock_converse, message_stack, messages): # Given driver = AmazonBedrockPromptDriver(model="ai21.j2") # When - text_artifact = driver.try_run(prompt_stack) + text_artifact = driver.try_run(message_stack) # Then mock_converse.assert_called_once_with( @@ -69,12 +69,12 @@ def test_try_run(self, mock_converse, prompt_stack, messages): assert text_artifact.usage.input_tokens == 5 assert text_artifact.usage.output_tokens == 10 - def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages): + def test_try_stream_run(self, mock_converse_stream, message_stack, messages): # Given driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True) # When - stream = driver.try_stream(prompt_stack) + stream = driver.try_stream(message_stack) event = next(stream) # Then diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py index a75fc6ed0..5d70f3aea 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -2,7 +2,7 @@ from botocore.response import StreamingBody from griptape.tokenizers import HuggingFaceTokenizer from griptape.drivers.prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver -from griptape.common import PromptStack +from griptape.common import MessageStack from io import BytesIO import json import pytest @@ -36,13 +36,13 @@ def test_init(self): def test_try_run(self, mock_client): # Given driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") - prompt_stack = PromptStack() - prompt_stack.add_user_message("prompt-stack") + message_stack = MessageStack() + message_stack.add_user_message("prompt-stack") # When response_body = [{"generated_text": "foobar"}] mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body(response_body)} - text_artifact = driver.try_run(prompt_stack) + text_artifact = driver.try_run(message_stack) assert isinstance(driver.tokenizer, HuggingFaceTokenizer) # Then @@ -72,7 +72,7 @@ def test_try_run(self, mock_client): # When response_body = {"generated_text": "foobar"} mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body(response_body)} - text_artifact = driver.try_run(prompt_stack) + text_artifact = driver.try_run(message_stack) assert isinstance(driver.tokenizer, HuggingFaceTokenizer) # Then @@ -100,12 +100,12 @@ def test_try_run(self, mock_client): def test_try_stream(self, mock_client): # Given driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") - prompt_stack = PromptStack() - prompt_stack.add_user_message("prompt-stack") + message_stack = MessageStack() + message_stack.add_user_message("prompt-stack") # When with pytest.raises(NotImplementedError) as e: - driver.try_stream(prompt_stack) + driver.try_stream(message_stack) # Then assert e.value.args[0] == "streaming is not supported" @@ -125,12 +125,12 @@ def test_try_run_throws_on_empty_response(self, mock_client): # Given driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body([])} - prompt_stack = PromptStack() - prompt_stack.add_user_message("prompt-stack") + message_stack = MessageStack() + message_stack.add_user_message("prompt-stack") # When with pytest.raises(Exception) as e: - driver.try_run(prompt_stack) + driver.try_run(message_stack) # Then assert e.value.args[0] == "model response is empty" diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index aa85021a0..7f118d151 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,5 +1,5 @@ from griptape.drivers import AnthropicPromptDriver -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.artifacts import TextArtifact, ImageArtifact from unittest.mock import Mock import pytest @@ -60,13 +60,13 @@ def test_init(self, model): @pytest.mark.parametrize("system_enabled", [True, False]) def test_try_run(self, mock_client, model, system_enabled): # Given - prompt_stack = PromptStack() + message_stack = MessageStack() if system_enabled: - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message(TextArtifact("user-input")) - prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) - prompt_stack.add_assistant_message("assistant-input") + message_stack.add_system_message("system-input") + message_stack.add_user_message("user-input") + message_stack.add_user_message(TextArtifact("user-input")) + message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + message_stack.add_assistant_message("assistant-input") driver = AnthropicPromptDriver(model=model, api_key="api-key") expected_messages = [ {"role": "user", "content": "user-input"}, @@ -84,7 +84,7 @@ def test_try_run(self, mock_client, model, system_enabled): ] # When - message = driver.try_run(prompt_stack) + message = driver.try_run(message_stack) # Then mock_client.return_value.messages.create.assert_called_once_with( @@ -115,13 +115,13 @@ def test_try_run(self, mock_client, model, system_enabled): @pytest.mark.parametrize("system_enabled", [True, False]) def test_try_stream_run(self, mock_stream_client, model, system_enabled): # Given - prompt_stack = PromptStack() + message_stack = MessageStack() if system_enabled: - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message(TextArtifact("user-input")) - prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) - prompt_stack.add_assistant_message("assistant-input") + message_stack.add_system_message("system-input") + message_stack.add_user_message("user-input") + message_stack.add_user_message(TextArtifact("user-input")) + message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + message_stack.add_assistant_message("assistant-input") expected_messages = [ {"role": "user", "content": "user-input"}, {"role": "user", "content": "user-input"}, @@ -139,7 +139,7 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): driver = AnthropicPromptDriver(model=model, api_key="api-key", stream=True) # When - stream = driver.try_stream(prompt_stack) + stream = driver.try_stream(message_stack) event = next(stream) # Then @@ -162,14 +162,14 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): event = next(stream) assert event.usage.output_tokens == 10 - def test_try_run_throws_when_prompt_stack_is_string(self): + def test_try_run_throws_when_message_stack_is_string(self): # Given - prompt_stack = "prompt-stack" + message_stack = "prompt-stack" driver = AnthropicPromptDriver(model="claude", api_key="api-key") # When with pytest.raises(Exception) as e: - driver.try_run(prompt_stack) # pyright: ignore + driver.try_run(message_stack) # pyright: ignore # Then assert e.value.args[0] == "'str' object has no attribute 'messages'" diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index 92544a74e..dedd2b3f6 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -30,12 +30,12 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", azure_deployment="foobar", model="gpt-4") assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" - def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): + def test_try_run(self, mock_chat_completion_create, message_stack, messages): # Given driver = AzureOpenAiChatPromptDriver(azure_endpoint="endpoint", azure_deployment="deployment-id", model="gpt-4") # When - text_artifact = driver.try_run(prompt_stack) + text_artifact = driver.try_run(message_stack) # Then mock_chat_completion_create.assert_called_once_with( @@ -45,14 +45,14 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): assert text_artifact.usage.input_tokens == 5 assert text_artifact.usage.output_tokens == 10 - def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages): + def test_try_stream_run(self, mock_chat_completion_stream_create, message_stack, messages): # Given driver = AzureOpenAiChatPromptDriver( azure_endpoint="endpoint", azure_deployment="deployment-id", model="gpt-4", stream=True ) # When - stream = driver.try_stream(prompt_stack) + stream = driver.try_stream(message_stack) event = next(stream) # Then diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 0f0d0ee89..78936748e 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,5 +1,5 @@ from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.common import PromptStack +from griptape.common import MessageStack from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver from griptape.artifacts import ErrorArtifact, TextArtifact @@ -37,7 +37,7 @@ def test_run_via_pipeline_publishes_events(self, mocker): assert instance_count(events, FinishPromptEvent) == 1 def test_run(self): - assert isinstance(MockPromptDriver().run(PromptStack(messages=[])), TextArtifact) + assert isinstance(MockPromptDriver().run(MessageStack(messages=[])), TextArtifact) def instance_count(instances, clazz): diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index cf556fc1e..666226c20 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.drivers import CoherePromptDriver @@ -33,22 +33,22 @@ def mock_tokenizer(self, mocker): return mocker.patch("griptape.tokenizers.CohereTokenizer").return_value @pytest.fixture - def prompt_stack(self): - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_assistant_message("assistant-input") - return prompt_stack + def message_stack(self): + message_stack = MessageStack() + message_stack.add_system_message("system-input") + message_stack.add_user_message("user-input") + message_stack.add_assistant_message("assistant-input") + return message_stack def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") - def test_try_run(self, mock_client, prompt_stack): # pyright: ignore + def test_try_run(self, mock_client, message_stack): # pyright: ignore # Given driver = CoherePromptDriver(model="command", api_key="api-key") # When - text_artifact = driver.try_run(prompt_stack) + text_artifact = driver.try_run(message_stack) # Then mock_client.chat.assert_called_once_with( @@ -64,12 +64,12 @@ def test_try_run(self, mock_client, prompt_stack): # pyright: ignore assert text_artifact.usage.input_tokens == 5 assert text_artifact.usage.output_tokens == 10 - def test_try_stream_run(self, mock_stream_client, prompt_stack): # pyright: ignore + def test_try_stream_run(self, mock_stream_client, message_stack): # pyright: ignore # Given driver = CoherePromptDriver(model="command", api_key="api-key", stream=True) # When - stream = driver.try_stream(prompt_stack) + stream = driver.try_stream(message_stack) event = next(stream) # Then diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 9a454a563..7ef4beb08 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -1,7 +1,7 @@ from google.generativeai.types import GenerationConfig from griptape.artifacts import TextArtifact, ImageArtifact from griptape.drivers import GooglePromptDriver -from griptape.common import PromptStack +from griptape.common import MessageStack from unittest.mock import Mock import pytest @@ -34,16 +34,16 @@ def test_init(self): def test_try_run(self, mock_generative_model): # Given - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message(TextArtifact("user-input")) - prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) - prompt_stack.add_assistant_message("assistant-input") + message_stack = MessageStack() + message_stack.add_system_message("system-input") + message_stack.add_user_message("user-input") + message_stack.add_user_message(TextArtifact("user-input")) + message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + message_stack.add_assistant_message("assistant-input") driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", top_p=0.5, top_k=50) # When - text_artifact = driver.try_run(prompt_stack) + text_artifact = driver.try_run(message_stack) # Then mock_generative_model.return_value.generate_content.assert_called_once_with( @@ -63,16 +63,16 @@ def test_try_run(self, mock_generative_model): def test_try_stream(self, mock_stream_generative_model): # Given - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message(TextArtifact("user-input")) - prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) - prompt_stack.add_assistant_message("assistant-input") + message_stack = MessageStack() + message_stack.add_system_message("system-input") + message_stack.add_user_message("user-input") + message_stack.add_user_message(TextArtifact("user-input")) + message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + message_stack.add_assistant_message("assistant-input") driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", stream=True, top_p=0.5, top_k=50) # When - stream = driver.try_stream(prompt_stack) + stream = driver.try_stream(message_stack) # Then event = next(stream) diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 4618e1de3..8079ffe13 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -1,5 +1,5 @@ from griptape.drivers import HuggingFaceHubPromptDriver -from griptape.common import PromptStack +from griptape.common import MessageStack import pytest @@ -28,12 +28,12 @@ def mock_client_stream(self, mocker): return mock_client @pytest.fixture - def prompt_stack(self): - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_assistant_message("assistant-input") - return prompt_stack + def message_stack(self): + message_stack = MessageStack() + message_stack.add_system_message("system-input") + message_stack.add_user_message("user-input") + message_stack.add_assistant_message("assistant-input") + return message_stack @pytest.fixture(autouse=True) def mock_autotokenizer(self, mocker): @@ -44,24 +44,24 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - def test_try_run(self, prompt_stack, mock_client): + def test_try_run(self, message_stack, mock_client): # Given driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id") # When - message = driver.try_run(prompt_stack) + message = driver.try_run(message_stack) # Then assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack, mock_client_stream): + def test_try_stream(self, message_stack, mock_client_stream): # Given driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id", stream=True) # When - stream = driver.try_stream(prompt_stack) + stream = driver.try_stream(message_stack) event = next(stream) # Then diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index a63d697fb..5b31d8fe5 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -1,5 +1,5 @@ from griptape.drivers import HuggingFacePipelinePromptDriver -from griptape.common import PromptStack +from griptape.common import MessageStack import pytest @@ -26,69 +26,69 @@ def mock_autotokenizer(self, mocker): return mock_autotokenizer @pytest.fixture - def prompt_stack(self): - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_assistant_message("assistant-input") - return prompt_stack + def message_stack(self): + message_stack = MessageStack() + message_stack.add_system_message("system-input") + message_stack.add_user_message("user-input") + message_stack.add_assistant_message("assistant-input") + return message_stack def test_init(self): assert HuggingFacePipelinePromptDriver(model="gpt2", max_tokens=42) - def test_try_run(self, prompt_stack): + def test_try_run(self, message_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) # When - message = driver.try_run(prompt_stack) + message = driver.try_run(message_stack) # Then assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, prompt_stack): + def test_try_stream(self, message_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) # When with pytest.raises(Exception) as e: - driver.try_stream(prompt_stack) + driver.try_stream(message_stack) assert e.value.args[0] == "streaming is not supported" @pytest.mark.parametrize("choices", [[], [1, 2]]) - def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_generator, prompt_stack): + def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_generator, message_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) mock_generator.return_value = choices # When with pytest.raises(Exception) as e: - driver.try_run(prompt_stack) + driver.try_run(message_stack) # Then assert e.value.args[0] == "completion with more than one choice is not supported yet" - def test_try_run_throws_when_non_list(self, mock_generator, prompt_stack): + def test_try_run_throws_when_non_list(self, mock_generator, message_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) mock_generator.return_value = {} # When with pytest.raises(Exception) as e: - driver.try_run(prompt_stack) + driver.try_run(message_stack) # Then assert e.value.args[0] == "invalid output format" - def test_prompt_stack_to_string(self, prompt_stack): + def test_message_stack_to_string(self, message_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) # When - result = driver.prompt_stack_to_string(prompt_stack) + result = driver.message_stack_to_string(message_stack) # Then assert result == "model-output" diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index e737aeaeb..c13aa2e60 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,6 +1,6 @@ -from griptape.common.prompt_stack.contents.text_delta_prompt_stack_content import TextDeltaPromptStackContent +from griptape.common.message_stack.contents.text_delta_message_content import TextDeltaMessageContent from griptape.drivers import OllamaPromptDriver -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact import pytest @@ -26,15 +26,15 @@ def test_init(self): def test_try_run(self, mock_client): # Given - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message( + message_stack = MessageStack() + message_stack.add_system_message("system-input") + message_stack.add_user_message("user-input") + message_stack.add_user_message( ListArtifact( [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] ) ) - prompt_stack.add_assistant_message("assistant-input") + message_stack.add_assistant_message("assistant-input") driver = OllamaPromptDriver(model="llama") expected_messages = [ {"role": "system", "content": "system-input"}, @@ -44,7 +44,7 @@ def test_try_run(self, mock_client): ] # When - message = driver.try_run(prompt_stack) + message = driver.try_run(message_stack) # Then mock_client.return_value.chat.assert_called_once_with( @@ -58,25 +58,25 @@ def test_try_run(self, mock_client): def test_try_run_bad_response(self, mock_client): # Given - prompt_stack = PromptStack() + message_stack = MessageStack() driver = OllamaPromptDriver(model="llama") mock_client.return_value.chat.return_value = "bad-response" # When/Then with pytest.raises(Exception, match="invalid model response"): - driver.try_run(prompt_stack) + driver.try_run(message_stack) def test_try_stream_run(self, mock_stream_client): # Given - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message( + message_stack = MessageStack() + message_stack.add_system_message("system-input") + message_stack.add_user_message("user-input") + message_stack.add_user_message( ListArtifact( [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] ) ) - prompt_stack.add_assistant_message("assistant-input") + message_stack.add_assistant_message("assistant-input") expected_messages = [ {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, @@ -86,7 +86,7 @@ def test_try_stream_run(self, mock_stream_client): driver = OllamaPromptDriver(model="llama", stream=True) # When - text_artifact = next(driver.try_stream(prompt_stack)) + text_artifact = next(driver.try_stream(message_stack)) # Then mock_stream_client.return_value.chat.assert_called_once_with( @@ -95,15 +95,15 @@ def test_try_stream_run(self, mock_stream_client): options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, stream=True, ) - if isinstance(text_artifact, TextDeltaPromptStackContent): + if isinstance(text_artifact, TextDeltaMessageContent): assert text_artifact.text == "model-output" def test_try_stream_bad_response(self, mock_stream_client): # Given - prompt_stack = PromptStack() + message_stack = MessageStack() driver = OllamaPromptDriver(model="llama", stream=True) mock_stream_client.return_value.chat.return_value = "bad-response" # When/Then with pytest.raises(Exception, match="invalid model response"): - next(driver.try_stream(prompt_stack)) + next(driver.try_stream(message_stack)) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 01de35028..6aab853dc 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts.image_artifact import ImageArtifact from griptape.artifacts.text_artifact import TextArtifact from griptape.drivers import OpenAiChatPromptDriver -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.tokenizers import OpenAiTokenizer from unittest.mock import Mock from tests.mocks.mock_tokenizer import MockTokenizer @@ -32,14 +32,14 @@ def mock_chat_completion_stream_create(self, mocker): return mock_chat_create @pytest.fixture - def prompt_stack(self): - prompt_stack = PromptStack() - prompt_stack.add_system_message("system-input") - prompt_stack.add_user_message("user-input") - prompt_stack.add_user_message(TextArtifact("user-input")) - prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) - prompt_stack.add_assistant_message("assistant-input") - return prompt_stack + def message_stack(self): + message_stack = MessageStack() + message_stack.add_system_message("system-input") + message_stack.add_user_message("user-input") + message_stack.add_user_message(TextArtifact("user-input")) + message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + message_stack.add_assistant_message("assistant-input") + return message_stack @pytest.fixture def messages(self): @@ -92,12 +92,12 @@ class TestOpenAiChatPromptDriver(TestOpenAiChatPromptDriverFixtureMixin): def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) - def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): + def test_try_run(self, mock_chat_completion_create, message_stack, messages): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) # When - event = driver.try_run(prompt_stack) + event = driver.try_run(message_stack) # Then mock_chat_completion_create.assert_called_once_with( @@ -105,14 +105,14 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): ) assert event.value == "model-output" - def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack, messages): + def test_try_run_response_format(self, mock_chat_completion_create, message_stack, messages): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, response_format="json_object" ) # When - message = driver.try_run(prompt_stack) + message = driver.try_run(message_stack) # Then mock_chat_completion_create.assert_called_once_with( @@ -127,12 +127,12 @@ def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack assert message.usage.input_tokens == 5 assert message.usage.output_tokens == 10 - def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages): + def test_try_stream_run(self, mock_chat_completion_stream_create, message_stack, messages): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True) # When - stream = driver.try_stream(prompt_stack) + stream = driver.try_stream(message_stack) event = next(stream) # Then @@ -152,12 +152,12 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, assert event.usage.input_tokens == 5 assert event.usage.output_tokens == 10 - def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): + def test_try_run_with_max_tokens(self, mock_chat_completion_create, message_stack, messages): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1) # When - event = driver.try_run(prompt_stack) + event = driver.try_run(message_stack) # Then mock_chat_completion_create.assert_called_once_with( @@ -170,7 +170,7 @@ def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack ) assert event.value == "model-output" - def test_try_run_throws_when_prompt_stack_is_string(self): + def test_try_run_throws_when_message_stack_is_string(self): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) @@ -181,19 +181,19 @@ def test_try_run_throws_when_prompt_stack_is_string(self): # Then assert e.value.args[0] == "'str' object has no attribute 'messages'" - def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completion_create, prompt_stack): + def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completion_create, message_stack): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, api_key="api-key") mock_chat_completion_create.return_value.choices = [Mock(message=Mock(content="model-output"))] * 10 # When with pytest.raises(Exception) as e: - driver.try_run(prompt_stack) + driver.try_run(message_stack) # Then assert e.value.args[0] == "Completion with more than one choice is not supported yet." - def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): + def test_custom_tokenizer(self, mock_chat_completion_create, message_stack, messages): driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), @@ -201,7 +201,7 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa ) # When - event = driver.try_run(prompt_stack) + event = driver.try_run(message_stack) # Then mock_chat_completion_create.assert_called_once_with( diff --git a/tests/unit/engines/summary/test_prompt_summary_engine.py b/tests/unit/engines/summary/test_prompt_summary_engine.py index 34c6e3563..5de0016a9 100644 --- a/tests/unit/engines/summary/test_prompt_summary_engine.py +++ b/tests/unit/engines/summary/test_prompt_summary_engine.py @@ -1,7 +1,7 @@ import pytest from griptape.artifacts import TextArtifact, ListArtifact from griptape.engines import PromptSummaryEngine -from griptape.common import PromptStack +from griptape.common import MessageStack from tests.mocks.mock_prompt_driver import MockPromptDriver import os @@ -27,8 +27,8 @@ def test_max_token_multiplier_invalid(self, engine): PromptSummaryEngine(prompt_driver=MockPromptDriver(), max_token_multiplier=10000) def test_chunked_summary(self, engine): - def smaller_input(prompt_stack: PromptStack): - return prompt_stack.messages[0].content[: (len(prompt_stack.messages[0].content) // 2)] + def smaller_input(message_stack: MessageStack): + return message_stack.messages[0].content[: (len(message_stack.messages[0].content) // 2)] engine = PromptSummaryEngine(prompt_driver=MockPromptDriver(mock_output="smaller_input")) diff --git a/tests/unit/events/test_base_event.py b/tests/unit/events/test_base_event.py index e3ed5aa0e..c4393fa3d 100644 --- a/tests/unit/events/test_base_event.py +++ b/tests/unit/events/test_base_event.py @@ -32,22 +32,22 @@ def test_start_prompt_event_from_dict(self): "id": "917298d4bf894b0a824a8fdb26717a0c", "timestamp": 123, "model": "foo bar", - "prompt_stack": { - "type": "PromptStack", + "message_stack": { + "type": "MessageStack", "messages": [ { - "type": "PromptStackMessage", + "type": "Message", "role": "user", "content": [ - {"type": "TextPromptStackContent", "artifact": {"type": "TextArtifact", "value": "foo"}} + {"type": "TextMessageContent", "artifact": {"type": "TextArtifact", "value": "foo"}} ], "usage": {"type": "Usage", "input_tokens": None, "output_tokens": None}, }, { - "type": "PromptStackMessage", + "type": "Message", "role": "system", "content": [ - {"type": "TextPromptStackContent", "artifact": {"type": "TextArtifact", "value": "bar"}} + {"type": "TextMessageContent", "artifact": {"type": "TextArtifact", "value": "bar"}} ], "usage": {"type": "Usage", "input_tokens": None, "output_tokens": None}, }, @@ -59,10 +59,10 @@ def test_start_prompt_event_from_dict(self): assert isinstance(event, StartPromptEvent) assert event.timestamp == 123 - assert event.prompt_stack.messages[0].content[0].artifact.value == "foo" - assert event.prompt_stack.messages[0].role == "user" - assert event.prompt_stack.messages[1].content[0].artifact.value == "bar" - assert event.prompt_stack.messages[1].role == "system" + assert event.message_stack.messages[0].content[0].artifact.value == "foo" + assert event.message_stack.messages[0].role == "user" + assert event.message_stack.messages[1].content[0].artifact.value == "bar" + assert event.message_stack.messages[1].role == "system" assert event.model == "foo bar" def test_finish_prompt_event_from_dict(self): diff --git a/tests/unit/events/test_start_prompt_event.py b/tests/unit/events/test_start_prompt_event.py index 4ef08ec5c..51e609458 100644 --- a/tests/unit/events/test_start_prompt_event.py +++ b/tests/unit/events/test_start_prompt_event.py @@ -1,22 +1,22 @@ import pytest from griptape.events import StartPromptEvent -from griptape.common import PromptStack +from griptape.common import MessageStack class TestStartPromptEvent: @pytest.fixture def start_prompt_event(self): - prompt_stack = PromptStack() - prompt_stack.add_user_message("foo") - prompt_stack.add_system_message("bar") - return StartPromptEvent(prompt_stack=prompt_stack, model="foo bar") + message_stack = MessageStack() + message_stack.add_user_message("foo") + message_stack.add_system_message("bar") + return StartPromptEvent(message_stack=message_stack, model="foo bar") def test_to_dict(self, start_prompt_event): assert "timestamp" in start_prompt_event.to_dict() - assert start_prompt_event.to_dict()["prompt_stack"]["messages"][0]["content"][0]["artifact"]["value"] == "foo" - assert start_prompt_event.to_dict()["prompt_stack"]["messages"][0]["role"] == "user" - assert start_prompt_event.to_dict()["prompt_stack"]["messages"][1]["content"][0]["artifact"]["value"] == "bar" - assert start_prompt_event.to_dict()["prompt_stack"]["messages"][1]["role"] == "system" + assert start_prompt_event.to_dict()["message_stack"]["messages"][0]["content"][0]["artifact"]["value"] == "foo" + assert start_prompt_event.to_dict()["message_stack"]["messages"][0]["role"] == "user" + assert start_prompt_event.to_dict()["message_stack"]["messages"][1]["content"][0]["artifact"]["value"] == "bar" + assert start_prompt_event.to_dict()["message_stack"]["messages"][1]["role"] == "system" assert start_prompt_event.to_dict()["model"] == "foo bar" diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 30685d863..82df4fc4d 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -1,6 +1,6 @@ import json from griptape.structures import Agent -from griptape.common import PromptStack +from griptape.common import MessageStack from griptape.memory.structure import ConversationMemory, Run, BaseConversationMemory from griptape.structures import Pipeline from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -32,14 +32,14 @@ def test_to_dict(self): assert memory.to_dict()["type"] == "ConversationMemory" assert memory.to_dict()["runs"][0]["input"]["value"] == "foo" - def test_to_prompt_stack(self): + def test_to_message_stack(self): memory = ConversationMemory() memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) - prompt_stack = memory.to_prompt_stack() + message_stack = memory.to_message_stack() - assert prompt_stack.messages[0].content[0].artifact.value == "foo" - assert prompt_stack.messages[1].content[0].artifact.value == "bar" + assert message_stack.messages[0].content[0].artifact.value == "foo" + assert message_stack.messages[1].content[0].artifact.value == "bar" def test_from_dict(self): memory = ConversationMemory() @@ -74,7 +74,7 @@ def test_buffering(self): assert pipeline.conversation_memory.runs[0].input.value == "run4" assert pipeline.conversation_memory.runs[1].input.value == "run5" - def test_add_to_prompt_stack_autopruing_disabled(self): + def test_add_to_message_stack_autopruing_disabled(self): agent = Agent(prompt_driver=MockPromptDriver()) memory = ConversationMemory( autoprune=False, @@ -87,14 +87,14 @@ def test_add_to_prompt_stack_autopruing_disabled(self): ], ) memory.structure = agent - prompt_stack = PromptStack() - prompt_stack.add_user_message(TextArtifact("foo")) - prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack) + message_stack = MessageStack() + message_stack.add_user_message(TextArtifact("foo")) + message_stack.add_assistant_message("bar") + memory.add_to_message_stack(message_stack) - assert len(prompt_stack.messages) == 12 + assert len(message_stack.messages) == 12 - def test_add_to_prompt_stack_autopruning_enabled(self): + def test_add_to_message_stack_autopruning_enabled(self): # All memory is pruned. agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0))) memory = ConversationMemory( @@ -108,13 +108,13 @@ def test_add_to_prompt_stack_autopruning_enabled(self): ], ) memory.structure = agent - prompt_stack = PromptStack() - prompt_stack.add_system_message("fizz") - prompt_stack.add_user_message("foo") - prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack) + message_stack = MessageStack() + message_stack.add_system_message("fizz") + message_stack.add_user_message("foo") + message_stack.add_assistant_message("bar") + memory.add_to_message_stack(message_stack) - assert len(prompt_stack.messages) == 3 + assert len(message_stack.messages) == 3 # No memory is pruned. agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000))) @@ -129,13 +129,13 @@ def test_add_to_prompt_stack_autopruning_enabled(self): ], ) memory.structure = agent - prompt_stack = PromptStack() - prompt_stack.add_system_message("fizz") - prompt_stack.add_user_message("foo") - prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack) + message_stack = MessageStack() + message_stack.add_system_message("fizz") + message_stack.add_user_message("foo") + message_stack.add_assistant_message("bar") + memory.add_to_message_stack(message_stack) - assert len(prompt_stack.messages) == 13 + assert len(message_stack.messages) == 13 # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens @@ -153,17 +153,17 @@ def test_add_to_prompt_stack_autopruning_enabled(self): ], ) memory.structure = agent - prompt_stack = PromptStack() + message_stack = MessageStack() # And then another 6 tokens from fizz for a total of 161 tokens. - prompt_stack.add_system_message("fizz") - prompt_stack.add_user_message("foo") - prompt_stack.add_assistant_message("bar") - memory.add_to_prompt_stack(prompt_stack, 1) - - # We expect one run (2 prompt stack inputs) to be pruned. - assert len(prompt_stack.messages) == 11 - assert prompt_stack.messages[0].content[0].artifact.value == "fizz" - assert prompt_stack.messages[1].content[0].artifact.value == "foo2" - assert prompt_stack.messages[2].content[0].artifact.value == "bar2" - assert prompt_stack.messages[-2].content[0].artifact.value == "foo" - assert prompt_stack.messages[-1].content[0].artifact.value == "bar" + message_stack.add_system_message("fizz") + message_stack.add_user_message("foo") + message_stack.add_assistant_message("bar") + memory.add_to_message_stack(message_stack, 1) + + # We expect one run (2 Message Stack inputs) to be pruned. + assert len(message_stack.messages) == 11 + assert message_stack.messages[0].content[0].artifact.value == "fizz" + assert message_stack.messages[1].content[0].artifact.value == "foo2" + assert message_stack.messages[2].content[0].artifact.value == "bar2" + assert message_stack.messages[-2].content[0].artifact.value == "foo" + assert message_stack.messages[-1].content[0].artifact.value == "bar" diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index e625ac6c6..e98fb6724 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -53,15 +53,15 @@ def test_to_dict(self): assert memory.to_dict()["type"] == "SummaryConversationMemory" assert memory.to_dict()["runs"][0]["input"]["value"] == "foo" - def test_to_prompt_stack(self): + def test_to_message_stack(self): memory = SummaryConversationMemory(summary="foobar") memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) - prompt_stack = memory.to_prompt_stack() + message_stack = memory.to_message_stack() - assert prompt_stack.messages[0].content[0].artifact.value == "Summary of the conversation so far: foobar" - assert prompt_stack.messages[1].content[0].artifact.value == "foo" - assert prompt_stack.messages[2].content[0].artifact.value == "bar" + assert message_stack.messages[0].content[0].artifact.value == "Summary of the conversation so far: foobar" + assert message_stack.messages[1].content[0].artifact.value == "foo" + assert message_stack.messages[2].content[0].artifact.value == "bar" def test_from_dict(self): memory = SummaryConversationMemory() diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 0d5d8a565..0fbc3b93e 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -157,39 +157,39 @@ def test_add_tasks(self): except ValueError: assert True - def test_prompt_stack_without_memory(self): + def test_message_stack_without_memory(self): agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=None) task1 = PromptTask("test") agent.add_task(task1) - assert len(task1.prompt_stack.messages) == 2 + assert len(task1.message_stack.messages) == 2 agent.run() - assert len(task1.prompt_stack.messages) == 3 + assert len(task1.message_stack.messages) == 3 agent.run() - assert len(task1.prompt_stack.messages) == 3 + assert len(task1.message_stack.messages) == 3 - def test_prompt_stack_with_memory(self): + def test_message_stack_with_memory(self): agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) task1 = PromptTask("test") agent.add_task(task1) - assert len(task1.prompt_stack.messages) == 2 + assert len(task1.message_stack.messages) == 2 agent.run() - assert len(task1.prompt_stack.messages) == 5 + assert len(task1.message_stack.messages) == 5 agent.run() - assert len(task1.prompt_stack.messages) == 7 + assert len(task1.message_stack.messages) == 7 def test_run(self): task = PromptTask("test") diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index 5131ed728..1d9ed10a8 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -251,7 +251,7 @@ def test_insert_task_at_end(self): assert [parent.id for parent in third_task.parents] == ["test2"] assert [child.id for child in third_task.children] == [] - def test_prompt_stack_without_memory(self): + def test_message_stack_without_memory(self): pipeline = Pipeline(conversation_memory=None, prompt_driver=MockPromptDriver()) task1 = PromptTask("test") @@ -259,20 +259,20 @@ def test_prompt_stack_without_memory(self): pipeline.add_tasks(task1, task2) - assert len(task1.prompt_stack.messages) == 2 - assert len(task2.prompt_stack.messages) == 2 + assert len(task1.message_stack.messages) == 2 + assert len(task2.message_stack.messages) == 2 pipeline.run() - assert len(task1.prompt_stack.messages) == 3 - assert len(task2.prompt_stack.messages) == 3 + assert len(task1.message_stack.messages) == 3 + assert len(task2.message_stack.messages) == 3 pipeline.run() - assert len(task1.prompt_stack.messages) == 3 - assert len(task2.prompt_stack.messages) == 3 + assert len(task1.message_stack.messages) == 3 + assert len(task2.message_stack.messages) == 3 - def test_prompt_stack_with_memory(self): + def test_message_stack_with_memory(self): pipeline = Pipeline(prompt_driver=MockPromptDriver()) task1 = PromptTask("test") @@ -280,18 +280,18 @@ def test_prompt_stack_with_memory(self): pipeline.add_tasks(task1, task2) - assert len(task1.prompt_stack.messages) == 2 - assert len(task2.prompt_stack.messages) == 2 + assert len(task1.message_stack.messages) == 2 + assert len(task2.message_stack.messages) == 2 pipeline.run() - assert len(task1.prompt_stack.messages) == 5 - assert len(task2.prompt_stack.messages) == 5 + assert len(task1.message_stack.messages) == 5 + assert len(task2.message_stack.messages) == 5 pipeline.run() - assert len(task1.prompt_stack.messages) == 7 - assert len(task2.prompt_stack.messages) == 7 + assert len(task1.message_stack.messages) == 7 + assert len(task2.message_stack.messages) == 7 def test_text_artifact_token_count(self): text = "foobar" diff --git a/tests/unit/tokenizers/test_google_tokenizer.py b/tests/unit/tokenizers/test_google_tokenizer.py index 0f940b06d..f8ca60452 100644 --- a/tests/unit/tokenizers/test_google_tokenizer.py +++ b/tests/unit/tokenizers/test_google_tokenizer.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import Mock -from griptape.common import PromptStack -from griptape.common.prompt_stack.messages.prompt_stack_message import PromptStackMessage +from griptape.common import MessageStack +from griptape.common.message_stack.messages.message import Message from griptape.tokenizers import GoogleTokenizer @@ -20,9 +20,7 @@ def tokenizer(self, request): @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 5)], indirect=["tokenizer"]) def test_token_count(self, tokenizer, expected): assert tokenizer.count_tokens("foo bar huzzah") == expected - assert ( - tokenizer.count_tokens(PromptStack(messages=[PromptStackMessage(content="foo", role="user")])) == expected - ) + assert tokenizer.count_tokens(MessageStack(messages=[Message(content="foo", role="user")])) == expected assert tokenizer.count_tokens(["foo", "bar", "huzzah"]) == expected @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 30715)], indirect=["tokenizer"]) diff --git a/tests/unit/utils/test_conversation.py b/tests/unit/utils/test_conversation.py index cce067f73..963903cc6 100644 --- a/tests/unit/utils/test_conversation.py +++ b/tests/unit/utils/test_conversation.py @@ -21,7 +21,7 @@ def test_lines(self): assert lines[2] == "Q: question 1" assert lines[3] == "A: mock output" - def test_prompt_stack_conversation_memory(self): + def test_message_stack_conversation_memory(self): pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) pipeline.add_tasks(PromptTask("question 1")) @@ -29,12 +29,12 @@ def test_prompt_stack_conversation_memory(self): pipeline.run() pipeline.run() - lines = Conversation(pipeline.conversation_memory).prompt_stack() + lines = Conversation(pipeline.conversation_memory).message_stack() assert lines[0] == "user: question 1" assert lines[1] == "assistant: mock output" - def test_prompt_stack_summary_conversation_memory(self): + def test_message_stack_summary_conversation_memory(self): pipeline = Pipeline( prompt_driver=MockPromptDriver(), conversation_memory=SummaryConversationMemory(summary="foobar", prompt_driver=MockPromptDriver()), @@ -45,7 +45,7 @@ def test_prompt_stack_summary_conversation_memory(self): pipeline.run() pipeline.run() - lines = Conversation(pipeline.conversation_memory).prompt_stack() + lines = Conversation(pipeline.conversation_memory).message_stack() assert lines[0] == "user: Summary of the conversation so far: mock output" assert lines[1] == "user: question 1" diff --git a/tests/unit/utils/test_message_stack.py b/tests/unit/utils/test_message_stack.py new file mode 100644 index 000000000..9bab66d0b --- /dev/null +++ b/tests/unit/utils/test_message_stack.py @@ -0,0 +1,55 @@ +import pytest + +from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact +from griptape.common import ImageMessageContent, MessageStack, TextMessageContent + + +class TestPromptStack: + @pytest.fixture + def message_stack(self): + return MessageStack() + + def test_init(self): + assert MessageStack() + + def test_add_message(self, message_stack): + message_stack.add_message("foo", "role") + message_stack.add_message(TextArtifact("foo"), "role") + message_stack.add_message(ImageArtifact(b"foo", format="png", width=100, height=100), "role") + message_stack.add_message(ListArtifact([TextArtifact("foo"), TextArtifact("bar")]), "role") + + assert message_stack.messages[0].role == "role" + assert isinstance(message_stack.messages[0].content[0], TextMessageContent) + assert message_stack.messages[0].content[0].artifact.value == "foo" + + assert message_stack.messages[1].role == "role" + assert isinstance(message_stack.messages[1].content[0], TextMessageContent) + assert message_stack.messages[1].content[0].artifact.value == "foo" + + assert message_stack.messages[2].role == "role" + assert isinstance(message_stack.messages[2].content[0], ImageMessageContent) + assert message_stack.messages[2].content[0].artifact.value == b"foo" + + assert message_stack.messages[3].role == "role" + assert isinstance(message_stack.messages[3].content[0], TextMessageContent) + assert message_stack.messages[3].content[0].artifact.value == "foo" + assert isinstance(message_stack.messages[3].content[1], TextMessageContent) + assert message_stack.messages[3].content[1].artifact.value == "bar" + + def test_add_system_message(self, message_stack): + message_stack.add_system_message("foo") + + assert message_stack.messages[0].role == "system" + assert message_stack.messages[0].content[0].artifact.value == "foo" + + def test_add_user_message(self, message_stack): + message_stack.add_user_message("foo") + + assert message_stack.messages[0].role == "user" + assert message_stack.messages[0].content[0].artifact.value == "foo" + + def test_add_assistant_message(self, message_stack): + message_stack.add_assistant_message("foo") + + assert message_stack.messages[0].role == "assistant" + assert message_stack.messages[0].content[0].artifact.value == "foo" diff --git a/tests/unit/utils/test_prompt_stack.py b/tests/unit/utils/test_prompt_stack.py deleted file mode 100644 index 98c9f48ff..000000000 --- a/tests/unit/utils/test_prompt_stack.py +++ /dev/null @@ -1,55 +0,0 @@ -import pytest - -from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact -from griptape.common import ImagePromptStackContent, PromptStack, TextPromptStackContent - - -class TestPromptStack: - @pytest.fixture - def prompt_stack(self): - return PromptStack() - - def test_init(self): - assert PromptStack() - - def test_add_message(self, prompt_stack): - prompt_stack.add_message("foo", "role") - prompt_stack.add_message(TextArtifact("foo"), "role") - prompt_stack.add_message(ImageArtifact(b"foo", format="png", width=100, height=100), "role") - prompt_stack.add_message(ListArtifact([TextArtifact("foo"), TextArtifact("bar")]), "role") - - assert prompt_stack.messages[0].role == "role" - assert isinstance(prompt_stack.messages[0].content[0], TextPromptStackContent) - assert prompt_stack.messages[0].content[0].artifact.value == "foo" - - assert prompt_stack.messages[1].role == "role" - assert isinstance(prompt_stack.messages[1].content[0], TextPromptStackContent) - assert prompt_stack.messages[1].content[0].artifact.value == "foo" - - assert prompt_stack.messages[2].role == "role" - assert isinstance(prompt_stack.messages[2].content[0], ImagePromptStackContent) - assert prompt_stack.messages[2].content[0].artifact.value == b"foo" - - assert prompt_stack.messages[3].role == "role" - assert isinstance(prompt_stack.messages[3].content[0], TextPromptStackContent) - assert prompt_stack.messages[3].content[0].artifact.value == "foo" - assert isinstance(prompt_stack.messages[3].content[1], TextPromptStackContent) - assert prompt_stack.messages[3].content[1].artifact.value == "bar" - - def test_add_system_message(self, prompt_stack): - prompt_stack.add_system_message("foo") - - assert prompt_stack.messages[0].role == "system" - assert prompt_stack.messages[0].content[0].artifact.value == "foo" - - def test_add_user_message(self, prompt_stack): - prompt_stack.add_user_message("foo") - - assert prompt_stack.messages[0].role == "user" - assert prompt_stack.messages[0].content[0].artifact.value == "foo" - - def test_add_assistant_message(self, prompt_stack): - prompt_stack.add_assistant_message("foo") - - assert prompt_stack.messages[0].role == "assistant" - assert prompt_stack.messages[0].content[0].artifact.value == "foo" From 1297528798d0303e83ee290a0ef9dde6c315d2a9 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 09:48:07 -0700 Subject: [PATCH 21/34] Improve test coverage --- griptape/drivers/prompt/base_prompt_driver.py | 2 +- .../drivers/prompt/cohere_prompt_driver.py | 9 +++------ griptape/tasks/prompt_task.py | 6 +----- tests/mocks/mock_prompt_driver.py | 6 +++--- .../prompt/test_anthropic_prompt_driver.py | 13 +++++++----- .../drivers/prompt/test_base_prompt_driver.py | 6 ++++++ .../prompt/test_cohere_prompt_driver.py | 14 +++++++++++-- .../prompt/test_openai_chat_prompt_driver.py | 20 +++++++++++++------ 8 files changed, 48 insertions(+), 28 deletions(-) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 5adb2b601..3f5551d07 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -135,7 +135,7 @@ def __process_stream(self, message_stack: MessageStack) -> Message: # Build a complete content from the content deltas content = [] - for index, deltas in delta_contents.items(): + for deltas in delta_contents.values(): text_deltas = [delta for delta in deltas if isinstance(delta, TextDeltaMessageContent)] if text_deltas: content.append(TextMessageContent.from_deltas(text_deltas)) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index c9bd0e119..3ee27d56f 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -84,13 +84,10 @@ def _base_params(self, message_stack: MessageStack) -> dict: ) system_message = next((message for message in message_stack.messages if message.is_system()), None) - if system_message is not None: - if len(system_message.content) == 1: - preamble = system_message.content[0].artifact.to_text() - else: - raise ValueError("System message must have exactly one content.") - else: + if system_message is None: preamble = None + else: + preamble = system_message.to_text_artifact().to_text() return { "message": user_message, diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index b0743303a..65369100a 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -98,9 +98,7 @@ def run(self) -> BaseArtifact: return self.output - def _process_task_input( - self, task_input: str | list | BaseArtifact | Callable[[BaseTask], BaseArtifact] - ) -> BaseArtifact: + def _process_task_input(self, task_input: str | BaseArtifact | Callable[[BaseTask], BaseArtifact]) -> BaseArtifact: if isinstance(task_input, TextArtifact): task_input.value = J2().render_from_string(task_input.value, **self.full_context) @@ -111,8 +109,6 @@ def _process_task_input( return self._process_task_input(TextArtifact(task_input)) elif isinstance(task_input, BaseArtifact): return task_input - elif isinstance(task_input, list): - return ListArtifact([self._process_task_input(elem) for elem in task_input]) else: raise ValueError(f"Invalid input type: {type(task_input)} ") diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 40782dc39..0a64daa8f 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -32,6 +32,6 @@ def try_run(self, message_stack: MessageStack) -> Message: def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: output = self.mock_output(message_stack) if isinstance(self.mock_output, Callable) else self.mock_output - yield DeltaMessage( - content=TextDeltaMessageContent(output), usage=DeltaMessage.Usage(input_tokens=100, output_tokens=100) - ) + yield DeltaMessage(content=TextDeltaMessageContent(output)) + + yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=100, output_tokens=100)) diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 7f118d151..4328c7f1e 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,6 +1,6 @@ from griptape.drivers import AnthropicPromptDriver from griptape.common import MessageStack -from griptape.artifacts import TextArtifact, ImageArtifact +from griptape.artifacts import TextArtifact, ImageArtifact, ListArtifact from unittest.mock import Mock import pytest @@ -119,18 +119,21 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): if system_enabled: message_stack.add_system_message("system-input") message_stack.add_user_message("user-input") - message_stack.add_user_message(TextArtifact("user-input")) - message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + message_stack.add_user_message( + ListArtifact( + [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] + ) + ) message_stack.add_assistant_message("assistant-input") expected_messages = [ - {"role": "user", "content": "user-input"}, {"role": "user", "content": "user-input"}, { "content": [ + {"type": "text", "text": "user-input"}, { "source": {"data": "aW1hZ2UtZGF0YQ==", "media_type": "image/png", "type": "base64"}, "type": "image", - } + }, ], "role": "user", }, diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 78936748e..5b3d938b6 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -39,6 +39,12 @@ def test_run_via_pipeline_publishes_events(self, mocker): def test_run(self): assert isinstance(MockPromptDriver().run(MessageStack(messages=[])), TextArtifact) + def test_run_with_stream(self): + pipeline = Pipeline() + result = MockPromptDriver(stream=True, structure=pipeline).run(MessageStack(messages=[])) + assert isinstance(result, TextArtifact) + assert result.value == "mock output" + def instance_count(instances, clazz): return len([instance for instance in instances if isinstance(instance, clazz)]) diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 666226c20..45cecff64 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -38,6 +38,8 @@ def message_stack(self): message_stack.add_system_message("system-input") message_stack.add_user_message("user-input") message_stack.add_assistant_message("assistant-input") + message_stack.add_user_message("user-input") + message_stack.add_assistant_message("assistant-input") return message_stack def test_init(self): @@ -52,7 +54,11 @@ def test_try_run(self, mock_client, message_stack): # pyright: ignore # Then mock_client.chat.assert_called_once_with( - chat_history=[{"content": [{"text": "user-input"}], "role": "USER"}], + chat_history=[ + {"content": [{"text": "user-input"}], "role": "USER"}, + {"content": [{"text": "assistant-input"}], "role": "CHATBOT"}, + {"content": [{"text": "user-input"}], "role": "USER"}, + ], max_tokens=None, message="assistant-input", preamble="system-input", @@ -75,7 +81,11 @@ def test_try_stream_run(self, mock_stream_client, message_stack): # pyright: ig # Then mock_stream_client.chat_stream.assert_called_once_with( - chat_history=[{"content": [{"text": "user-input"}], "role": "USER"}], + chat_history=[ + {"content": [{"text": "user-input"}], "role": "USER"}, + {"content": [{"text": "assistant-input"}], "role": "CHATBOT"}, + {"content": [{"text": "user-input"}], "role": "USER"}, + ], max_tokens=None, message="assistant-input", preamble="system-input", diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 6aab853dc..8bd190c3a 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,5 +1,5 @@ -from griptape.artifacts.image_artifact import ImageArtifact -from griptape.artifacts.text_artifact import TextArtifact +from griptape.artifacts import ImageArtifact, ListArtifact +from griptape.artifacts import TextArtifact from griptape.drivers import OpenAiChatPromptDriver from griptape.common import MessageStack from griptape.tokenizers import OpenAiTokenizer @@ -27,6 +27,7 @@ def mock_chat_completion_stream_create(self, mocker): [ Mock(choices=[Mock(delta=Mock(content="model-output"))], usage=None), Mock(choices=None, usage=Mock(prompt_tokens=5, completion_tokens=10)), + Mock(choices=[Mock(delta=Mock(content=None))], usage=None), ] ) return mock_chat_create @@ -36,8 +37,11 @@ def message_stack(self): message_stack = MessageStack() message_stack.add_system_message("system-input") message_stack.add_user_message("user-input") - message_stack.add_user_message(TextArtifact("user-input")) - message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + message_stack.add_user_message( + ListArtifact( + [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] + ) + ) message_stack.add_assistant_message("assistant-input") return message_stack @@ -46,10 +50,12 @@ def messages(self): return [ {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, - {"role": "user", "content": "user-input"}, { "role": "user", - "content": [{"type": "image_url", "image_url": {"url": "data:image/png;base64,aW1hZ2UtZGF0YQ=="}}], + "content": [ + {"type": "text", "text": "user-input"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,aW1hZ2UtZGF0YQ=="}}, + ], }, {"role": "assistant", "content": "assistant-input"}, ] @@ -151,6 +157,8 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, message_stack, event = next(stream) assert event.usage.input_tokens == 5 assert event.usage.output_tokens == 10 + event = next(stream) + assert event.content.text == "" def test_try_run_with_max_tokens(self, mock_chat_completion_create, message_stack, messages): # Given From 2b7fe92b6020f941570ebf6477c0bbc18e3b5f45 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 09:53:57 -0700 Subject: [PATCH 22/34] Add missing module --- griptape/common/message_stack/messages/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 griptape/common/message_stack/messages/__init__.py diff --git a/griptape/common/message_stack/messages/__init__.py b/griptape/common/message_stack/messages/__init__.py new file mode 100644 index 000000000..e69de29bb From 85b8cd076ed9b70a9d432f7b21c5f84d5b1d1b44 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 10:04:18 -0700 Subject: [PATCH 23/34] Update docs --- README.md | 2 +- .../drivers/prompt-drivers.md | 4 +-- docs/griptape-framework/structures/agents.md | 2 +- docs/griptape-framework/structures/tasks.md | 27 ++++++++++++++++++- 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index b5e4f7bcb..3a08dc3ea 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ from griptape.structures import Agent from griptape.tools import WebScraper, FileManager, TaskMemoryClient agent = Agent( - input_template="Load {{ args[0] }}, summarize it, and store it in a file called {{ args[1] }}.", + input="Load {{ args[0] }}, summarize it, and store it in a file called {{ args[1] }}.", tools=[ WebScraper(off_prompt=True), TaskMemoryClient(off_prompt=True), diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index bceda20cb..e2773e863 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -14,7 +14,7 @@ agent = Agent( config=StructureConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3), ), - input_template="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}", + input="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}", rules=[ Rule( value="Output only the sentiment." @@ -80,7 +80,7 @@ agent = Agent( seed=42, ) ), - input_template="You will be provided with a description of a mood, and your task is to generate the CSS code for a color that matches it. Description: {{ args[0] }}", + input="You will be provided with a description of a mood, and your task is to generate the CSS code for a color that matches it. Description: {{ args[0] }}", rules=[ Rule( value='Write your output in json with a single key called "css_code".' diff --git a/docs/griptape-framework/structures/agents.md b/docs/griptape-framework/structures/agents.md index 8737d2a59..b36db72d3 100644 --- a/docs/griptape-framework/structures/agents.md +++ b/docs/griptape-framework/structures/agents.md @@ -15,7 +15,7 @@ from griptape.structures import Agent agent = Agent( - input_template="Calculate the following: {{ args[0] }}", + input="Calculate the following: {{ args[0] }}", tools=[Calculator()] ) diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index c43603b1b..f46cc45b9 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -88,6 +88,31 @@ agent.run("Write me a haiku") Day begins anew. ``` +If the model supports it, you can also pass image inputs: + +```python +from griptape.structures import Agent +from griptape.loaders import ImageLoader + +agent = Agent() +with open("assets/mountain.jpg", "rb") as f: + image_artifact = ImageLoader().load(f.read()) + +agent.run(["What's in this image?", image_artifact]) +``` + +``` +[06/21/24 10:01:08] INFO PromptTask c229d1792da34ab1a7c45768270aada9 + Input: What's in this image? + + Media, type: image/jpeg, size: 82351 bytes +[06/21/24 10:01:12] INFO PromptTask c229d1792da34ab1a7c45768270aada9 + Output: The image depicts a stunning mountain landscape at sunrise or sunset. The sun is partially visible on the left side of the image, + casting a warm golden light over the scene. The mountains are covered with snow at their peaks, and a layer of clouds or fog is settled in the + valleys between them. The sky is a mix of warm colors near the horizon, transitioning to cooler blues higher up, with some scattered clouds + adding texture to the sky. The overall scene is serene and majestic, highlighting the natural beauty of the mountainous terrain. +``` + ## Toolkit Task To use [Griptape Tools](../../griptape-framework/tools/index.md), use a [Toolkit Task](../../reference/griptape/tasks/toolkit_task.md). @@ -742,7 +767,7 @@ def build_researcher(): def build_writer(): writer = Agent( - input_template="Instructions: {{args[0]}}\nContext: {{args[1]}}", + input="Instructions: {{args[0]}}\nContext: {{args[1]}}", rulesets=[ Ruleset( name="Position", From 635249c81d1df82204a7d15aa89d78c0f39079ce Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 10:12:04 -0700 Subject: [PATCH 24/34] Simplify inputs --- griptape/tasks/prompt_task.py | 26 ++++++-------------------- tests/unit/tasks/test_prompt_task.py | 27 +++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 65369100a..281910ffb 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Callable, Optional -from collections.abc import Sequence from attrs import Factory, define, field @@ -30,13 +29,7 @@ class PromptTask(RuleMixin, BaseTask): @property def input(self) -> BaseArtifact: - if isinstance(self._input, list) or isinstance(self._input, tuple): - artifacts = [self._process_task_input(input) for input in self._input] - flattened_artifacts = self.__flatten_artifacts(artifacts) - - return ListArtifact(flattened_artifacts) - else: - return self._process_task_input(self._input) + return self._process_task_input(self._input) @input.setter def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact]) -> None: @@ -98,7 +91,9 @@ def run(self) -> BaseArtifact: return self.output - def _process_task_input(self, task_input: str | BaseArtifact | Callable[[BaseTask], BaseArtifact]) -> BaseArtifact: + def _process_task_input( + self, task_input: str | tuple | list | BaseArtifact | Callable[[BaseTask], BaseArtifact] + ) -> BaseArtifact: if isinstance(task_input, TextArtifact): task_input.value = J2().render_from_string(task_input.value, **self.full_context) @@ -109,16 +104,7 @@ def _process_task_input(self, task_input: str | BaseArtifact | Callable[[BaseTas return self._process_task_input(TextArtifact(task_input)) elif isinstance(task_input, BaseArtifact): return task_input + elif isinstance(task_input, list) or isinstance(task_input, tuple): + return ListArtifact([self._process_task_input(elem) for elem in task_input]) else: raise ValueError(f"Invalid input type: {type(task_input)} ") - - def __flatten_artifacts(self, artifacts: Sequence[BaseArtifact]) -> Sequence[BaseArtifact]: - result = [] - - for elem in artifacts: - if isinstance(elem, ListArtifact): - result.extend(self.__flatten_artifacts(elem.value)) - else: - result.append(elem) - - return result diff --git a/tests/unit/tasks/test_prompt_task.py b/tests/unit/tasks/test_prompt_task.py index a0fb1fd59..bcef8bce1 100644 --- a/tests/unit/tasks/test_prompt_task.py +++ b/tests/unit/tasks/test_prompt_task.py @@ -37,20 +37,24 @@ def test_missing_prompt_driver(self): task.prompt_driver def test_input(self): + # Str task = PromptTask("test") assert task.input.value == "test" + # List of strs task = PromptTask(["test1", "test2"]) assert task.input.value[0].value == "test1" assert task.input.value[1].value == "test2" + # Tuple of strs task = PromptTask(("test1", "test2")) assert task.input.value[0].value == "test1" assert task.input.value[1].value == "test2" + # Image artifact task = PromptTask(ImageArtifact(b"image-data", format="png", width=100, height=100)) assert isinstance(task.input, ImageArtifact) @@ -59,6 +63,7 @@ def test_input(self): assert task.input.width == 100 assert task.input.height == 100 + # List of str and image artifact task = PromptTask(["foo", ImageArtifact(b"image-data", format="png", width=100, height=100)]) assert isinstance(task.input, ListArtifact) @@ -68,6 +73,27 @@ def test_input(self): assert task.input.value[1].format == "png" assert task.input.value[1].width == 100 + # List of str and nested image artifact + task = PromptTask(["foo", [ImageArtifact(b"image-data", format="png", width=100, height=100)]]) + assert isinstance(task.input, ListArtifact) + assert task.input.value[0].value == "foo" + assert isinstance(task.input.value[1], ListArtifact) + assert isinstance(task.input.value[1].value[0], ImageArtifact) + assert task.input.value[1].value[0].value == b"image-data" + assert task.input.value[1].value[0].format == "png" + assert task.input.value[1].value[0].width == 100 + + # Tuple of str and image artifact + task = PromptTask(("foo", ImageArtifact(b"image-data", format="png", width=100, height=100))) + + assert isinstance(task.input, ListArtifact) + assert task.input.value[0].value == "foo" + assert isinstance(task.input.value[1], ImageArtifact) + assert task.input.value[1].value == b"image-data" + assert task.input.value[1].format == "png" + assert task.input.value[1].width == 100 + + # Lambda returning list of str and image artifact task = PromptTask( ListArtifact([TextArtifact("foo"), ImageArtifact(b"image-data", format="png", width=100, height=100)]) ) @@ -79,6 +105,7 @@ def test_input(self): assert task.input.value[1].format == "png" assert task.input.value[1].width == 100 + # Lambda returning list of str and image artifact task = PromptTask( lambda _: ListArtifact( [TextArtifact("foo"), ImageArtifact(b"image-data", format="png", width=100, height=100)] From 52ea784307ba1a2e9478ba71bb085dfb6b0da0c1 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 11:01:33 -0700 Subject: [PATCH 25/34] Regenerate lock file --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 73d84d3f4..438824ae7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6392,4 +6392,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "584f05b52935d6e3bbcafc7b4eae7d1b14fe3a28f4e292e37bd307f3202cd4ff" +content-hash = "007e18c26b43deba75c94988cf55186309f21f689f75ef4db585aa743a237eaa" From 75a5c1ae1a0b217d6c6ea9e70e7e61e34fbb2970 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 11:21:43 -0700 Subject: [PATCH 26/34] Update changelog --- CHANGELOG.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8552a4d59..506d7be69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,11 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added -- `PromptStackMessage` for storing messages in a `PromptStack`. Messages consist of a role, content, and usage. -- `DeltaPromptStackMessage` for storing partial messages in a `PromptStack`. Multiple `DeltaPromptStackMessage` can be combined to form a `PromptStackMessage`. -- `TextPromptStackContent` for storing textual content in a `PromptStackMessage`. -- `ImagePromptStackContent` for storing image content in a `PromptStackMessage`. -- Support for adding `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s to `PromptStack`. +- `Message` for storing messages in a `MessageStack`. Messages consist of a role, content, and usage. +- `DeltaMessage` for storing partial messages in a `MessageStack`. Multiple `DeltaMessage` can be combined to form a `Message`. +- `TextMessageContent` for storing textual content in a `Message`. +- `ImageMessageContent` for storing image content in a `Message`. +- Support for adding `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s to `MessageStack`. - Support for image inputs to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AmazonBedrockPromptDriver`, `AnthropicPromptDriver`, and `GooglePromptDriver`. - Input/output token usage metrics to all Prompt Drivers. - `FinishPromptEvent.input_token_count` and `FinishPromptEvent.output_token_count`. @@ -19,13 +19,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `PromptTask`s to take `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s as input. ### Changed -- **BREAKING**: Moved `griptape.utils.PromptStack` to `griptape.common.PromptStack`. +- **BREAKING**: Moved/renamed `griptape.utils.PromptStack` to `griptape.common.MessageStack`. - **BREAKING**: Renamed `PromptStack.inputs` to `PromptStack.messages`. -- **BREAKING**: Moved `PromptStack.USER_ROLE`, `PromptStack.ASSISTANT_ROLE`, and `PromptStack.SYSTEM_ROLE` to `PromptStackMessage`. -- **BREAKING**: Updated return type of `PromptDriver.try_run` from `TextArtifact` to `PromptStackMessage`. -- **BREAKING**: Updated return type of `PromptDriver.try_stream` from `Iterator[TextArtifact]` to `Iterator[DeltaPromptStackMessage | BaseDeltaPromptStackContent]`. +- **BREAKING**: Moved `PromptStack.USER_ROLE`, `PromptStack.ASSISTANT_ROLE`, and `PromptStack.SYSTEM_ROLE` to `Message`. +- **BREAKING**: Updated return type of `PromptDriver.try_run` from `TextArtifact` to `Message`. +- **BREAKING**: Updated return type of `PromptDriver.try_stream` from `Iterator[TextArtifact]` to `Iterator[DeltaMessage]`. - **BREAKING**: Removed `BasePromptEvent.token_count` in favor of `FinishPromptEvent.input_token_count` and `FinishPromptEvent.output_token_count`. -- **BREAKING**: Removed `StartPromptEvent.prompt`. Use `StartPromptEvent.prompt_stack` instead. +- **BREAKING**: Removed `StartPromptEvent.prompt`. Use `StartPromptEvent.message_stack` instead. - **BREAKING**: Removed `Agent.input_template` in favor of `Agent.input`. - Default Prompt Driver model in `GoogleStructureConfig` to `gemini-1.5-pro`. From 06cb2a33f3578cdf007225af80dd566f935b105c Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 11:22:28 -0700 Subject: [PATCH 27/34] Fix test --- tests/unit/config/test_amazon_bedrock_structure_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index d13f103cb..d75684829 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -92,13 +92,13 @@ def test_to_dict_with_values(self, config_with_values): }, "image_query_driver": { "type": "AmazonBedrockImageQueryDriver", - "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "max_tokens": 256, "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, }, "prompt_driver": { "max_tokens": None, - "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "stream": False, "temperature": 0.1, "type": "AmazonBedrockPromptDriver", From 5c797adca53dd8eb56f1e048c9eb85af45c51913 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 11:25:12 -0700 Subject: [PATCH 28/34] Simplify cohere --- griptape/drivers/prompt/cohere_prompt_driver.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 3ee27d56f..a18d3a3a2 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -74,10 +74,7 @@ def _message_stack_messages_to_messages(self, messages: list[Message]) -> list[d def _base_params(self, message_stack: MessageStack) -> dict: last_input = message_stack.messages[-1] - if last_input is not None and len(last_input.content) == 1: - user_message = last_input.content[0].artifact.to_text() - else: - raise ValueError("User message must have exactly one content.") + user_message = last_input.to_text_artifact().to_text() history_messages = self._message_stack_messages_to_messages( [message for message in message_stack.messages[:-1] if not message.is_system()] From 182ff92eee222b1f082ffa616c892ecd9662a98a Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 11:45:39 -0700 Subject: [PATCH 29/34] Simplify system_prompt access --- .../common/message_stack/message_stack.py | 12 +++++ .../prompt/amazon_bedrock_prompt_driver.py | 4 +- .../drivers/prompt/anthropic_prompt_driver.py | 6 +-- .../drivers/prompt/cohere_prompt_driver.py | 8 ++-- .../drivers/prompt/google_prompt_driver.py | 6 +-- .../test_amazon_bedrock_prompt_driver.py | 11 +++-- .../prompt/test_cohere_prompt_driver.py | 13 +++--- .../prompt/test_google_prompt_driver.py | 46 ++++++++++++------- 8 files changed, 66 insertions(+), 40 deletions(-) diff --git a/griptape/common/message_stack/message_stack.py b/griptape/common/message_stack/message_stack.py index 4a8bb6985..39e85e6f1 100644 --- a/griptape/common/message_stack/message_stack.py +++ b/griptape/common/message_stack/message_stack.py @@ -10,6 +10,18 @@ class MessageStack(SerializableMixin): messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True}) + @property + def system_messages(self) -> list[Message]: + return [message for message in self.messages if message.is_system()] + + @property + def user_messages(self) -> list[Message]: + return [message for message in self.messages if message.is_user()] + + @property + def assistant_messages(self) -> list[Message]: + return [message for message in self.messages if message.is_assistant()] + def add_message(self, artifact: str | BaseArtifact, role: str) -> Message: new_content = self.__process_artifact(artifact) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 6673a8283..b692c5fe2 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -78,9 +78,7 @@ def _message_stack_messages_to_messages(self, messages: list[Message]) -> list[d ] def _base_params(self, message_stack: MessageStack) -> dict: - system_messages = [ - {"text": message.to_text_artifact().to_text()} for message in message_stack.messages if message.is_system() - ] + system_messages = [{"text": message.to_text_artifact().to_text()} for message in message_stack.system_messages] messages = self._message_stack_messages_to_messages( [message for message in message_stack.messages if not message.is_system()] diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index a7d7b62bc..aca074343 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -76,9 +76,9 @@ def _base_params(self, message_stack: MessageStack) -> dict: [message for message in message_stack.messages if not message.is_system()] ) - system_message = next((message for message in message_stack.messages if message.is_system()), None) - if system_message: - system_message = system_message.to_text_artifact().to_text() + system_messages = message_stack.system_messages + if system_messages: + system_message = system_messages[0].to_text_artifact().to_text() else: system_message = None diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index a18d3a3a2..d40842284 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -80,11 +80,11 @@ def _base_params(self, message_stack: MessageStack) -> dict: [message for message in message_stack.messages[:-1] if not message.is_system()] ) - system_message = next((message for message in message_stack.messages if message.is_system()), None) - if system_message is None: - preamble = None + system_messages = message_stack.system_messages + if system_messages: + preamble = system_messages[0].to_text_artifact().to_text() else: - preamble = system_message.to_text_artifact().to_text() + preamble = None return { "message": user_message, diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 83582a7be..d95a48aef 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -122,9 +122,9 @@ def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: ] # Gemini does not have the notion of a system message, so we insert it as part of the first message in the history. - system = next((i for i in message_stack.messages if i.is_system()), None) - if system is not None: - inputs[0]["parts"].insert(0, "\n".join(content.to_text() for content in system.content)) + system_messages = message_stack.system_messages + if system_messages: + inputs[0]["parts"].insert(0, system_messages[0].to_text_artifact().to_text()) return inputs diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index 816692075..d49eb700d 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -30,10 +30,11 @@ def mock_converse_stream(self, mocker): return mock_converse_stream - @pytest.fixture - def message_stack(self): + @pytest.fixture(params=[True, False]) + def message_stack(self, request): message_stack = MessageStack() - message_stack.add_system_message("system-input") + if request.param: + message_stack.add_system_message("system-input") message_stack.add_user_message("user-input") message_stack.add_user_message(TextArtifact("user-input")) message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) @@ -61,7 +62,7 @@ def test_try_run(self, mock_converse, message_stack, messages): mock_converse.assert_called_once_with( modelId=driver.model, messages=messages, - system=[{"text": "system-input"}], + **({"system": [{"text": "system-input"}]} if message_stack.system_messages else {"system": []}), inferenceConfig={"temperature": driver.temperature}, additionalModelRequestFields={}, ) @@ -81,7 +82,7 @@ def test_try_stream_run(self, mock_converse_stream, message_stack, messages): mock_converse_stream.assert_called_once_with( modelId=driver.model, messages=messages, - system=[{"text": "system-input"}], + **({"system": [{"text": "system-input"}]} if message_stack.system_messages else {"system": []}), inferenceConfig={"temperature": driver.temperature}, additionalModelRequestFields={}, ) diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 45cecff64..86b248c26 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -32,10 +32,11 @@ def mock_stream_client(self, mocker): def mock_tokenizer(self, mocker): return mocker.patch("griptape.tokenizers.CohereTokenizer").return_value - @pytest.fixture - def message_stack(self): + @pytest.fixture(params=[True, False]) + def message_stack(self, request): message_stack = MessageStack() - message_stack.add_system_message("system-input") + if request.param: + message_stack.add_system_message("system-input") message_stack.add_user_message("user-input") message_stack.add_assistant_message("assistant-input") message_stack.add_user_message("user-input") @@ -45,7 +46,7 @@ def message_stack(self): def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") - def test_try_run(self, mock_client, message_stack): # pyright: ignore + def test_try_run(self, mock_client, message_stack): # Given driver = CoherePromptDriver(model="command", api_key="api-key") @@ -61,7 +62,7 @@ def test_try_run(self, mock_client, message_stack): # pyright: ignore ], max_tokens=None, message="assistant-input", - preamble="system-input", + **({"preamble": "system-input"} if message_stack.system_messages else {}), stop_sequences=[], temperature=0.1, ) @@ -88,7 +89,7 @@ def test_try_stream_run(self, mock_stream_client, message_stack): # pyright: ig ], max_tokens=None, message="assistant-input", - preamble="system-input", + **({"preamble": "system-input"} if message_stack.system_messages else {}), stop_sequences=[], temperature=0.1, ) diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 7ef4beb08..93aa5bb3b 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -32,10 +32,12 @@ def test_init(self): driver = GooglePromptDriver(model="gemini-pro", api_key="1234") assert driver - def test_try_run(self, mock_generative_model): + @pytest.mark.parametrize("system_enabled", [True, False]) + def test_try_run(self, mock_generative_model, system_enabled): # Given message_stack = MessageStack() - message_stack.add_system_message("system-input") + if system_enabled: + message_stack.add_system_message("system-input") message_stack.add_user_message("user-input") message_stack.add_user_message(TextArtifact("user-input")) message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) @@ -46,13 +48,18 @@ def test_try_run(self, mock_generative_model): text_artifact = driver.try_run(message_stack) # Then + messages = [ + *( + [{"parts": ["system-input", "user-input"], "role": "user"}] + if system_enabled + else [{"parts": ["user-input"], "role": "user"}] + ), + {"parts": ["user-input"], "role": "user"}, + {"parts": [{"data": b"image-data", "mime_type": "image/png"}], "role": "user"}, + {"parts": ["assistant-input"], "role": "model"}, + ] mock_generative_model.return_value.generate_content.assert_called_once_with( - [ - {"parts": ["system-input", "user-input"], "role": "user"}, - {"parts": ["user-input"], "role": "user"}, - {"parts": [{"data": b"image-data", "mime_type": "image/png"}], "role": "user"}, - {"parts": ["assistant-input"], "role": "model"}, - ], + messages, generation_config=GenerationConfig( max_output_tokens=None, temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[] ), @@ -61,10 +68,12 @@ def test_try_run(self, mock_generative_model): assert text_artifact.usage.input_tokens == 5 assert text_artifact.usage.output_tokens == 10 - def test_try_stream(self, mock_stream_generative_model): + @pytest.mark.parametrize("system_enabled", [True, False]) + def test_try_stream(self, mock_stream_generative_model, system_enabled): # Given message_stack = MessageStack() - message_stack.add_system_message("system-input") + if system_enabled: + message_stack.add_system_message("system-input") message_stack.add_user_message("user-input") message_stack.add_user_message(TextArtifact("user-input")) message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) @@ -76,13 +85,18 @@ def test_try_stream(self, mock_stream_generative_model): # Then event = next(stream) + messages = [ + *( + [{"parts": ["system-input", "user-input"], "role": "user"}] + if system_enabled + else [{"parts": ["user-input"], "role": "user"}] + ), + {"parts": ["user-input"], "role": "user"}, + {"parts": [{"data": b"image-data", "mime_type": "image/png"}], "role": "user"}, + {"parts": ["assistant-input"], "role": "model"}, + ] mock_stream_generative_model.return_value.generate_content.assert_called_once_with( - [ - {"parts": ["system-input", "user-input"], "role": "user"}, - {"parts": ["user-input"], "role": "user"}, - {"parts": [{"data": b"image-data", "mime_type": "image/png"}], "role": "user"}, - {"parts": ["assistant-input"], "role": "model"}, - ], + messages, stream=True, generation_config=GenerationConfig(temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[]), ) From 8dc49449593fc714cdcda01df7628d8845bcb7d8 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 15:21:11 -0700 Subject: [PATCH 30/34] Change prompt driver run return value --- CHANGELOG.md | 1 + .../common/message_stack/messages/message.py | 20 +++++++++---------- .../prompt/amazon_bedrock_prompt_driver.py | 2 +- ...mazon_sagemaker_jumpstart_prompt_driver.py | 2 +- .../drivers/prompt/anthropic_prompt_driver.py | 4 ++-- griptape/drivers/prompt/base_prompt_driver.py | 7 +++---- .../drivers/prompt/cohere_prompt_driver.py | 4 ++-- .../drivers/prompt/google_prompt_driver.py | 2 +- .../prompt/huggingface_hub_prompt_driver.py | 8 ++++---- .../huggingface_pipeline_prompt_driver.py | 4 ++-- .../drivers/prompt/ollama_prompt_driver.py | 2 +- .../prompt/openai_chat_prompt_driver.py | 2 +- griptape/engines/query/vector_query_engine.py | 2 +- .../engines/summary/prompt_summary_engine.py | 9 +++++++-- griptape/tasks/prompt_task.py | 4 ++-- griptape/utils/conversation.py | 2 +- .../drivers/prompt/test_base_prompt_driver.py | 5 +++-- 17 files changed, 43 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 506d7be69..74cd52c4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed `BasePromptEvent.token_count` in favor of `FinishPromptEvent.input_token_count` and `FinishPromptEvent.output_token_count`. - **BREAKING**: Removed `StartPromptEvent.prompt`. Use `StartPromptEvent.message_stack` instead. - **BREAKING**: Removed `Agent.input_template` in favor of `Agent.input`. +- **BREAKING**: `BasePromptDriver.run` now returns a `Message` instead of a `TextArtifact`. For compatibility, `Message.value` contains the Message's Artifact value - Default Prompt Driver model in `GoogleStructureConfig` to `gemini-1.5-pro`. diff --git a/griptape/common/message_stack/messages/message.py b/griptape/common/message_stack/messages/message.py index 3a4e07ccb..fcf6750ea 100644 --- a/griptape/common/message_stack/messages/message.py +++ b/griptape/common/message_stack/messages/message.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact from griptape.common import BaseMessageContent, TextMessageContent from .base_message import BaseMessage @@ -21,18 +21,18 @@ def __init__(self, content: str | list[BaseMessageContent], **kwargs: Any): @property def value(self) -> Any: - if len(self.content) == 1: - return self.content[0].artifact.value - else: - return [content.artifact for content in self.content] + return self.to_artifact().value def __str__(self) -> str: return self.to_text() def to_text(self) -> str: - return self.to_text_artifact().to_text() - - def to_text_artifact(self) -> TextArtifact: - return TextArtifact( - "".join([content.artifact.to_text() for content in self.content if isinstance(content, TextMessageContent)]) + return "".join( + [content.artifact.to_text() for content in self.content if isinstance(content, TextMessageContent)] ) + + def to_artifact(self) -> BaseArtifact: + if len(self.content) == 1: + return self.content[0].artifact + else: + return ListArtifact([content.artifact for content in self.content]) diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index b692c5fe2..cec7aeff0 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -78,7 +78,7 @@ def _message_stack_messages_to_messages(self, messages: list[Message]) -> list[d ] def _base_params(self, message_stack: MessageStack) -> dict: - system_messages = [{"text": message.to_text_artifact().to_text()} for message in message_stack.system_messages] + system_messages = [{"text": message.to_text()} for message in message_stack.system_messages] messages = self._message_stack_messages_to_messages( [message for message in message_stack.messages if not message.is_system()] diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index ff59ced5a..5f5b1a473 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -98,7 +98,7 @@ def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: messages = [] for message in message_stack.messages: - messages.append({"role": message.role, "content": TextMessageContent(message.to_text_artifact())}) + messages.append({"role": message.role, "content": message.to_text()}) return messages diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index aca074343..154758b4e 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -78,7 +78,7 @@ def _base_params(self, message_stack: MessageStack) -> dict: system_messages = message_stack.system_messages if system_messages: - system_message = system_messages[0].to_text_artifact().to_text() + system_message = system_messages[0].to_text() else: system_message = None @@ -101,7 +101,7 @@ def __to_role(self, message: Message) -> str: def __to_content(self, message: Message) -> str | list[dict]: if all(isinstance(content, TextMessageContent) for content in message.content): - return message.to_text_artifact().to_text() + return message.to_text() else: return [self.__message_stack_content_message_content(content) for content in message.content] diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 3f5551d07..c2307f956 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -6,7 +6,6 @@ from attrs import Factory, define, field -from griptape.artifacts.text_artifact import TextArtifact from griptape.common import ( BaseDeltaMessageContent, DeltaMessage, @@ -63,7 +62,7 @@ def after_run(self, result: Message) -> None: ) ) - def run(self, message_stack: MessageStack) -> TextArtifact: + def run(self, message_stack: MessageStack) -> Message: for attempt in self.retrying(): with attempt: self.before_run(message_stack) @@ -75,7 +74,7 @@ def run(self, message_stack: MessageStack) -> TextArtifact: self.after_run(result) - return result.to_text_artifact() + return result else: raise Exception("prompt driver failed after all retry attempts") @@ -92,7 +91,7 @@ def message_stack_to_string(self, message_stack: MessageStack) -> str: prompt_lines = [] for i in message_stack.messages: - content = i.to_text_artifact().to_text() + content = i.to_text() if i.is_user(): prompt_lines.append(f"User: {content}") elif i.is_assistant(): diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index d40842284..b62dca8d2 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -74,7 +74,7 @@ def _message_stack_messages_to_messages(self, messages: list[Message]) -> list[d def _base_params(self, message_stack: MessageStack) -> dict: last_input = message_stack.messages[-1] - user_message = last_input.to_text_artifact().to_text() + user_message = last_input.to_text() history_messages = self._message_stack_messages_to_messages( [message for message in message_stack.messages[:-1] if not message.is_system()] @@ -82,7 +82,7 @@ def _base_params(self, message_stack: MessageStack) -> dict: system_messages = message_stack.system_messages if system_messages: - preamble = system_messages[0].to_text_artifact().to_text() + preamble = system_messages[0].to_text() else: preamble = None diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index d95a48aef..313c9b1a3 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -124,7 +124,7 @@ def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: # Gemini does not have the notion of a system message, so we insert it as part of the first message in the history. system_messages = message_stack.system_messages if system_messages: - inputs[0]["parts"].insert(0, system_messages[0].to_text_artifact().to_text()) + inputs[0]["parts"].insert(0, system_messages[0].to_text()) return inputs diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 74b1b68d2..0363bf31d 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -7,7 +7,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer -from griptape.common import MessageStack, Message, DeltaMessage, TextMessageContent, TextDeltaMessageContent +from griptape.common import MessageStack, Message, DeltaMessage, TextDeltaMessageContent from griptape.utils import import_optional_dependency if TYPE_CHECKING: @@ -84,9 +84,9 @@ def message_stack_to_string(self, message_stack: MessageStack) -> str: def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: messages = [] - for i in message_stack.messages: - if len(i.content) == 1: - messages.append({"role": i.role, "content": TextMessageContent(i.to_text_artifact())}) + for message in message_stack.messages: + if len(message.content) == 1: + messages.append({"role": message.role, "content": message.to_text()}) else: raise ValueError("Invalid input content length.") diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 155d0c488..61e05a4bf 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -75,8 +75,8 @@ def message_stack_to_string(self, message_stack: MessageStack) -> str: def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: messages = [] - for i in message_stack.messages: - messages.append({"role": i.role, "content": i.to_text_artifact().to_text()}) + for message in message_stack.messages: + messages.append({"role": message.role, "content": message.to_text()}) return messages diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index a93270a1c..0743e6032 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -78,7 +78,7 @@ def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: return [ { "role": message.role, - "content": message.to_text_artifact().to_text(), + "content": message.to_text(), **( { "images": [ diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index e181502da..9397124ca 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -147,7 +147,7 @@ def __to_role(self, message: Message) -> str: def __to_content(self, message: Message) -> str | list[dict]: if all(isinstance(content, TextMessageContent) for content in message.content): - return message.to_text_artifact().to_text() + return message.to_text() else: return [self.__message_stack_content_message_content(content) for content in message.content] diff --git a/griptape/engines/query/vector_query_engine.py b/griptape/engines/query/vector_query_engine.py index 08253b73f..4214d04b1 100644 --- a/griptape/engines/query/vector_query_engine.py +++ b/griptape/engines/query/vector_query_engine.py @@ -79,7 +79,7 @@ def query( Message(user_message, role=Message.USER_ROLE), ] ) - ) + ).to_artifact() if isinstance(result, TextArtifact): return result diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 13958ee9c..47863dfa4 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -61,14 +61,19 @@ def summarize_artifacts_rec( self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt) >= self.min_response_tokens ): - return self.prompt_driver.run( + result = self.prompt_driver.run( MessageStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ] ) - ) + ).to_artifact() + + if isinstance(result, TextArtifact): + return result + else: + raise ValueError("Prompt driver did not return a TextArtifact") else: chunks = self.chunker.chunk(artifacts_text) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 281910ffb..5adba029b 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -87,9 +87,9 @@ def after_run(self) -> None: self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") def run(self) -> BaseArtifact: - self.output = self.prompt_driver.run(self.message_stack) + message = self.prompt_driver.run(self.message_stack) - return self.output + return message.to_artifact() def _process_task_input( self, task_input: str | tuple | list | BaseArtifact | Callable[[BaseTask], BaseArtifact] diff --git a/griptape/utils/conversation.py b/griptape/utils/conversation.py index 0bdc078dd..e885800cc 100644 --- a/griptape/utils/conversation.py +++ b/griptape/utils/conversation.py @@ -23,7 +23,7 @@ def message_stack(self) -> list[str]: lines = [] for stack in self.memory.to_message_stack().messages: - lines.append(f"{stack.role}: {stack.to_text_artifact().to_text()}") + lines.append(f"{stack.role}: {stack.to_text()}") return lines diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 5b3d938b6..ac54dc9a1 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,3 +1,4 @@ +from griptape.common.message_stack.messages.message import Message from griptape.events import FinishPromptEvent, StartPromptEvent from griptape.common import MessageStack from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -37,12 +38,12 @@ def test_run_via_pipeline_publishes_events(self, mocker): assert instance_count(events, FinishPromptEvent) == 1 def test_run(self): - assert isinstance(MockPromptDriver().run(MessageStack(messages=[])), TextArtifact) + assert isinstance(MockPromptDriver().run(MessageStack(messages=[])), Message) def test_run_with_stream(self): pipeline = Pipeline() result = MockPromptDriver(stream=True, structure=pipeline).run(MessageStack(messages=[])) - assert isinstance(result, TextArtifact) + assert isinstance(result, Message) assert result.value == "mock output" From 49b93f935afcbecd1107692af05aac53b659d538 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Fri, 21 Jun 2024 17:07:40 -0700 Subject: [PATCH 31/34] Regenerate lock file --- poetry.lock | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index e030c8906..6c5a4f07e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3360,7 +3360,6 @@ description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, ] @@ -4598,7 +4597,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -6395,4 +6393,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "ce26764ee2c4a9a99d24ef4afc7efa6aa894a7560a725388ff24db15f6014e9a" +content-hash = "25929d245a12253a2536123b44c2b7f654a0570df7e88500c563368c53fb25da" From ebedd83397c68e6b3131a2afd3627a27dc4c293c Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 1 Jul 2024 16:52:11 -0500 Subject: [PATCH 32/34] Fix bad merge, tests --- .../engines/rag/modules/base_rag_module.py | 11 ++-- .../prompt_generation_rag_module.py | 10 ++- poetry.lock | 4 +- .../prompt/test_cohere_prompt_driver.py | 62 +++++++------------ tests/unit/tasks/test_toolkit_task.py | 1 + 5 files changed, 38 insertions(+), 50 deletions(-) diff --git a/griptape/engines/rag/modules/base_rag_module.py b/griptape/engines/rag/modules/base_rag_module.py index 8e9b42e93..0dcdb3c35 100644 --- a/griptape/engines/rag/modules/base_rag_module.py +++ b/griptape/engines/rag/modules/base_rag_module.py @@ -4,7 +4,7 @@ from attrs import define, field, Factory -from griptape.utils import PromptStack +from griptape.common import MessageStack, Message @define(kw_only=True) @@ -13,10 +13,7 @@ class BaseRagModule(ABC): default=Factory(lambda: lambda: futures.ThreadPoolExecutor()) ) - def generate_query_prompt_stack(self, system_prompt: str, query: str) -> PromptStack: - return PromptStack( - inputs=[ - PromptStack.Input(system_prompt, role=PromptStack.SYSTEM_ROLE), - PromptStack.Input(query, role=PromptStack.USER_ROLE), - ] + def generate_query_prompt_stack(self, system_prompt: str, query: str) -> MessageStack: + return MessageStack( + messages=[Message(system_prompt, role=Message.SYSTEM_ROLE), Message(query, role=Message.USER_ROLE)] ) diff --git a/griptape/engines/rag/modules/generation/prompt_generation_rag_module.py b/griptape/engines/rag/modules/generation/prompt_generation_rag_module.py index 9e7c3b08f..bd6c82191 100644 --- a/griptape/engines/rag/modules/generation/prompt_generation_rag_module.py +++ b/griptape/engines/rag/modules/generation/prompt_generation_rag_module.py @@ -1,5 +1,6 @@ from typing import Callable from attrs import define, field, Factory +from griptape.artifacts.text_artifact import TextArtifact from griptape.drivers import BasePromptDriver from griptape.engines.rag import RagContext from griptape.engines.rag.modules import BaseGenerationRagModule @@ -30,7 +31,7 @@ def run(self, context: RagContext) -> RagContext: system_prompt = self.generate_system_template(text_chunks, before_query, after_query) message_token_count = self.prompt_driver.tokenizer.count_tokens( - self.prompt_driver.prompt_stack_to_string(self.generate_query_prompt_stack(system_prompt, query)) + self.prompt_driver.message_stack_to_string(self.generate_query_prompt_stack(system_prompt, query)) ) if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens: @@ -40,7 +41,12 @@ def run(self, context: RagContext) -> RagContext: break - context.output = self.prompt_driver.run(self.generate_query_prompt_stack(system_prompt, query)) + output = self.prompt_driver.run(self.generate_query_prompt_stack(system_prompt, query)).to_artifact() + + if isinstance(output, TextArtifact): + context.output = output + else: + raise ValueError("Prompt driver did not return a TextArtifact") return context diff --git a/poetry.lock b/poetry.lock index 6c5a4f07e..becf778c2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohttp" @@ -6393,4 +6393,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "25929d245a12253a2536123b44c2b7f654a0570df7e88500c563368c53fb25da" +content-hash = "ce26764ee2c4a9a99d24ef4afc7efa6aa894a7560a725388ff24db15f6014e9a" diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index e58eb304c..86b248c26 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -9,16 +9,22 @@ class TestCoherePromptDriver: @pytest.fixture def mock_client(self, mocker): - mock_client = mocker.patch("cohere.Client") - mock_client.return_value.chat.return_value = Mock(text="model-output") + mock_client = mocker.patch("cohere.Client").return_value + mock_client.chat.return_value = Mock( + text="model-output", meta=Mock(tokens=Mock(input_tokens=5, output_tokens=10)) + ) return mock_client @pytest.fixture def mock_stream_client(self, mocker): - mock_client = mocker.patch("cohere.Client") - mock_chunk = Mock(text="model-output", event_type="text-generation") - mock_client.return_value.chat_stream.return_value = iter([mock_chunk]) + mock_client = mocker.patch("cohere.Client").return_value + mock_client.chat_stream.return_value = iter( + [ + Mock(text="model-output", event_type="text-generation"), + Mock(response=Mock(meta=Mock(tokens=Mock(input_tokens=5, output_tokens=10))), event_type="stream-end"), + ] + ) return mock_client @@ -45,43 +51,21 @@ def test_try_run(self, mock_client, message_stack): driver = CoherePromptDriver(model="command", api_key="api-key") # When - text_artifact = driver.try_run(prompt_stack) - print(f"Called methods: {mock_client}") - - # Then - expected_message = "assistant-input" - expected_history = [ - {"role": "ASSISTANT", "text": "generic-input"}, - {"role": "SYSTEM", "text": "system-input"}, - {"role": "USER", "text": "user-input"}, - ] - mock_client.return_value.chat.assert_called_once_with( - message=expected_message, - temperature=driver.temperature, - stop_sequences=driver.tokenizer.stop_sequences, - max_tokens=driver.max_tokens, - chat_history=expected_history, - ) - assert text_artifact.value == "model-output" - - def test_try_run_no_history(self, mock_client, prompt_stack): - # Given - prompt_stack_no_history = PromptStack() - prompt_stack_no_history.add_user_input("user-input") - driver = CoherePromptDriver(model="command", api_key="api-key") - - # When - text_artifact = driver.try_run(prompt_stack_no_history) + text_artifact = driver.try_run(message_stack) # Then - expected_message = "user-input" - mock_client.return_value.chat.assert_called_once_with( - message=expected_message, - temperature=driver.temperature, - stop_sequences=driver.tokenizer.stop_sequences, - max_tokens=driver.max_tokens, + mock_client.chat.assert_called_once_with( + chat_history=[ + {"content": [{"text": "user-input"}], "role": "USER"}, + {"content": [{"text": "assistant-input"}], "role": "CHATBOT"}, + {"content": [{"text": "user-input"}], "role": "USER"}, + ], + max_tokens=None, + message="assistant-input", + **({"preamble": "system-input"} if message_stack.system_messages else {}), + stop_sequences=[], + temperature=0.1, ) - assert text_artifact.value == "model-output" assert text_artifact.value == "model-output" assert text_artifact.usage.input_tokens == 5 diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 6df9193d3..cc05a5caf 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -2,6 +2,7 @@ from griptape.structures import Agent from griptape.tasks import ToolkitTask, ActionsSubtask, PromptTask from tests.mocks.mock_tool.tool import MockTool +from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.utils import defaults From 7802c117312f14460b116131970618c568190c62 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 1 Jul 2024 16:57:17 -0500 Subject: [PATCH 33/34] Rename MessageStack back to PromptStack. Snip-snap! Snip-snap! Snip-snap! --- CHANGELOG.md | 10 +-- .../drivers/prompt-drivers.md | 4 +- docs/griptape-framework/misc/events.md | 8 +- griptape/common/__init__.py | 20 ++--- .../__init__.py | 0 .../contents/__init__.py | 0 .../contents/base_delta_message_content.py | 0 .../contents/base_message_content.py | 0 .../contents/image_message_content.py | 0 .../contents/text_delta_message_content.py | 0 .../contents/text_message_content.py | 0 .../messages/__init__.py | 0 .../messages/base_message.py | 0 .../messages/delta_message.py | 2 +- .../messages/message.py | 0 .../prompt_stack.py} | 2 +- .../prompt/amazon_bedrock_prompt_driver.py | 24 +++--- ...mazon_sagemaker_jumpstart_prompt_driver.py | 28 +++---- .../drivers/prompt/anthropic_prompt_driver.py | 32 ++++---- .../prompt/azure_openai_chat_prompt_driver.py | 6 +- griptape/drivers/prompt/base_prompt_driver.py | 38 +++++----- .../drivers/prompt/cohere_prompt_driver.py | 26 +++---- .../drivers/prompt/dummy_prompt_driver.py | 6 +- .../drivers/prompt/google_prompt_driver.py | 20 ++--- .../prompt/huggingface_hub_prompt_driver.py | 26 +++---- .../huggingface_pipeline_prompt_driver.py | 22 +++--- .../drivers/prompt/ollama_prompt_driver.py | 18 ++--- .../prompt/openai_chat_prompt_driver.py | 32 ++++---- .../extraction/csv_extraction_engine.py | 10 +-- .../extraction/json_extraction_engine.py | 8 +- .../engines/rag/modules/base_rag_module.py | 6 +- .../prompt_generation_rag_module.py | 2 +- .../engines/summary/prompt_summary_engine.py | 8 +- griptape/events/start_prompt_event.py | 4 +- .../structure/base_conversation_memory.py | 36 ++++----- .../memory/structure/conversation_memory.py | 12 +-- .../structure/summary_conversation_memory.py | 10 +-- griptape/schemas/base_schema.py | 4 +- griptape/tasks/prompt_task.py | 10 +-- griptape/tasks/tool_task.py | 2 +- griptape/tasks/toolkit_task.py | 12 +-- griptape/utils/conversation.py | 4 +- tests/mocks/mock_failing_prompt_driver.py | 6 +- tests/mocks/mock_prompt_driver.py | 12 +-- .../test_amazon_bedrock_prompt_driver.py | 30 ++++---- ...mazon_sagemaker_jumpstart_prompt_driver.py | 22 +++--- .../prompt/test_anthropic_prompt_driver.py | 34 ++++----- .../test_azure_openai_chat_prompt_driver.py | 8 +- .../drivers/prompt/test_base_prompt_driver.py | 8 +- .../prompt/test_cohere_prompt_driver.py | 30 ++++---- .../prompt/test_google_prompt_driver.py | 30 ++++---- .../test_hugging_face_hub_prompt_driver.py | 22 +++--- ...est_hugging_face_pipeline_prompt_driver.py | 34 ++++----- .../prompt/test_ollama_prompt_driver.py | 36 ++++----- .../prompt/test_openai_chat_prompt_driver.py | 42 +++++------ .../summary/test_prompt_summary_engine.py | 6 +- tests/unit/events/test_base_event.py | 12 +-- tests/unit/events/test_start_prompt_event.py | 18 ++--- .../structure/test_conversation_memory.py | 74 +++++++++---------- .../test_summary_conversation_memory.py | 10 +-- tests/unit/structures/test_agent.py | 16 ++-- tests/unit/structures/test_pipeline.py | 28 +++---- .../unit/tokenizers/test_google_tokenizer.py | 6 +- tests/unit/utils/test_conversation.py | 8 +- tests/unit/utils/test_message_stack.py | 70 +++++++++--------- 65 files changed, 506 insertions(+), 508 deletions(-) rename griptape/common/{message_stack => prompt_stack}/__init__.py (100%) rename griptape/common/{message_stack => prompt_stack}/contents/__init__.py (100%) rename griptape/common/{message_stack => prompt_stack}/contents/base_delta_message_content.py (100%) rename griptape/common/{message_stack => prompt_stack}/contents/base_message_content.py (100%) rename griptape/common/{message_stack => prompt_stack}/contents/image_message_content.py (100%) rename griptape/common/{message_stack => prompt_stack}/contents/text_delta_message_content.py (100%) rename griptape/common/{message_stack => prompt_stack}/contents/text_message_content.py (100%) rename griptape/common/{message_stack => prompt_stack}/messages/__init__.py (100%) rename griptape/common/{message_stack => prompt_stack}/messages/base_message.py (100%) rename griptape/common/{message_stack => prompt_stack}/messages/delta_message.py (79%) rename griptape/common/{message_stack => prompt_stack}/messages/message.py (100%) rename griptape/common/{message_stack/message_stack.py => prompt_stack/prompt_stack.py} (98%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36e1601d2..12c113dd2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,11 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added -- `Message` for storing messages in a `MessageStack`. Messages consist of a role, content, and usage. -- `DeltaMessage` for storing partial messages in a `MessageStack`. Multiple `DeltaMessage` can be combined to form a `Message`. +- `Message` for storing messages in a `PromptStack`. Messages consist of a role, content, and usage. +- `DeltaMessage` for storing partial messages in a `PromptStack`. Multiple `DeltaMessage` can be combined to form a `Message`. - `TextMessageContent` for storing textual content in a `Message`. - `ImageMessageContent` for storing image content in a `Message`. -- Support for adding `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s to `MessageStack`. +- Support for adding `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s to `PromptStack`. - Support for image inputs to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AmazonBedrockPromptDriver`, `AnthropicPromptDriver`, and `GooglePromptDriver`. - Input/output token usage metrics to all Prompt Drivers. - `FinishPromptEvent.input_token_count` and `FinishPromptEvent.output_token_count`. @@ -19,13 +19,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `PromptTask`s to take `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s as input. ### Changed -- **BREAKING**: Moved/renamed `griptape.utils.PromptStack` to `griptape.common.MessageStack`. +- **BREAKING**: Moved/renamed `griptape.utils.PromptStack` to `griptape.common.PromptStack`. - **BREAKING**: Renamed `PromptStack.inputs` to `PromptStack.messages`. - **BREAKING**: Moved `PromptStack.USER_ROLE`, `PromptStack.ASSISTANT_ROLE`, and `PromptStack.SYSTEM_ROLE` to `Message`. - **BREAKING**: Updated return type of `PromptDriver.try_run` from `TextArtifact` to `Message`. - **BREAKING**: Updated return type of `PromptDriver.try_stream` from `Iterator[TextArtifact]` to `Iterator[DeltaMessage]`. - **BREAKING**: Removed `BasePromptEvent.token_count` in favor of `FinishPromptEvent.input_token_count` and `FinishPromptEvent.output_token_count`. -- **BREAKING**: Removed `StartPromptEvent.prompt`. Use `StartPromptEvent.message_stack` instead. +- **BREAKING**: Removed `StartPromptEvent.prompt`. Use `StartPromptEvent.prompt_stack` instead. - **BREAKING**: Removed `Agent.input_template` in favor of `Agent.input`. - **BREAKING**: `BasePromptDriver.run` now returns a `Message` instead of a `TextArtifact`. For compatibility, `Message.value` contains the Message's Artifact value - Default Prompt Driver model in `GoogleStructureConfig` to `gemini-1.5-pro`. diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 271763cca..5d665d855 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -28,10 +28,10 @@ agent.run("I loved the new Batman movie!") Or use them independently: ```python -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.drivers import OpenAiChatPromptDriver -stack = MessageStack() +stack = PromptStack() stack.add_system_input( "You will be provided with Python code, and your task is to calculate its time complexity." diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index bea54f761..2d3645e94 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -235,7 +235,7 @@ total tokens: 273 ## Inspecting Payloads -You can use the [StartPromptEvent](../../reference/griptape/events/start_prompt_event.md) to inspect the Message Stack and final prompt string before it is sent to the LLM. +You can use the [StartPromptEvent](../../reference/griptape/events/start_prompt_event.md) to inspect the Prompt Stack and final prompt string before it is sent to the LLM. ```python from griptape.structures import Agent @@ -244,8 +244,8 @@ from griptape.events import BaseEvent, StartPromptEvent, EventListener def handler(event: BaseEvent): if isinstance(event, StartPromptEvent): - print("Message Stack MessageStack:") - for message in event.message_stack.messages: + print("Prompt Stack PromptStack:") + for message in event.prompt_stack.messages: print(f"{message.role}: {message.content}") print("Final Prompt String:") print(event.prompt) @@ -259,7 +259,7 @@ agent.run("Write me a poem.") ``` ``` ... -Message Stack Messages: +Prompt Stack Messages: system: user: Write me a poem. Final Prompt String: diff --git a/griptape/common/__init__.py b/griptape/common/__init__.py index 87c81c2ee..2b1189472 100644 --- a/griptape/common/__init__.py +++ b/griptape/common/__init__.py @@ -1,14 +1,14 @@ -from .message_stack.contents.base_message_content import BaseMessageContent -from .message_stack.contents.base_delta_message_content import BaseDeltaMessageContent -from .message_stack.contents.text_delta_message_content import TextDeltaMessageContent -from .message_stack.contents.text_message_content import TextMessageContent -from .message_stack.contents.image_message_content import ImageMessageContent +from .prompt_stack.contents.base_message_content import BaseMessageContent +from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent +from .prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent +from .prompt_stack.contents.text_message_content import TextMessageContent +from .prompt_stack.contents.image_message_content import ImageMessageContent -from .message_stack.messages.base_message import BaseMessage -from .message_stack.messages.delta_message import DeltaMessage -from .message_stack.messages.message import Message +from .prompt_stack.messages.base_message import BaseMessage +from .prompt_stack.messages.delta_message import DeltaMessage +from .prompt_stack.messages.message import Message -from .message_stack.message_stack import MessageStack +from .prompt_stack.prompt_stack import PromptStack __all__ = [ "BaseMessage", @@ -19,5 +19,5 @@ "TextDeltaMessageContent", "TextMessageContent", "ImageMessageContent", - "MessageStack", + "PromptStack", ] diff --git a/griptape/common/message_stack/__init__.py b/griptape/common/prompt_stack/__init__.py similarity index 100% rename from griptape/common/message_stack/__init__.py rename to griptape/common/prompt_stack/__init__.py diff --git a/griptape/common/message_stack/contents/__init__.py b/griptape/common/prompt_stack/contents/__init__.py similarity index 100% rename from griptape/common/message_stack/contents/__init__.py rename to griptape/common/prompt_stack/contents/__init__.py diff --git a/griptape/common/message_stack/contents/base_delta_message_content.py b/griptape/common/prompt_stack/contents/base_delta_message_content.py similarity index 100% rename from griptape/common/message_stack/contents/base_delta_message_content.py rename to griptape/common/prompt_stack/contents/base_delta_message_content.py diff --git a/griptape/common/message_stack/contents/base_message_content.py b/griptape/common/prompt_stack/contents/base_message_content.py similarity index 100% rename from griptape/common/message_stack/contents/base_message_content.py rename to griptape/common/prompt_stack/contents/base_message_content.py diff --git a/griptape/common/message_stack/contents/image_message_content.py b/griptape/common/prompt_stack/contents/image_message_content.py similarity index 100% rename from griptape/common/message_stack/contents/image_message_content.py rename to griptape/common/prompt_stack/contents/image_message_content.py diff --git a/griptape/common/message_stack/contents/text_delta_message_content.py b/griptape/common/prompt_stack/contents/text_delta_message_content.py similarity index 100% rename from griptape/common/message_stack/contents/text_delta_message_content.py rename to griptape/common/prompt_stack/contents/text_delta_message_content.py diff --git a/griptape/common/message_stack/contents/text_message_content.py b/griptape/common/prompt_stack/contents/text_message_content.py similarity index 100% rename from griptape/common/message_stack/contents/text_message_content.py rename to griptape/common/prompt_stack/contents/text_message_content.py diff --git a/griptape/common/message_stack/messages/__init__.py b/griptape/common/prompt_stack/messages/__init__.py similarity index 100% rename from griptape/common/message_stack/messages/__init__.py rename to griptape/common/prompt_stack/messages/__init__.py diff --git a/griptape/common/message_stack/messages/base_message.py b/griptape/common/prompt_stack/messages/base_message.py similarity index 100% rename from griptape/common/message_stack/messages/base_message.py rename to griptape/common/prompt_stack/messages/base_message.py diff --git a/griptape/common/message_stack/messages/delta_message.py b/griptape/common/prompt_stack/messages/delta_message.py similarity index 79% rename from griptape/common/message_stack/messages/delta_message.py rename to griptape/common/prompt_stack/messages/delta_message.py index f022c8e0a..7ff90b08f 100644 --- a/griptape/common/message_stack/messages/delta_message.py +++ b/griptape/common/prompt_stack/messages/delta_message.py @@ -3,7 +3,7 @@ from attrs import define, field -from griptape.common.message_stack.contents.text_delta_message_content import TextDeltaMessageContent +from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent from .base_message import BaseMessage diff --git a/griptape/common/message_stack/messages/message.py b/griptape/common/prompt_stack/messages/message.py similarity index 100% rename from griptape/common/message_stack/messages/message.py rename to griptape/common/prompt_stack/messages/message.py diff --git a/griptape/common/message_stack/message_stack.py b/griptape/common/prompt_stack/prompt_stack.py similarity index 98% rename from griptape/common/message_stack/message_stack.py rename to griptape/common/prompt_stack/prompt_stack.py index 39e85e6f1..ce19696e3 100644 --- a/griptape/common/message_stack/message_stack.py +++ b/griptape/common/prompt_stack/prompt_stack.py @@ -7,7 +7,7 @@ @define -class MessageStack(SerializableMixin): +class PromptStack(SerializableMixin): messages: list[Message] = field(factory=list, kw_only=True, metadata={"serializable": True}) @property diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index cec7aeff0..21d81724d 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: import boto3 - from griptape.common import MessageStack + from griptape.common import PromptStack @define @@ -35,8 +35,8 @@ class AmazonBedrockPromptDriver(BasePromptDriver): default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True ) - def try_run(self, message_stack: MessageStack) -> Message: - response = self.bedrock_client.converse(**self._base_params(message_stack)) + def try_run(self, prompt_stack: PromptStack) -> Message: + response = self.bedrock_client.converse(**self._base_params(prompt_stack)) usage = response["usage"] output_message = response["output"]["message"] @@ -47,8 +47,8 @@ def try_run(self, message_stack: MessageStack) -> Message: usage=Message.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]), ) - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: - response = self.bedrock_client.converse_stream(**self._base_params(message_stack)) + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: + response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack)) stream = response.get("stream") if stream is not None: @@ -68,20 +68,20 @@ def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: else: raise Exception("model response is empty") - def _message_stack_messages_to_messages(self, messages: list[Message]) -> list[dict]: + def _prompt_stack_messages_to_messages(self, messages: list[Message]) -> list[dict]: return [ { "role": self.__to_role(message), - "content": [self.__message_stack_content_message_content(content) for content in message.content], + "content": [self.__prompt_stack_content_message_content(content) for content in message.content], } for message in messages ] - def _base_params(self, message_stack: MessageStack) -> dict: - system_messages = [{"text": message.to_text()} for message in message_stack.system_messages] + def _base_params(self, prompt_stack: PromptStack) -> dict: + system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages] - messages = self._message_stack_messages_to_messages( - [message for message in message_stack.messages if not message.is_system()] + messages = self._prompt_stack_messages_to_messages( + [message for message in prompt_stack.messages if not message.is_system()] ) return { @@ -92,7 +92,7 @@ def _base_params(self, message_stack: MessageStack) -> dict: "additionalModelRequestFields": self.additional_model_request_fields, } - def __message_stack_content_message_content(self, content: BaseMessageContent) -> dict: + def __prompt_stack_content_message_content(self, content: BaseMessageContent) -> dict: if isinstance(content, TextMessageContent): return {"text": content.artifact.to_text()} elif isinstance(content, ImageMessageContent): diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py index 5f5b1a473..dacaa62b0 100644 --- a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -7,7 +7,7 @@ from attrs import Factory, define, field from griptape.artifacts import TextArtifact -from griptape.common import MessageStack, Message, TextMessageContent, DeltaMessage +from griptape.common import PromptStack, Message, TextMessageContent, DeltaMessage from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency @@ -15,7 +15,7 @@ if TYPE_CHECKING: import boto3 - from griptape.common import MessageStack + from griptape.common import PromptStack @define @@ -41,10 +41,10 @@ def validate_stream(self, _, stream): if stream: raise ValueError("streaming is not supported") - def try_run(self, message_stack: MessageStack) -> Message: + def try_run(self, prompt_stack: PromptStack) -> Message: payload = { - "inputs": self.message_stack_to_string(message_stack), - "parameters": {**self._base_params(message_stack)}, + "inputs": self.prompt_stack_to_string(prompt_stack), + "parameters": {**self._base_params(prompt_stack)}, } response = self.sagemaker_client.invoke_endpoint( @@ -69,7 +69,7 @@ def try_run(self, message_stack: MessageStack) -> Message: else: generated_text = decoded_body["generated_text"] - input_tokens = len(self.__message_stack_to_tokens(message_stack)) + input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) return Message( @@ -78,13 +78,13 @@ def try_run(self, message_stack: MessageStack) -> Message: usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise NotImplementedError("streaming is not supported") - def message_stack_to_string(self, message_stack: MessageStack) -> str: - return self.tokenizer.tokenizer.decode(self.__message_stack_to_tokens(message_stack)) + def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: + return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) - def _base_params(self, message_stack: MessageStack) -> dict: + def _base_params(self, prompt_stack: PromptStack) -> dict: return { "temperature": self.temperature, "max_new_tokens": self.max_tokens, @@ -94,16 +94,16 @@ def _base_params(self, message_stack: MessageStack) -> dict: "return_full_text": False, } - def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] - for message in message_stack.messages: + for message in prompt_stack.messages: messages.append({"role": message.role, "content": message.to_text()}) return messages - def __message_stack_to_tokens(self, message_stack: MessageStack) -> list[int]: - messages = self._message_stack_to_messages(message_stack) + def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: + messages = self._prompt_stack_to_messages(prompt_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 154758b4e..1dd7bb7f9 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -11,7 +11,7 @@ DeltaMessage, TextDeltaMessageContent, ImageMessageContent, - MessageStack, + PromptStack, Message, TextMessageContent, ) @@ -48,35 +48,35 @@ class AnthropicPromptDriver(BasePromptDriver): top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) - def try_run(self, message_stack: MessageStack) -> Message: - response = self.client.messages.create(**self._base_params(message_stack)) + def try_run(self, prompt_stack: PromptStack) -> Message: + response = self.client.messages.create(**self._base_params(prompt_stack)) return Message( - content=[self.__message_content_to_message_stack_content(content) for content in response.content], + content=[self.__message_content_to_prompt_stack_content(content) for content in response.content], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens), ) - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: - events = self.client.messages.create(**self._base_params(message_stack), stream=True) + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: + events = self.client.messages.create(**self._base_params(prompt_stack), stream=True) for event in events: if event.type == "content_block_delta": - yield DeltaMessage(content=self.__message_content_delta_to_message_stack_content_delta(event)) + yield DeltaMessage(content=self.__message_content_delta_to_prompt_stack_content_delta(event)) elif event.type == "message_start": yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=event.message.usage.input_tokens)) elif event.type == "message_delta": yield DeltaMessage(usage=DeltaMessage.Usage(output_tokens=event.usage.output_tokens)) - def _message_stack_messages_to_messages(self, messages: list[Message]) -> list[dict]: + def _prompt_stack_messages_to_messages(self, messages: list[Message]) -> list[dict]: return [{"role": self.__to_role(message), "content": self.__to_content(message)} for message in messages] - def _base_params(self, message_stack: MessageStack) -> dict: - messages = self._message_stack_messages_to_messages( - [message for message in message_stack.messages if not message.is_system()] + def _base_params(self, prompt_stack: PromptStack) -> dict: + messages = self._prompt_stack_messages_to_messages( + [message for message in prompt_stack.messages if not message.is_system()] ) - system_messages = message_stack.system_messages + system_messages = prompt_stack.system_messages if system_messages: system_message = system_messages[0].to_text() else: @@ -103,9 +103,9 @@ def __to_content(self, message: Message) -> str | list[dict]: if all(isinstance(content, TextMessageContent) for content in message.content): return message.to_text() else: - return [self.__message_stack_content_message_content(content) for content in message.content] + return [self.__prompt_stack_content_message_content(content) for content in message.content] - def __message_stack_content_message_content(self, content: BaseMessageContent) -> dict: + def __prompt_stack_content_message_content(self, content: BaseMessageContent) -> dict: if isinstance(content, TextMessageContent): return {"type": "text", "text": content.artifact.to_text()} elif isinstance(content, ImageMessageContent): @@ -116,13 +116,13 @@ def __message_stack_content_message_content(self, content: BaseMessageContent) - else: raise ValueError(f"Unsupported prompt content type: {type(content)}") - def __message_content_to_message_stack_content(self, content: ContentBlock) -> BaseMessageContent: + def __message_content_to_prompt_stack_content(self, content: ContentBlock) -> BaseMessageContent: if content.type == "text": return TextMessageContent(TextArtifact(content.text)) else: raise ValueError(f"Unsupported message content type: {content.type}") - def __message_content_delta_to_message_stack_content_delta( + def __message_content_delta_to_prompt_stack_content_delta( self, content_delta: ContentBlockDeltaEvent ) -> TextDeltaMessageContent: index = content_delta.index diff --git a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py index 37350f0a6..50e9effe6 100644 --- a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py @@ -1,6 +1,6 @@ from attrs import define, field, Factory from typing import Callable, Optional -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.drivers import OpenAiChatPromptDriver import openai @@ -41,8 +41,8 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver): ) ) - def _base_params(self, message_stack: MessageStack) -> dict: - params = super()._base_params(message_stack) + def _base_params(self, prompt_stack: PromptStack) -> dict: + params = super()._base_params(prompt_stack) # TODO: Add `seed` parameter once Azure supports it. del params["seed"] diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index b954e15b1..ff7cb988d 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -10,7 +10,7 @@ BaseDeltaMessageContent, DeltaMessage, TextDeltaMessageContent, - MessageStack, + PromptStack, Message, TextMessageContent, ) @@ -30,7 +30,7 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): temperature: The temperature to use for the completion. max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer. structure: An optional `Structure` to publish events to. - message_stack_to_string: A function that converts a `MessageStack` to a string. + prompt_stack_to_string: A function that converts a `PromptStack` to a string. ignored_exception_types: A tuple of exception types to ignore. model: The model name. tokenizer: An instance of `BaseTokenizer` to when calculating tokens. @@ -45,9 +45,9 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, metadata={"serializable": True}) - def before_run(self, message_stack: MessageStack) -> None: + def before_run(self, prompt_stack: PromptStack) -> None: if self.structure: - self.structure.publish_event(StartPromptEvent(model=self.model, message_stack=message_stack)) + self.structure.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: if self.structure: @@ -60,15 +60,15 @@ def after_run(self, result: Message) -> None: ) ) - def run(self, message_stack: MessageStack) -> Message: + def run(self, prompt_stack: PromptStack) -> Message: for attempt in self.retrying(): with attempt: - self.before_run(message_stack) + self.before_run(prompt_stack) if self.stream: - result = self.__process_stream(message_stack) + result = self.__process_stream(prompt_stack) else: - result = self.__process_run(message_stack) + result = self.__process_run(prompt_stack) self.after_run(result) @@ -76,19 +76,19 @@ def run(self, message_stack: MessageStack) -> Message: else: raise Exception("prompt driver failed after all retry attempts") - def message_stack_to_string(self, message_stack: MessageStack) -> str: - """Converts a Message Stack to a string for token counting or model input. + def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: + """Converts a Prompt Stack to a string for token counting or model input. This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens. Args: - message_stack: The Message Stack to convert to a string. + prompt_stack: The Prompt Stack to convert to a string. Returns: - A single string representation of the Message Stack. + A single string representation of the Prompt Stack. """ prompt_lines = [] - for i in message_stack.messages: + for i in prompt_stack.messages: content = i.to_text() if i.is_user(): prompt_lines.append(f"User: {content}") @@ -102,22 +102,22 @@ def message_stack_to_string(self, message_stack: MessageStack) -> str: return "\n\n".join(prompt_lines) @abstractmethod - def try_run(self, message_stack: MessageStack) -> Message: ... + def try_run(self, prompt_stack: PromptStack) -> Message: ... @abstractmethod - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: ... + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... - def __process_run(self, message_stack: MessageStack) -> Message: - result = self.try_run(message_stack) + def __process_run(self, prompt_stack: PromptStack) -> Message: + result = self.try_run(prompt_stack) return result - def __process_stream(self, message_stack: MessageStack) -> Message: + def __process_stream(self, prompt_stack: PromptStack) -> Message: delta_contents: dict[int, list[BaseDeltaMessageContent]] = {} usage = DeltaMessage.Usage() # Aggregate all content deltas from the stream - deltas = self.try_stream(message_stack) + deltas = self.try_stream(prompt_stack) for delta in deltas: usage += delta.usage diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 7615d8e12..331c2c039 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -6,7 +6,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import CohereTokenizer from griptape.common import ( - MessageStack, + PromptStack, Message, DeltaMessage, TextMessageContent, @@ -38,8 +38,8 @@ class CoherePromptDriver(BasePromptDriver): default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True) ) - def try_run(self, message_stack: MessageStack) -> Message: - result = self.client.chat(**self._base_params(message_stack)) + def try_run(self, prompt_stack: PromptStack) -> Message: + result = self.client.chat(**self._base_params(prompt_stack)) usage = result.meta.tokens return Message( @@ -48,8 +48,8 @@ def try_run(self, message_stack: MessageStack) -> Message: usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), ) - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: - result = self.client.chat_stream(**self._base_params(message_stack)) + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: + result = self.client.chat_stream(**self._base_params(prompt_stack)) for event in result: if event.event_type == "text-generation": @@ -61,24 +61,24 @@ def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: usage=DeltaMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens) ) - def _message_stack_messages_to_messages(self, messages: list[Message]) -> list[dict]: + def _prompt_stack_messages_to_messages(self, messages: list[Message]) -> list[dict]: return [ { "role": self.__to_role(message), - "content": [self.__message_stack_content_message_content(content) for content in message.content], + "content": [self.__prompt_stack_content_message_content(content) for content in message.content], } for message in messages ] - def _base_params(self, message_stack: MessageStack) -> dict: - last_input = message_stack.messages[-1] + def _base_params(self, prompt_stack: PromptStack) -> dict: + last_input = prompt_stack.messages[-1] user_message = last_input.to_text() - history_messages = self._message_stack_messages_to_messages( - [message for message in message_stack.messages[:-1] if not message.is_system()] + history_messages = self._prompt_stack_messages_to_messages( + [message for message in prompt_stack.messages[:-1] if not message.is_system()] ) - system_messages = message_stack.system_messages + system_messages = prompt_stack.system_messages if system_messages: preamble = system_messages[0].to_text() else: @@ -93,7 +93,7 @@ def _base_params(self, message_stack: MessageStack) -> dict: **({"preamble": preamble} if preamble else {}), } - def __message_stack_content_message_content(self, content: BaseMessageContent) -> dict: + def __prompt_stack_content_message_content(self, content: BaseMessageContent) -> dict: if isinstance(content, TextMessageContent): return {"text": content.artifact.to_text()} else: diff --git a/griptape/drivers/prompt/dummy_prompt_driver.py b/griptape/drivers/prompt/dummy_prompt_driver.py index aadb0ee8c..72757624d 100644 --- a/griptape/drivers/prompt/dummy_prompt_driver.py +++ b/griptape/drivers/prompt/dummy_prompt_driver.py @@ -3,7 +3,7 @@ from attrs import Factory, define, field -from griptape.common import MessageStack, Message, DeltaMessage +from griptape.common import PromptStack, Message, DeltaMessage from griptape.drivers import BasePromptDriver from griptape.exceptions import DummyException from griptape.tokenizers import DummyTokenizer @@ -14,8 +14,8 @@ class DummyPromptDriver(BasePromptDriver): model: None = field(init=False, default=None, kw_only=True) tokenizer: DummyTokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True) - def try_run(self, message_stack: MessageStack) -> Message: + def try_run(self, prompt_stack: PromptStack) -> Message: raise DummyException(__class__.__name__, "try_run") - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise DummyException(__class__.__name__, "try_stream") diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 313c9b1a3..ff4b07a93 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -11,7 +11,7 @@ DeltaMessage, TextDeltaMessageContent, ImageMessageContent, - MessageStack, + PromptStack, Message, TextMessageContent, ) @@ -45,10 +45,10 @@ class GooglePromptDriver(BasePromptDriver): top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) - def try_run(self, message_stack: MessageStack) -> Message: + def try_run(self, prompt_stack: PromptStack) -> Message: GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig - messages = self._message_stack_to_messages(message_stack) + messages = self._prompt_stack_to_messages(prompt_stack) response: GenerateContentResponse = self.model_client.generate_content( messages, generation_config=GenerationConfig( @@ -70,10 +70,10 @@ def try_run(self, message_stack: MessageStack) -> Message: ), ) - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig - messages = self._message_stack_to_messages(message_stack) + messages = self._prompt_stack_to_messages(prompt_stack) response: Iterator[GenerateContentResponse] = self.model_client.generate_content( messages, stream=True, @@ -114,21 +114,21 @@ def _default_model_client(self) -> GenerativeModel: return genai.GenerativeModel(self.model) - def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: inputs = [ {"role": self.__to_role(message), "parts": self.__to_content(message)} - for message in message_stack.messages + for message in prompt_stack.messages if not message.is_system() ] # Gemini does not have the notion of a system message, so we insert it as part of the first message in the history. - system_messages = message_stack.system_messages + system_messages = prompt_stack.system_messages if system_messages: inputs[0]["parts"].insert(0, system_messages[0].to_text()) return inputs - def __message_stack_content_message_content(self, content: BaseMessageContent) -> ContentDict | str: + def __prompt_stack_content_message_content(self, content: BaseMessageContent) -> ContentDict | str: ContentDict = import_optional_dependency("google.generativeai.types").ContentDict if isinstance(content, TextMessageContent): @@ -145,4 +145,4 @@ def __to_role(self, message: Message) -> str: return "user" def __to_content(self, message: Message) -> list[ContentDict | str]: - return [self.__message_stack_content_message_content(content) for content in message.content] + return [self.__prompt_stack_content_message_content(content) for content in message.content] diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 0363bf31d..f327010c4 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -7,7 +7,7 @@ from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer -from griptape.common import MessageStack, Message, DeltaMessage, TextDeltaMessageContent +from griptape.common import PromptStack, Message, DeltaMessage, TextDeltaMessageContent from griptape.utils import import_optional_dependency if TYPE_CHECKING: @@ -47,13 +47,13 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, message_stack: MessageStack) -> Message: - prompt = self.message_stack_to_string(message_stack) + def try_run(self, prompt_stack: PromptStack) -> Message: + prompt = self.prompt_stack_to_string(prompt_stack) response = self.client.text_generation( prompt, return_full_text=False, max_new_tokens=self.max_tokens, **self.params ) - input_tokens = len(self.__message_stack_to_tokens(message_stack)) + input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(response)) return Message( @@ -62,14 +62,14 @@ def try_run(self, message_stack: MessageStack) -> Message: usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: - prompt = self.message_stack_to_string(message_stack) + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: + prompt = self.prompt_stack_to_string(prompt_stack) response = self.client.text_generation( prompt, return_full_text=False, max_new_tokens=self.max_tokens, stream=True, **self.params ) - input_tokens = len(self.__message_stack_to_tokens(message_stack)) + input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) full_text = "" for token in response: @@ -79,12 +79,12 @@ def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: output_tokens = len(self.tokenizer.tokenizer.encode(full_text)) yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens)) - def message_stack_to_string(self, message_stack: MessageStack) -> str: - return self.tokenizer.tokenizer.decode(self.__message_stack_to_tokens(message_stack)) + def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: + return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) - def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] - for message in message_stack.messages: + for message in prompt_stack.messages: if len(message.content) == 1: messages.append({"role": message.role, "content": message.to_text()}) else: @@ -92,8 +92,8 @@ def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: return messages - def __message_stack_to_tokens(self, message_stack: MessageStack) -> list[int]: - messages = self._message_stack_to_messages(message_stack) + def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: + messages = self._prompt_stack_to_messages(prompt_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index 61e05a4bf..6dabc3e20 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field from griptape.artifacts import TextArtifact -from griptape.common import DeltaMessage, MessageStack, Message, TextMessageContent +from griptape.common import DeltaMessage, PromptStack, Message, TextMessageContent from griptape.drivers import BasePromptDriver from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import import_optional_dependency @@ -42,8 +42,8 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): ) ) - def try_run(self, message_stack: MessageStack) -> Message: - messages = self._message_stack_to_messages(message_stack) + def try_run(self, prompt_stack: PromptStack) -> Message: + messages = self._prompt_stack_to_messages(prompt_stack) result = self.pipe( messages, max_new_tokens=self.max_tokens, temperature=self.temperature, do_sample=True, **self.params @@ -53,7 +53,7 @@ def try_run(self, message_stack: MessageStack) -> Message: if len(result) == 1: generated_text = result[0]["generated_text"][-1]["content"] - input_tokens = len(self.__message_stack_to_tokens(message_stack)) + input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) return Message( @@ -66,22 +66,22 @@ def try_run(self, message_stack: MessageStack) -> Message: else: raise Exception("invalid output format") - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise NotImplementedError("streaming is not supported") - def message_stack_to_string(self, message_stack: MessageStack) -> str: - return self.tokenizer.tokenizer.decode(self.__message_stack_to_tokens(message_stack)) + def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: + return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) - def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] - for message in message_stack.messages: + for message in prompt_stack.messages: messages.append({"role": message.role, "content": message.to_text()}) return messages - def __message_stack_to_tokens(self, message_stack: MessageStack) -> list[int]: - messages = self._message_stack_to_messages(message_stack) + def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: + messages = self._prompt_stack_to_messages(prompt_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py index 0743e6032..5a78abf29 100644 --- a/griptape/drivers/prompt/ollama_prompt_driver.py +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -5,7 +5,7 @@ from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver from griptape.tokenizers.base_tokenizer import BaseTokenizer -from griptape.common import MessageStack, TextMessageContent +from griptape.common import PromptStack, TextMessageContent from griptape.utils import import_optional_dependency from griptape.tokenizers import SimpleTokenizer from griptape.common import Message, DeltaMessage, TextDeltaMessageContent @@ -49,8 +49,8 @@ class OllamaPromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, message_stack: MessageStack) -> Message: - response = self.client.chat(**self._base_params(message_stack)) + def try_run(self, prompt_stack: PromptStack) -> Message: + response = self.client.chat(**self._base_params(prompt_stack)) if isinstance(response, dict): return Message( @@ -60,8 +60,8 @@ def try_run(self, message_stack: MessageStack) -> Message: else: raise Exception("invalid model response") - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: - stream = self.client.chat(**self._base_params(message_stack), stream=True) + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: + stream = self.client.chat(**self._base_params(prompt_stack), stream=True) if isinstance(stream, Iterator): for chunk in stream: @@ -69,12 +69,12 @@ def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: else: raise Exception("invalid model response") - def _base_params(self, message_stack: MessageStack) -> dict: - messages = self._message_stack_to_messages(message_stack) + def _base_params(self, prompt_stack: PromptStack) -> dict: + messages = self._prompt_stack_to_messages(prompt_stack) return {"messages": messages, "model": self.model, "options": self.options} - def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: return [ { "role": message.role, @@ -91,5 +91,5 @@ def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: else {} ), } - for message in message_stack.messages + for message in prompt_stack.messages ] diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 9397124ca..e1e046d11 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -12,7 +12,7 @@ DeltaMessage, TextDeltaMessageContent, ImageMessageContent, - MessageStack, + PromptStack, Message, TextMessageContent, ) @@ -73,14 +73,14 @@ class OpenAiChatPromptDriver(BasePromptDriver): kw_only=True, ) - def try_run(self, message_stack: MessageStack) -> Message: - result = self.client.chat.completions.create(**self._base_params(message_stack)) + def try_run(self, prompt_stack: PromptStack) -> Message: + result = self.client.chat.completions.create(**self._base_params(prompt_stack)) if len(result.choices) == 1: message = result.choices[0].message return Message( - content=[self.__message_to_message_stack_content(message)], + content=[self.__message_to_prompt_stack_content(message)], role=message.role, usage=Message.Usage( input_tokens=result.usage.prompt_tokens, output_tokens=result.usage.completion_tokens @@ -89,9 +89,9 @@ def try_run(self, message_stack: MessageStack) -> Message: else: raise Exception("Completion with more than one choice is not supported yet.") - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: result = self.client.chat.completions.create( - **self._base_params(message_stack), stream=True, stream_options={"include_usage": True} + **self._base_params(prompt_stack), stream=True, stream_options={"include_usage": True} ) for chunk in result: @@ -106,17 +106,17 @@ def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: choice = chunk.choices[0] delta = choice.delta - yield DeltaMessage(content=self.__message_delta_to_message_stack_content_delta(delta)) + yield DeltaMessage(content=self.__message_delta_to_prompt_stack_content_delta(delta)) else: raise Exception("Completion with more than one choice is not supported yet.") - def _message_stack_to_messages(self, message_stack: MessageStack) -> list[dict]: + def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: return [ {"role": self.__to_role(message), "content": self.__to_content(message)} - for message in message_stack.messages + for message in prompt_stack.messages ] - def _base_params(self, message_stack: MessageStack) -> dict: + def _base_params(self, prompt_stack: PromptStack) -> dict: params = { "model": self.model, "temperature": self.temperature, @@ -129,9 +129,9 @@ def _base_params(self, message_stack: MessageStack) -> dict: if self.response_format == "json_object": params["response_format"] = {"type": "json_object"} # JSON mode still requires a system message instructing the LLM to output JSON. - message_stack.add_system_message("Provide your response as a valid JSON object.") + prompt_stack.add_system_message("Provide your response as a valid JSON object.") - messages = self._message_stack_to_messages(message_stack) + messages = self._prompt_stack_to_messages(prompt_stack) params["messages"] = messages @@ -149,9 +149,9 @@ def __to_content(self, message: Message) -> str | list[dict]: if all(isinstance(content, TextMessageContent) for content in message.content): return message.to_text() else: - return [self.__message_stack_content_message_content(content) for content in message.content] + return [self.__prompt_stack_content_message_content(content) for content in message.content] - def __message_stack_content_message_content(self, content: BaseMessageContent) -> dict: + def __prompt_stack_content_message_content(self, content: BaseMessageContent) -> dict: if isinstance(content, TextMessageContent): return {"type": "text", "text": content.artifact.to_text()} elif isinstance(content, ImageMessageContent): @@ -162,13 +162,13 @@ def __message_stack_content_message_content(self, content: BaseMessageContent) - else: raise ValueError(f"Unsupported content type: {type(content)}") - def __message_to_message_stack_content(self, message: ChatCompletionMessage) -> BaseMessageContent: + def __message_to_prompt_stack_content(self, message: ChatCompletionMessage) -> BaseMessageContent: if message.content is not None: return TextMessageContent(TextArtifact(message.content)) else: raise ValueError(f"Unsupported message type: {message}") - def __message_delta_to_message_stack_content_delta(self, content_delta: ChoiceDelta) -> TextDeltaMessageContent: + def __message_delta_to_prompt_stack_content_delta(self, content_delta: ChoiceDelta) -> TextDeltaMessageContent: if content_delta.content is None: return TextDeltaMessageContent("") else: diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index e59b6ec23..6fe9bb879 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -4,8 +4,8 @@ import io from attrs import field, Factory, define from griptape.artifacts import TextArtifact, CsvRowArtifact, ListArtifact, ErrorArtifact -from griptape.common import MessageStack -from griptape.common.message_stack.messages.message import Message +from griptape.common import PromptStack +from griptape.common.prompt_stack.messages.message import Message from griptape.engines import BaseExtractionEngine from griptape.utils import J2 from griptape.rules import Ruleset @@ -64,7 +64,7 @@ def _extract_rec( if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: rows.extend( self.text_to_csv_rows( - self.prompt_driver.run(MessageStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value, + self.prompt_driver.run(PromptStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value, column_names, ) ) @@ -80,9 +80,7 @@ def _extract_rec( rows.extend( self.text_to_csv_rows( - self.prompt_driver.run( - MessageStack(messages=[Message(partial_text, role=Message.USER_ROLE)]) - ).value, + self.prompt_driver.run(PromptStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value, column_names, ) ) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index 744cab563..830092dab 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -3,10 +3,10 @@ import json from attrs import field, Factory, define from griptape.artifacts import TextArtifact, ListArtifact, ErrorArtifact -from griptape.common.message_stack.messages.message import Message +from griptape.common.prompt_stack.messages.message import Message from griptape.engines import BaseExtractionEngine from griptape.utils import J2 -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.rules import Ruleset @@ -59,7 +59,7 @@ def _extract_rec( if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: extractions.extend( self.json_to_text_artifacts( - self.prompt_driver.run(MessageStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value + self.prompt_driver.run(PromptStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value ) ) @@ -74,7 +74,7 @@ def _extract_rec( extractions.extend( self.json_to_text_artifacts( - self.prompt_driver.run(MessageStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value + self.prompt_driver.run(PromptStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value ) ) diff --git a/griptape/engines/rag/modules/base_rag_module.py b/griptape/engines/rag/modules/base_rag_module.py index 0dcdb3c35..13563970d 100644 --- a/griptape/engines/rag/modules/base_rag_module.py +++ b/griptape/engines/rag/modules/base_rag_module.py @@ -4,7 +4,7 @@ from attrs import define, field, Factory -from griptape.common import MessageStack, Message +from griptape.common import PromptStack, Message @define(kw_only=True) @@ -13,7 +13,7 @@ class BaseRagModule(ABC): default=Factory(lambda: lambda: futures.ThreadPoolExecutor()) ) - def generate_query_prompt_stack(self, system_prompt: str, query: str) -> MessageStack: - return MessageStack( + def generate_query_prompt_stack(self, system_prompt: str, query: str) -> PromptStack: + return PromptStack( messages=[Message(system_prompt, role=Message.SYSTEM_ROLE), Message(query, role=Message.USER_ROLE)] ) diff --git a/griptape/engines/rag/modules/generation/prompt_generation_rag_module.py b/griptape/engines/rag/modules/generation/prompt_generation_rag_module.py index bd6c82191..26d95bdd1 100644 --- a/griptape/engines/rag/modules/generation/prompt_generation_rag_module.py +++ b/griptape/engines/rag/modules/generation/prompt_generation_rag_module.py @@ -31,7 +31,7 @@ def run(self, context: RagContext) -> RagContext: system_prompt = self.generate_system_template(text_chunks, before_query, after_query) message_token_count = self.prompt_driver.tokenizer.count_tokens( - self.prompt_driver.message_stack_to_string(self.generate_query_prompt_stack(system_prompt, query)) + self.prompt_driver.prompt_stack_to_string(self.generate_query_prompt_stack(system_prompt, query)) ) if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens: diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 47863dfa4..51259e444 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -2,8 +2,8 @@ from attrs import define, Factory, field from griptape.artifacts import TextArtifact, ListArtifact from griptape.chunkers import BaseChunker, TextChunker -from griptape.common import MessageStack -from griptape.common.message_stack.messages.message import Message +from griptape.common import PromptStack +from griptape.common.prompt_stack.messages.message import Message from griptape.drivers import BasePromptDriver from griptape.engines import BaseSummaryEngine from griptape.utils import J2 @@ -62,7 +62,7 @@ def summarize_artifacts_rec( >= self.min_response_tokens ): result = self.prompt_driver.run( - MessageStack( + PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), @@ -82,7 +82,7 @@ def summarize_artifacts_rec( return self.summarize_artifacts_rec( chunks[1:], self.prompt_driver.run( - MessageStack( + PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(partial_text, role=Message.USER_ROLE), diff --git a/griptape/events/start_prompt_event.py b/griptape/events/start_prompt_event.py index c6cd4aab4..35dae95d6 100644 --- a/griptape/events/start_prompt_event.py +++ b/griptape/events/start_prompt_event.py @@ -5,9 +5,9 @@ from griptape.events.base_prompt_event import BasePromptEvent if TYPE_CHECKING: - from griptape.common import MessageStack + from griptape.common import PromptStack @define class StartPromptEvent(BasePromptEvent): - message_stack: MessageStack = field(kw_only=True, metadata={"serializable": True}) + prompt_stack: PromptStack = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 4e29d9925..a8133d64b 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Optional from attrs import define, field from griptape.memory.structure import Run -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.mixins import SerializableMixin from abc import ABC, abstractmethod @@ -44,40 +44,40 @@ def after_add_run(self) -> None: def try_add_run(self, run: Run) -> None: ... @abstractmethod - def to_message_stack(self, last_n: Optional[int] = None) -> MessageStack: ... + def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: ... - def add_to_message_stack(self, message_stack: MessageStack, index: Optional[int] = None) -> MessageStack: - """Add the Conversation Memory runs to the Message Stack by modifying the messages in place. + def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = None) -> PromptStack: + """Add the Conversation Memory runs to the Prompt Stack by modifying the messages in place. - If autoprune is enabled, this will fit as many Conversation Memory runs into the Message Stack + If autoprune is enabled, this will fit as many Conversation Memory runs into the Prompt Stack as possible without exceeding the token limit. Args: - message_stack: The Message Stack to add the Conversation Memory to. + prompt_stack: The Prompt Stack to add the Conversation Memory to. index: Optional index to insert the Conversation Memory runs at. - Defaults to appending to the end of the Message Stack. + Defaults to appending to the end of the Prompt Stack. """ num_runs_to_fit_in_prompt = len(self.runs) if self.autoprune and hasattr(self, "structure"): should_prune = True prompt_driver = self.structure.config.prompt_driver - temp_stack = MessageStack() + temp_stack = PromptStack() # Try to determine how many Conversation Memory runs we can - # fit into the Message Stack without exceeding the token limit. + # fit into the Prompt Stack without exceeding the token limit. while should_prune and num_runs_to_fit_in_prompt > 0: - temp_stack.messages = message_stack.messages.copy() + temp_stack.messages = prompt_stack.messages.copy() # Add n runs from Conversation Memory. - # Where we insert into the Message Stack doesn't matter here + # Where we insert into the Prompt Stack doesn't matter here # since we only care about the total token count. - memory_inputs = self.to_message_stack(num_runs_to_fit_in_prompt).messages + memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).messages temp_stack.messages.extend(memory_inputs) - # Convert the Message Stack into tokens left. + # Convert the Prompt Stack into tokens left. tokens_left = prompt_driver.tokenizer.count_input_tokens_left( - prompt_driver.message_stack_to_string(temp_stack) + prompt_driver.prompt_stack_to_string(temp_stack) ) if tokens_left > 0: # There are still tokens left, no need to prune. @@ -87,10 +87,10 @@ def add_to_message_stack(self, message_stack: MessageStack, index: Optional[int] num_runs_to_fit_in_prompt -= 1 if num_runs_to_fit_in_prompt: - memory_inputs = self.to_message_stack(num_runs_to_fit_in_prompt).messages + memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).messages if index: - message_stack.messages[index:index] = memory_inputs + prompt_stack.messages[index:index] = memory_inputs else: - message_stack.messages.extend(memory_inputs) + prompt_stack.messages.extend(memory_inputs) - return message_stack + return prompt_stack diff --git a/griptape/memory/structure/conversation_memory.py b/griptape/memory/structure/conversation_memory.py index b1401c1e4..42d160abd 100644 --- a/griptape/memory/structure/conversation_memory.py +++ b/griptape/memory/structure/conversation_memory.py @@ -2,7 +2,7 @@ from attrs import define from typing import Optional from griptape.memory.structure import Run, BaseConversationMemory -from griptape.common import MessageStack +from griptape.common import PromptStack @define @@ -14,10 +14,10 @@ def try_add_run(self, run: Run) -> None: while len(self.runs) > self.max_runs: self.runs.pop(0) - def to_message_stack(self, last_n: Optional[int] = None) -> MessageStack: - message_stack = MessageStack() + def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: + prompt_stack = PromptStack() runs = self.runs[-last_n:] if last_n else self.runs for run in runs: - message_stack.add_user_message(run.input) - message_stack.add_assistant_message(run.output) - return message_stack + prompt_stack.add_user_message(run.input) + prompt_stack.add_assistant_message(run.output) + return prompt_stack diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index de1a198d9..b88a4b4e6 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -2,9 +2,9 @@ import logging from typing import TYPE_CHECKING, Optional from attrs import define, field, Factory -from griptape.common.message_stack.messages.message import Message +from griptape.common.prompt_stack.messages.message import Message from griptape.utils import J2 -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.memory.structure import ConversationMemory if TYPE_CHECKING: @@ -36,8 +36,8 @@ def prompt_driver(self) -> BasePromptDriver: def prompt_driver(self, value: BasePromptDriver) -> None: self._prompt_driver = value - def to_message_stack(self, last_n: Optional[int] = None) -> MessageStack: - stack = MessageStack() + def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: + stack = PromptStack() if self.summary: stack.add_user_message(self.summary_template_generator.render(summary=self.summary)) @@ -75,7 +75,7 @@ def summarize_runs(self, previous_summary: str | None, runs: list[Run]) -> str | if len(runs) > 0: summary = self.summarize_conversation_template_generator.render(summary=previous_summary, runs=runs) return self.prompt_driver.run( - message_stack=MessageStack(messages=[Message(summary, role=Message.USER_ROLE)]) + prompt_stack=PromptStack(messages=[Message(summary, role=Message.USER_ROLE)]) ).to_text() else: return previous_summary diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 1026ade1a..1099e1bf7 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -105,7 +105,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: # These modules are required to avoid `NameError`s when resolving types. from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver from griptape.structures import Structure - from griptape.common import MessageStack, Message + from griptape.common import PromptStack, Message from griptape.tokenizers.base_tokenizer import BaseTokenizer from typing import Any @@ -115,7 +115,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: attrs.resolve_types( attrs_cls, localns={ - "MessageStack": MessageStack, + "PromptStack": PromptStack, "Usage": Message.Usage, "Structure": Structure, "BaseConversationMemoryDriver": BaseConversationMemoryDriver, diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 5adba029b..f10899f76 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -5,7 +5,7 @@ from attrs import Factory, define, field from griptape.artifacts import BaseArtifact -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.tasks import BaseTask from griptape.utils import J2 from griptape.artifacts import TextArtifact, ListArtifact @@ -38,8 +38,8 @@ def input(self, value: str | list | tuple | BaseArtifact | Callable[[BaseTask], output: Optional[BaseArtifact] = field(default=None, init=False) @property - def message_stack(self) -> MessageStack: - stack = MessageStack() + def prompt_stack(self) -> PromptStack: + stack = PromptStack() memory = self.structure.conversation_memory stack.add_system_message(self.generate_system_template(self)) @@ -51,7 +51,7 @@ def message_stack(self) -> MessageStack: if memory: # inserting at index 1 to place memory right after system prompt - memory.add_to_message_stack(stack, 1) + memory.add_to_prompt_stack(stack, 1) return stack @@ -87,7 +87,7 @@ def after_run(self) -> None: self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") def run(self) -> BaseArtifact: - message = self.prompt_driver.run(self.message_stack) + message = self.prompt_driver.run(self.prompt_stack) return message.to_artifact() diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index aec246da9..edd90c26e 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -48,7 +48,7 @@ def actions_schema(self) -> Schema: return self._actions_schema_for_tools([self.tool]) def run(self) -> BaseArtifact: - prompt_output = self.prompt_driver.run(message_stack=self.message_stack).to_text() + prompt_output = self.prompt_driver.run(prompt_stack=self.prompt_stack).to_text() action_matches = re.findall(self.ACTION_PATTERN, prompt_output, re.DOTALL) if action_matches: diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 9c0fe33bb..58300b529 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -10,7 +10,7 @@ from griptape.tasks import ActionsSubtask from griptape.tasks import PromptTask from griptape.utils import J2 -from griptape.common import MessageStack +from griptape.common import PromptStack if TYPE_CHECKING: from griptape.tools import BaseTool @@ -61,8 +61,8 @@ def tool_output_memory(self) -> list[TaskMemory]: return list(unique_memory_dict.values()) @property - def message_stack(self) -> MessageStack: - stack = MessageStack() + def prompt_stack(self) -> PromptStack: + stack = PromptStack() memory = self.structure.conversation_memory stack.add_system_message(self.generate_system_template(self)) @@ -78,7 +78,7 @@ def message_stack(self) -> MessageStack: if memory: # inserting at index 1 to place memory right after system prompt - memory.add_to_message_stack(stack, 1) + memory.add_to_prompt_stack(stack, 1) return stack @@ -131,7 +131,7 @@ def run(self) -> BaseArtifact: self.subtasks.clear() self.prompt_driver.tokenizer.stop_sequences.extend([self.response_stop_sequence]) - subtask = self.add_subtask(ActionsSubtask(self.prompt_driver.run(message_stack=self.message_stack).to_text())) + subtask = self.add_subtask(ActionsSubtask(self.prompt_driver.run(prompt_stack=self.prompt_stack).to_text())) while True: if subtask.output is None: @@ -146,7 +146,7 @@ def run(self) -> BaseArtifact: subtask.after_run() subtask = self.add_subtask( - ActionsSubtask(self.prompt_driver.run(message_stack=self.message_stack).to_text()) + ActionsSubtask(self.prompt_driver.run(prompt_stack=self.prompt_stack).to_text()) ) else: break diff --git a/griptape/utils/conversation.py b/griptape/utils/conversation.py index e885800cc..ae05e8b99 100644 --- a/griptape/utils/conversation.py +++ b/griptape/utils/conversation.py @@ -19,10 +19,10 @@ def lines(self) -> list[str]: return lines - def message_stack(self) -> list[str]: + def prompt_stack(self) -> list[str]: lines = [] - for stack in self.memory.to_message_stack().messages: + for stack in self.memory.to_prompt_stack().messages: lines.append(f"{stack.role}: {stack.to_text()}") return lines diff --git a/tests/mocks/mock_failing_prompt_driver.py b/tests/mocks/mock_failing_prompt_driver.py index d376d1981..0dbeb8fda 100644 --- a/tests/mocks/mock_failing_prompt_driver.py +++ b/tests/mocks/mock_failing_prompt_driver.py @@ -3,7 +3,7 @@ from attrs import define from griptape.artifacts import TextArtifact -from griptape.common import MessageStack, Message, TextMessageContent, DeltaMessage, TextDeltaMessageContent +from griptape.common import PromptStack, Message, TextMessageContent, DeltaMessage, TextDeltaMessageContent from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer @@ -15,7 +15,7 @@ class MockFailingPromptDriver(BasePromptDriver): model: str = "test-model" tokenizer: BaseTokenizer = OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - def try_run(self, message_stack: MessageStack) -> Message: + def try_run(self, prompt_stack: PromptStack) -> Message: if self.current_attempt < self.max_failures: self.current_attempt += 1 @@ -27,7 +27,7 @@ def try_run(self, message_stack: MessageStack) -> Message: usage=Message.Usage(input_tokens=100, output_tokens=100), ) - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: if self.current_attempt < self.max_failures: self.current_attempt += 1 diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 0a64daa8f..4786b78a6 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -6,7 +6,7 @@ from attrs import define, field from griptape.artifacts import TextArtifact -from griptape.common import MessageStack, Message, DeltaMessage, TextMessageContent, TextDeltaMessageContent +from griptape.common import PromptStack, Message, DeltaMessage, TextMessageContent, TextDeltaMessageContent from griptape.drivers import BasePromptDriver from griptape.tokenizers import BaseTokenizer @@ -18,10 +18,10 @@ class MockPromptDriver(BasePromptDriver): model: str = "test-model" tokenizer: BaseTokenizer = MockTokenizer(model="test-model", max_input_tokens=4096, max_output_tokens=4096) mock_input: str | Callable[[], str] = field(default="mock input", kw_only=True) - mock_output: str | Callable[[MessageStack], str] = field(default="mock output", kw_only=True) + mock_output: str | Callable[[PromptStack], str] = field(default="mock output", kw_only=True) - def try_run(self, message_stack: MessageStack) -> Message: - output = self.mock_output(message_stack) if isinstance(self.mock_output, Callable) else self.mock_output + def try_run(self, prompt_stack: PromptStack) -> Message: + output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output return Message( content=[TextMessageContent(TextArtifact(output))], @@ -29,8 +29,8 @@ def try_run(self, message_stack: MessageStack) -> Message: usage=Message.Usage(input_tokens=100, output_tokens=100), ) - def try_stream(self, message_stack: MessageStack) -> Iterator[DeltaMessage]: - output = self.mock_output(message_stack) if isinstance(self.mock_output, Callable) else self.mock_output + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: + output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output yield DeltaMessage(content=TextDeltaMessageContent(output)) diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index d49eb700d..3cd165140 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,7 +1,7 @@ import pytest from griptape.artifacts import ImageArtifact, TextArtifact -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.drivers import AmazonBedrockPromptDriver @@ -31,16 +31,16 @@ def mock_converse_stream(self, mocker): return mock_converse_stream @pytest.fixture(params=[True, False]) - def message_stack(self, request): - message_stack = MessageStack() + def prompt_stack(self, request): + prompt_stack = PromptStack() if request.param: - message_stack.add_system_message("system-input") - message_stack.add_user_message("user-input") - message_stack.add_user_message(TextArtifact("user-input")) - message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) - message_stack.add_assistant_message("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message(TextArtifact("user-input")) + prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + prompt_stack.add_assistant_message("assistant-input") - return message_stack + return prompt_stack @pytest.fixture def messages(self): @@ -51,18 +51,18 @@ def messages(self): {"role": "assistant", "content": [{"text": "assistant-input"}]}, ] - def test_try_run(self, mock_converse, message_stack, messages): + def test_try_run(self, mock_converse, prompt_stack, messages): # Given driver = AmazonBedrockPromptDriver(model="ai21.j2") # When - text_artifact = driver.try_run(message_stack) + text_artifact = driver.try_run(prompt_stack) # Then mock_converse.assert_called_once_with( modelId=driver.model, messages=messages, - **({"system": [{"text": "system-input"}]} if message_stack.system_messages else {"system": []}), + **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), inferenceConfig={"temperature": driver.temperature}, additionalModelRequestFields={}, ) @@ -70,19 +70,19 @@ def test_try_run(self, mock_converse, message_stack, messages): assert text_artifact.usage.input_tokens == 5 assert text_artifact.usage.output_tokens == 10 - def test_try_stream_run(self, mock_converse_stream, message_stack, messages): + def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages): # Given driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True) # When - stream = driver.try_stream(message_stack) + stream = driver.try_stream(prompt_stack) event = next(stream) # Then mock_converse_stream.assert_called_once_with( modelId=driver.model, messages=messages, - **({"system": [{"text": "system-input"}]} if message_stack.system_messages else {"system": []}), + **({"system": [{"text": "system-input"}]} if prompt_stack.system_messages else {"system": []}), inferenceConfig={"temperature": driver.temperature}, additionalModelRequestFields={}, ) diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py index 5d70f3aea..a75fc6ed0 100644 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -2,7 +2,7 @@ from botocore.response import StreamingBody from griptape.tokenizers import HuggingFaceTokenizer from griptape.drivers.prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver -from griptape.common import MessageStack +from griptape.common import PromptStack from io import BytesIO import json import pytest @@ -36,13 +36,13 @@ def test_init(self): def test_try_run(self, mock_client): # Given driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") - message_stack = MessageStack() - message_stack.add_user_message("prompt-stack") + prompt_stack = PromptStack() + prompt_stack.add_user_message("prompt-stack") # When response_body = [{"generated_text": "foobar"}] mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body(response_body)} - text_artifact = driver.try_run(message_stack) + text_artifact = driver.try_run(prompt_stack) assert isinstance(driver.tokenizer, HuggingFaceTokenizer) # Then @@ -72,7 +72,7 @@ def test_try_run(self, mock_client): # When response_body = {"generated_text": "foobar"} mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body(response_body)} - text_artifact = driver.try_run(message_stack) + text_artifact = driver.try_run(prompt_stack) assert isinstance(driver.tokenizer, HuggingFaceTokenizer) # Then @@ -100,12 +100,12 @@ def test_try_run(self, mock_client): def test_try_stream(self, mock_client): # Given driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") - message_stack = MessageStack() - message_stack.add_user_message("prompt-stack") + prompt_stack = PromptStack() + prompt_stack.add_user_message("prompt-stack") # When with pytest.raises(NotImplementedError) as e: - driver.try_stream(message_stack) + driver.try_stream(prompt_stack) # Then assert e.value.args[0] == "streaming is not supported" @@ -125,12 +125,12 @@ def test_try_run_throws_on_empty_response(self, mock_client): # Given driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body([])} - message_stack = MessageStack() - message_stack.add_user_message("prompt-stack") + prompt_stack = PromptStack() + prompt_stack.add_user_message("prompt-stack") # When with pytest.raises(Exception) as e: - driver.try_run(message_stack) + driver.try_run(prompt_stack) # Then assert e.value.args[0] == "model response is empty" diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 4328c7f1e..858c05ee6 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -1,5 +1,5 @@ from griptape.drivers import AnthropicPromptDriver -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.artifacts import TextArtifact, ImageArtifact, ListArtifact from unittest.mock import Mock import pytest @@ -60,13 +60,13 @@ def test_init(self, model): @pytest.mark.parametrize("system_enabled", [True, False]) def test_try_run(self, mock_client, model, system_enabled): # Given - message_stack = MessageStack() + prompt_stack = PromptStack() if system_enabled: - message_stack.add_system_message("system-input") - message_stack.add_user_message("user-input") - message_stack.add_user_message(TextArtifact("user-input")) - message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) - message_stack.add_assistant_message("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message(TextArtifact("user-input")) + prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + prompt_stack.add_assistant_message("assistant-input") driver = AnthropicPromptDriver(model=model, api_key="api-key") expected_messages = [ {"role": "user", "content": "user-input"}, @@ -84,7 +84,7 @@ def test_try_run(self, mock_client, model, system_enabled): ] # When - message = driver.try_run(message_stack) + message = driver.try_run(prompt_stack) # Then mock_client.return_value.messages.create.assert_called_once_with( @@ -115,16 +115,16 @@ def test_try_run(self, mock_client, model, system_enabled): @pytest.mark.parametrize("system_enabled", [True, False]) def test_try_stream_run(self, mock_stream_client, model, system_enabled): # Given - message_stack = MessageStack() + prompt_stack = PromptStack() if system_enabled: - message_stack.add_system_message("system-input") - message_stack.add_user_message("user-input") - message_stack.add_user_message( + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message( ListArtifact( [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] ) ) - message_stack.add_assistant_message("assistant-input") + prompt_stack.add_assistant_message("assistant-input") expected_messages = [ {"role": "user", "content": "user-input"}, { @@ -142,7 +142,7 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): driver = AnthropicPromptDriver(model=model, api_key="api-key", stream=True) # When - stream = driver.try_stream(message_stack) + stream = driver.try_stream(prompt_stack) event = next(stream) # Then @@ -165,14 +165,14 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): event = next(stream) assert event.usage.output_tokens == 10 - def test_try_run_throws_when_message_stack_is_string(self): + def test_try_run_throws_when_prompt_stack_is_string(self): # Given - message_stack = "prompt-stack" + prompt_stack = "prompt-stack" driver = AnthropicPromptDriver(model="claude", api_key="api-key") # When with pytest.raises(Exception) as e: - driver.try_run(message_stack) # pyright: ignore + driver.try_run(prompt_stack) # pyright: ignore # Then assert e.value.args[0] == "'str' object has no attribute 'messages'" diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index dedd2b3f6..92544a74e 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -30,12 +30,12 @@ def test_init(self): assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", azure_deployment="foobar", model="gpt-4") assert AzureOpenAiChatPromptDriver(azure_endpoint="foobar", model="gpt-4").azure_deployment == "gpt-4" - def test_try_run(self, mock_chat_completion_create, message_stack, messages): + def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): # Given driver = AzureOpenAiChatPromptDriver(azure_endpoint="endpoint", azure_deployment="deployment-id", model="gpt-4") # When - text_artifact = driver.try_run(message_stack) + text_artifact = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( @@ -45,14 +45,14 @@ def test_try_run(self, mock_chat_completion_create, message_stack, messages): assert text_artifact.usage.input_tokens == 5 assert text_artifact.usage.output_tokens == 10 - def test_try_stream_run(self, mock_chat_completion_stream_create, message_stack, messages): + def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages): # Given driver = AzureOpenAiChatPromptDriver( azure_endpoint="endpoint", azure_deployment="deployment-id", model="gpt-4", stream=True ) # When - stream = driver.try_stream(message_stack) + stream = driver.try_stream(prompt_stack) event = next(stream) # Then diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index ac54dc9a1..6eb000e1f 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,6 +1,6 @@ -from griptape.common.message_stack.messages.message import Message +from griptape.common.prompt_stack.messages.message import Message from griptape.events import FinishPromptEvent, StartPromptEvent -from griptape.common import MessageStack +from griptape.common import PromptStack from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_failing_prompt_driver import MockFailingPromptDriver from griptape.artifacts import ErrorArtifact, TextArtifact @@ -38,11 +38,11 @@ def test_run_via_pipeline_publishes_events(self, mocker): assert instance_count(events, FinishPromptEvent) == 1 def test_run(self): - assert isinstance(MockPromptDriver().run(MessageStack(messages=[])), Message) + assert isinstance(MockPromptDriver().run(PromptStack(messages=[])), Message) def test_run_with_stream(self): pipeline = Pipeline() - result = MockPromptDriver(stream=True, structure=pipeline).run(MessageStack(messages=[])) + result = MockPromptDriver(stream=True, structure=pipeline).run(PromptStack(messages=[])) assert isinstance(result, Message) assert result.value == "mock output" diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 86b248c26..d65775e8b 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.drivers import CoherePromptDriver @@ -33,25 +33,25 @@ def mock_tokenizer(self, mocker): return mocker.patch("griptape.tokenizers.CohereTokenizer").return_value @pytest.fixture(params=[True, False]) - def message_stack(self, request): - message_stack = MessageStack() + def prompt_stack(self, request): + prompt_stack = PromptStack() if request.param: - message_stack.add_system_message("system-input") - message_stack.add_user_message("user-input") - message_stack.add_assistant_message("assistant-input") - message_stack.add_user_message("user-input") - message_stack.add_assistant_message("assistant-input") - return message_stack + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") + return prompt_stack def test_init(self): assert CoherePromptDriver(model="command", api_key="foobar") - def test_try_run(self, mock_client, message_stack): + def test_try_run(self, mock_client, prompt_stack): # Given driver = CoherePromptDriver(model="command", api_key="api-key") # When - text_artifact = driver.try_run(message_stack) + text_artifact = driver.try_run(prompt_stack) # Then mock_client.chat.assert_called_once_with( @@ -62,7 +62,7 @@ def test_try_run(self, mock_client, message_stack): ], max_tokens=None, message="assistant-input", - **({"preamble": "system-input"} if message_stack.system_messages else {}), + **({"preamble": "system-input"} if prompt_stack.system_messages else {}), stop_sequences=[], temperature=0.1, ) @@ -71,12 +71,12 @@ def test_try_run(self, mock_client, message_stack): assert text_artifact.usage.input_tokens == 5 assert text_artifact.usage.output_tokens == 10 - def test_try_stream_run(self, mock_stream_client, message_stack): # pyright: ignore + def test_try_stream_run(self, mock_stream_client, prompt_stack): # pyright: ignore # Given driver = CoherePromptDriver(model="command", api_key="api-key", stream=True) # When - stream = driver.try_stream(message_stack) + stream = driver.try_stream(prompt_stack) event = next(stream) # Then @@ -89,7 +89,7 @@ def test_try_stream_run(self, mock_stream_client, message_stack): # pyright: ig ], max_tokens=None, message="assistant-input", - **({"preamble": "system-input"} if message_stack.system_messages else {}), + **({"preamble": "system-input"} if prompt_stack.system_messages else {}), stop_sequences=[], temperature=0.1, ) diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 93aa5bb3b..6a25ec3d3 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -1,7 +1,7 @@ from google.generativeai.types import GenerationConfig from griptape.artifacts import TextArtifact, ImageArtifact from griptape.drivers import GooglePromptDriver -from griptape.common import MessageStack +from griptape.common import PromptStack from unittest.mock import Mock import pytest @@ -35,17 +35,17 @@ def test_init(self): @pytest.mark.parametrize("system_enabled", [True, False]) def test_try_run(self, mock_generative_model, system_enabled): # Given - message_stack = MessageStack() + prompt_stack = PromptStack() if system_enabled: - message_stack.add_system_message("system-input") - message_stack.add_user_message("user-input") - message_stack.add_user_message(TextArtifact("user-input")) - message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) - message_stack.add_assistant_message("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message(TextArtifact("user-input")) + prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + prompt_stack.add_assistant_message("assistant-input") driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", top_p=0.5, top_k=50) # When - text_artifact = driver.try_run(message_stack) + text_artifact = driver.try_run(prompt_stack) # Then messages = [ @@ -71,17 +71,17 @@ def test_try_run(self, mock_generative_model, system_enabled): @pytest.mark.parametrize("system_enabled", [True, False]) def test_try_stream(self, mock_stream_generative_model, system_enabled): # Given - message_stack = MessageStack() + prompt_stack = PromptStack() if system_enabled: - message_stack.add_system_message("system-input") - message_stack.add_user_message("user-input") - message_stack.add_user_message(TextArtifact("user-input")) - message_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) - message_stack.add_assistant_message("assistant-input") + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message(TextArtifact("user-input")) + prompt_stack.add_user_message(ImageArtifact(value=b"image-data", format="png", width=100, height=100)) + prompt_stack.add_assistant_message("assistant-input") driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", stream=True, top_p=0.5, top_k=50) # When - stream = driver.try_stream(message_stack) + stream = driver.try_stream(prompt_stack) # Then event = next(stream) diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 8079ffe13..4618e1de3 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -1,5 +1,5 @@ from griptape.drivers import HuggingFaceHubPromptDriver -from griptape.common import MessageStack +from griptape.common import PromptStack import pytest @@ -28,12 +28,12 @@ def mock_client_stream(self, mocker): return mock_client @pytest.fixture - def message_stack(self): - message_stack = MessageStack() - message_stack.add_system_message("system-input") - message_stack.add_user_message("user-input") - message_stack.add_assistant_message("assistant-input") - return message_stack + def prompt_stack(self): + prompt_stack = PromptStack() + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") + return prompt_stack @pytest.fixture(autouse=True) def mock_autotokenizer(self, mocker): @@ -44,24 +44,24 @@ def mock_autotokenizer(self, mocker): def test_init(self): assert HuggingFaceHubPromptDriver(api_token="foobar", model="gpt2") - def test_try_run(self, message_stack, mock_client): + def test_try_run(self, prompt_stack, mock_client): # Given driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id") # When - message = driver.try_run(message_stack) + message = driver.try_run(prompt_stack) # Then assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, message_stack, mock_client_stream): + def test_try_stream(self, prompt_stack, mock_client_stream): # Given driver = HuggingFaceHubPromptDriver(api_token="api-token", model="repo-id", stream=True) # When - stream = driver.try_stream(message_stack) + stream = driver.try_stream(prompt_stack) event = next(stream) # Then diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index 5b31d8fe5..a63d697fb 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -1,5 +1,5 @@ from griptape.drivers import HuggingFacePipelinePromptDriver -from griptape.common import MessageStack +from griptape.common import PromptStack import pytest @@ -26,69 +26,69 @@ def mock_autotokenizer(self, mocker): return mock_autotokenizer @pytest.fixture - def message_stack(self): - message_stack = MessageStack() - message_stack.add_system_message("system-input") - message_stack.add_user_message("user-input") - message_stack.add_assistant_message("assistant-input") - return message_stack + def prompt_stack(self): + prompt_stack = PromptStack() + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_assistant_message("assistant-input") + return prompt_stack def test_init(self): assert HuggingFacePipelinePromptDriver(model="gpt2", max_tokens=42) - def test_try_run(self, message_stack): + def test_try_run(self, prompt_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) # When - message = driver.try_run(message_stack) + message = driver.try_run(prompt_stack) # Then assert message.value == "model-output" assert message.usage.input_tokens == 3 assert message.usage.output_tokens == 3 - def test_try_stream(self, message_stack): + def test_try_stream(self, prompt_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) # When with pytest.raises(Exception) as e: - driver.try_stream(message_stack) + driver.try_stream(prompt_stack) assert e.value.args[0] == "streaming is not supported" @pytest.mark.parametrize("choices", [[], [1, 2]]) - def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_generator, message_stack): + def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_generator, prompt_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) mock_generator.return_value = choices # When with pytest.raises(Exception) as e: - driver.try_run(message_stack) + driver.try_run(prompt_stack) # Then assert e.value.args[0] == "completion with more than one choice is not supported yet" - def test_try_run_throws_when_non_list(self, mock_generator, message_stack): + def test_try_run_throws_when_non_list(self, mock_generator, prompt_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) mock_generator.return_value = {} # When with pytest.raises(Exception) as e: - driver.try_run(message_stack) + driver.try_run(prompt_stack) # Then assert e.value.args[0] == "invalid output format" - def test_message_stack_to_string(self, message_stack): + def test_prompt_stack_to_string(self, prompt_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) # When - result = driver.message_stack_to_string(message_stack) + result = driver.prompt_stack_to_string(prompt_stack) # Then assert result == "model-output" diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py index c13aa2e60..a247a77ab 100644 --- a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -1,6 +1,6 @@ -from griptape.common.message_stack.contents.text_delta_message_content import TextDeltaMessageContent +from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent from griptape.drivers import OllamaPromptDriver -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact import pytest @@ -26,15 +26,15 @@ def test_init(self): def test_try_run(self, mock_client): # Given - message_stack = MessageStack() - message_stack.add_system_message("system-input") - message_stack.add_user_message("user-input") - message_stack.add_user_message( + prompt_stack = PromptStack() + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message( ListArtifact( [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] ) ) - message_stack.add_assistant_message("assistant-input") + prompt_stack.add_assistant_message("assistant-input") driver = OllamaPromptDriver(model="llama") expected_messages = [ {"role": "system", "content": "system-input"}, @@ -44,7 +44,7 @@ def test_try_run(self, mock_client): ] # When - message = driver.try_run(message_stack) + message = driver.try_run(prompt_stack) # Then mock_client.return_value.chat.assert_called_once_with( @@ -58,25 +58,25 @@ def test_try_run(self, mock_client): def test_try_run_bad_response(self, mock_client): # Given - message_stack = MessageStack() + prompt_stack = PromptStack() driver = OllamaPromptDriver(model="llama") mock_client.return_value.chat.return_value = "bad-response" # When/Then with pytest.raises(Exception, match="invalid model response"): - driver.try_run(message_stack) + driver.try_run(prompt_stack) def test_try_stream_run(self, mock_stream_client): # Given - message_stack = MessageStack() - message_stack.add_system_message("system-input") - message_stack.add_user_message("user-input") - message_stack.add_user_message( + prompt_stack = PromptStack() + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message( ListArtifact( [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] ) ) - message_stack.add_assistant_message("assistant-input") + prompt_stack.add_assistant_message("assistant-input") expected_messages = [ {"role": "system", "content": "system-input"}, {"role": "user", "content": "user-input"}, @@ -86,7 +86,7 @@ def test_try_stream_run(self, mock_stream_client): driver = OllamaPromptDriver(model="llama", stream=True) # When - text_artifact = next(driver.try_stream(message_stack)) + text_artifact = next(driver.try_stream(prompt_stack)) # Then mock_stream_client.return_value.chat.assert_called_once_with( @@ -100,10 +100,10 @@ def test_try_stream_run(self, mock_stream_client): def test_try_stream_bad_response(self, mock_stream_client): # Given - message_stack = MessageStack() + prompt_stack = PromptStack() driver = OllamaPromptDriver(model="llama", stream=True) mock_stream_client.return_value.chat.return_value = "bad-response" # When/Then with pytest.raises(Exception, match="invalid model response"): - next(driver.try_stream(message_stack)) + next(driver.try_stream(prompt_stack)) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index 8bd190c3a..5c217ed06 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,7 +1,7 @@ from griptape.artifacts import ImageArtifact, ListArtifact from griptape.artifacts import TextArtifact from griptape.drivers import OpenAiChatPromptDriver -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.tokenizers import OpenAiTokenizer from unittest.mock import Mock from tests.mocks.mock_tokenizer import MockTokenizer @@ -33,17 +33,17 @@ def mock_chat_completion_stream_create(self, mocker): return mock_chat_create @pytest.fixture - def message_stack(self): - message_stack = MessageStack() - message_stack.add_system_message("system-input") - message_stack.add_user_message("user-input") - message_stack.add_user_message( + def prompt_stack(self): + prompt_stack = PromptStack() + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message( ListArtifact( [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] ) ) - message_stack.add_assistant_message("assistant-input") - return message_stack + prompt_stack.add_assistant_message("assistant-input") + return prompt_stack @pytest.fixture def messages(self): @@ -98,12 +98,12 @@ class TestOpenAiChatPromptDriver(TestOpenAiChatPromptDriverFixtureMixin): def test_init(self): assert OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_4_MODEL) - def test_try_run(self, mock_chat_completion_create, message_stack, messages): + def test_try_run(self, mock_chat_completion_create, prompt_stack, messages): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) # When - event = driver.try_run(message_stack) + event = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( @@ -111,14 +111,14 @@ def test_try_run(self, mock_chat_completion_create, message_stack, messages): ) assert event.value == "model-output" - def test_try_run_response_format(self, mock_chat_completion_create, message_stack, messages): + def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack, messages): # Given driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, response_format="json_object" ) # When - message = driver.try_run(message_stack) + message = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( @@ -133,12 +133,12 @@ def test_try_run_response_format(self, mock_chat_completion_create, message_stac assert message.usage.input_tokens == 5 assert message.usage.output_tokens == 10 - def test_try_stream_run(self, mock_chat_completion_stream_create, message_stack, messages): + def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True) # When - stream = driver.try_stream(message_stack) + stream = driver.try_stream(prompt_stack) event = next(stream) # Then @@ -160,12 +160,12 @@ def test_try_stream_run(self, mock_chat_completion_stream_create, message_stack, event = next(stream) assert event.content.text == "" - def test_try_run_with_max_tokens(self, mock_chat_completion_create, message_stack, messages): + def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack, messages): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=1) # When - event = driver.try_run(message_stack) + event = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( @@ -178,7 +178,7 @@ def test_try_run_with_max_tokens(self, mock_chat_completion_create, message_stac ) assert event.value == "model-output" - def test_try_run_throws_when_message_stack_is_string(self): + def test_try_run_throws_when_prompt_stack_is_string(self): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) @@ -189,19 +189,19 @@ def test_try_run_throws_when_message_stack_is_string(self): # Then assert e.value.args[0] == "'str' object has no attribute 'messages'" - def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completion_create, message_stack): + def test_try_run_throws_when_multiple_choices_returned(self, mock_chat_completion_create, prompt_stack): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, api_key="api-key") mock_chat_completion_create.return_value.choices = [Mock(message=Mock(content="model-output"))] * 10 # When with pytest.raises(Exception) as e: - driver.try_run(message_stack) + driver.try_run(prompt_stack) # Then assert e.value.args[0] == "Completion with more than one choice is not supported yet." - def test_custom_tokenizer(self, mock_chat_completion_create, message_stack, messages): + def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=MockTokenizer(model="mock-model", stop_sequences=["mock-stop"]), @@ -209,7 +209,7 @@ def test_custom_tokenizer(self, mock_chat_completion_create, message_stack, mess ) # When - event = driver.try_run(message_stack) + event = driver.try_run(prompt_stack) # Then mock_chat_completion_create.assert_called_once_with( diff --git a/tests/unit/engines/summary/test_prompt_summary_engine.py b/tests/unit/engines/summary/test_prompt_summary_engine.py index 5de0016a9..34c6e3563 100644 --- a/tests/unit/engines/summary/test_prompt_summary_engine.py +++ b/tests/unit/engines/summary/test_prompt_summary_engine.py @@ -1,7 +1,7 @@ import pytest from griptape.artifacts import TextArtifact, ListArtifact from griptape.engines import PromptSummaryEngine -from griptape.common import MessageStack +from griptape.common import PromptStack from tests.mocks.mock_prompt_driver import MockPromptDriver import os @@ -27,8 +27,8 @@ def test_max_token_multiplier_invalid(self, engine): PromptSummaryEngine(prompt_driver=MockPromptDriver(), max_token_multiplier=10000) def test_chunked_summary(self, engine): - def smaller_input(message_stack: MessageStack): - return message_stack.messages[0].content[: (len(message_stack.messages[0].content) // 2)] + def smaller_input(prompt_stack: PromptStack): + return prompt_stack.messages[0].content[: (len(prompt_stack.messages[0].content) // 2)] engine = PromptSummaryEngine(prompt_driver=MockPromptDriver(mock_output="smaller_input")) diff --git a/tests/unit/events/test_base_event.py b/tests/unit/events/test_base_event.py index c4393fa3d..595c90f1f 100644 --- a/tests/unit/events/test_base_event.py +++ b/tests/unit/events/test_base_event.py @@ -32,8 +32,8 @@ def test_start_prompt_event_from_dict(self): "id": "917298d4bf894b0a824a8fdb26717a0c", "timestamp": 123, "model": "foo bar", - "message_stack": { - "type": "MessageStack", + "prompt_stack": { + "type": "PromptStack", "messages": [ { "type": "Message", @@ -59,10 +59,10 @@ def test_start_prompt_event_from_dict(self): assert isinstance(event, StartPromptEvent) assert event.timestamp == 123 - assert event.message_stack.messages[0].content[0].artifact.value == "foo" - assert event.message_stack.messages[0].role == "user" - assert event.message_stack.messages[1].content[0].artifact.value == "bar" - assert event.message_stack.messages[1].role == "system" + assert event.prompt_stack.messages[0].content[0].artifact.value == "foo" + assert event.prompt_stack.messages[0].role == "user" + assert event.prompt_stack.messages[1].content[0].artifact.value == "bar" + assert event.prompt_stack.messages[1].role == "system" assert event.model == "foo bar" def test_finish_prompt_event_from_dict(self): diff --git a/tests/unit/events/test_start_prompt_event.py b/tests/unit/events/test_start_prompt_event.py index 51e609458..4ef08ec5c 100644 --- a/tests/unit/events/test_start_prompt_event.py +++ b/tests/unit/events/test_start_prompt_event.py @@ -1,22 +1,22 @@ import pytest from griptape.events import StartPromptEvent -from griptape.common import MessageStack +from griptape.common import PromptStack class TestStartPromptEvent: @pytest.fixture def start_prompt_event(self): - message_stack = MessageStack() - message_stack.add_user_message("foo") - message_stack.add_system_message("bar") - return StartPromptEvent(message_stack=message_stack, model="foo bar") + prompt_stack = PromptStack() + prompt_stack.add_user_message("foo") + prompt_stack.add_system_message("bar") + return StartPromptEvent(prompt_stack=prompt_stack, model="foo bar") def test_to_dict(self, start_prompt_event): assert "timestamp" in start_prompt_event.to_dict() - assert start_prompt_event.to_dict()["message_stack"]["messages"][0]["content"][0]["artifact"]["value"] == "foo" - assert start_prompt_event.to_dict()["message_stack"]["messages"][0]["role"] == "user" - assert start_prompt_event.to_dict()["message_stack"]["messages"][1]["content"][0]["artifact"]["value"] == "bar" - assert start_prompt_event.to_dict()["message_stack"]["messages"][1]["role"] == "system" + assert start_prompt_event.to_dict()["prompt_stack"]["messages"][0]["content"][0]["artifact"]["value"] == "foo" + assert start_prompt_event.to_dict()["prompt_stack"]["messages"][0]["role"] == "user" + assert start_prompt_event.to_dict()["prompt_stack"]["messages"][1]["content"][0]["artifact"]["value"] == "bar" + assert start_prompt_event.to_dict()["prompt_stack"]["messages"][1]["role"] == "system" assert start_prompt_event.to_dict()["model"] == "foo bar" diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 82df4fc4d..613d4b1fe 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -1,6 +1,6 @@ import json from griptape.structures import Agent -from griptape.common import MessageStack +from griptape.common import PromptStack from griptape.memory.structure import ConversationMemory, Run, BaseConversationMemory from griptape.structures import Pipeline from tests.mocks.mock_prompt_driver import MockPromptDriver @@ -32,14 +32,14 @@ def test_to_dict(self): assert memory.to_dict()["type"] == "ConversationMemory" assert memory.to_dict()["runs"][0]["input"]["value"] == "foo" - def test_to_message_stack(self): + def test_to_prompt_stack(self): memory = ConversationMemory() memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) - message_stack = memory.to_message_stack() + prompt_stack = memory.to_prompt_stack() - assert message_stack.messages[0].content[0].artifact.value == "foo" - assert message_stack.messages[1].content[0].artifact.value == "bar" + assert prompt_stack.messages[0].content[0].artifact.value == "foo" + assert prompt_stack.messages[1].content[0].artifact.value == "bar" def test_from_dict(self): memory = ConversationMemory() @@ -74,7 +74,7 @@ def test_buffering(self): assert pipeline.conversation_memory.runs[0].input.value == "run4" assert pipeline.conversation_memory.runs[1].input.value == "run5" - def test_add_to_message_stack_autopruing_disabled(self): + def test_add_to_prompt_stack_autopruing_disabled(self): agent = Agent(prompt_driver=MockPromptDriver()) memory = ConversationMemory( autoprune=False, @@ -87,14 +87,14 @@ def test_add_to_message_stack_autopruing_disabled(self): ], ) memory.structure = agent - message_stack = MessageStack() - message_stack.add_user_message(TextArtifact("foo")) - message_stack.add_assistant_message("bar") - memory.add_to_message_stack(message_stack) + prompt_stack = PromptStack() + prompt_stack.add_user_message(TextArtifact("foo")) + prompt_stack.add_assistant_message("bar") + memory.add_to_prompt_stack(prompt_stack) - assert len(message_stack.messages) == 12 + assert len(prompt_stack.messages) == 12 - def test_add_to_message_stack_autopruning_enabled(self): + def test_add_to_prompt_stack_autopruning_enabled(self): # All memory is pruned. agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0))) memory = ConversationMemory( @@ -108,13 +108,13 @@ def test_add_to_message_stack_autopruning_enabled(self): ], ) memory.structure = agent - message_stack = MessageStack() - message_stack.add_system_message("fizz") - message_stack.add_user_message("foo") - message_stack.add_assistant_message("bar") - memory.add_to_message_stack(message_stack) + prompt_stack = PromptStack() + prompt_stack.add_system_message("fizz") + prompt_stack.add_user_message("foo") + prompt_stack.add_assistant_message("bar") + memory.add_to_prompt_stack(prompt_stack) - assert len(message_stack.messages) == 3 + assert len(prompt_stack.messages) == 3 # No memory is pruned. agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000))) @@ -129,13 +129,13 @@ def test_add_to_message_stack_autopruning_enabled(self): ], ) memory.structure = agent - message_stack = MessageStack() - message_stack.add_system_message("fizz") - message_stack.add_user_message("foo") - message_stack.add_assistant_message("bar") - memory.add_to_message_stack(message_stack) + prompt_stack = PromptStack() + prompt_stack.add_system_message("fizz") + prompt_stack.add_user_message("foo") + prompt_stack.add_assistant_message("bar") + memory.add_to_prompt_stack(prompt_stack) - assert len(message_stack.messages) == 13 + assert len(prompt_stack.messages) == 13 # One memory is pruned. # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens @@ -153,17 +153,17 @@ def test_add_to_message_stack_autopruning_enabled(self): ], ) memory.structure = agent - message_stack = MessageStack() + prompt_stack = PromptStack() # And then another 6 tokens from fizz for a total of 161 tokens. - message_stack.add_system_message("fizz") - message_stack.add_user_message("foo") - message_stack.add_assistant_message("bar") - memory.add_to_message_stack(message_stack, 1) - - # We expect one run (2 Message Stack inputs) to be pruned. - assert len(message_stack.messages) == 11 - assert message_stack.messages[0].content[0].artifact.value == "fizz" - assert message_stack.messages[1].content[0].artifact.value == "foo2" - assert message_stack.messages[2].content[0].artifact.value == "bar2" - assert message_stack.messages[-2].content[0].artifact.value == "foo" - assert message_stack.messages[-1].content[0].artifact.value == "bar" + prompt_stack.add_system_message("fizz") + prompt_stack.add_user_message("foo") + prompt_stack.add_assistant_message("bar") + memory.add_to_prompt_stack(prompt_stack, 1) + + # We expect one run (2 Prompt Stack inputs) to be pruned. + assert len(prompt_stack.messages) == 11 + assert prompt_stack.messages[0].content[0].artifact.value == "fizz" + assert prompt_stack.messages[1].content[0].artifact.value == "foo2" + assert prompt_stack.messages[2].content[0].artifact.value == "bar2" + assert prompt_stack.messages[-2].content[0].artifact.value == "foo" + assert prompt_stack.messages[-1].content[0].artifact.value == "bar" diff --git a/tests/unit/memory/structure/test_summary_conversation_memory.py b/tests/unit/memory/structure/test_summary_conversation_memory.py index e98fb6724..e625ac6c6 100644 --- a/tests/unit/memory/structure/test_summary_conversation_memory.py +++ b/tests/unit/memory/structure/test_summary_conversation_memory.py @@ -53,15 +53,15 @@ def test_to_dict(self): assert memory.to_dict()["type"] == "SummaryConversationMemory" assert memory.to_dict()["runs"][0]["input"]["value"] == "foo" - def test_to_message_stack(self): + def test_to_prompt_stack(self): memory = SummaryConversationMemory(summary="foobar") memory.add_run(Run(input=TextArtifact("foo"), output=TextArtifact("bar"))) - message_stack = memory.to_message_stack() + prompt_stack = memory.to_prompt_stack() - assert message_stack.messages[0].content[0].artifact.value == "Summary of the conversation so far: foobar" - assert message_stack.messages[1].content[0].artifact.value == "foo" - assert message_stack.messages[2].content[0].artifact.value == "bar" + assert prompt_stack.messages[0].content[0].artifact.value == "Summary of the conversation so far: foobar" + assert prompt_stack.messages[1].content[0].artifact.value == "foo" + assert prompt_stack.messages[2].content[0].artifact.value == "bar" def test_from_dict(self): memory = SummaryConversationMemory() diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 65562f8b2..942a960f9 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -159,39 +159,39 @@ def test_add_tasks(self): except ValueError: assert True - def test_message_stack_without_memory(self): + def test_prompt_stack_without_memory(self): agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=None) task1 = PromptTask("test") agent.add_task(task1) - assert len(task1.message_stack.messages) == 2 + assert len(task1.prompt_stack.messages) == 2 agent.run() - assert len(task1.message_stack.messages) == 3 + assert len(task1.prompt_stack.messages) == 3 agent.run() - assert len(task1.message_stack.messages) == 3 + assert len(task1.prompt_stack.messages) == 3 - def test_message_stack_with_memory(self): + def test_prompt_stack_with_memory(self): agent = Agent(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) task1 = PromptTask("test") agent.add_task(task1) - assert len(task1.message_stack.messages) == 2 + assert len(task1.prompt_stack.messages) == 2 agent.run() - assert len(task1.message_stack.messages) == 5 + assert len(task1.prompt_stack.messages) == 5 agent.run() - assert len(task1.message_stack.messages) == 7 + assert len(task1.prompt_stack.messages) == 7 def test_run(self): task = PromptTask("test") diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index a84ab31dc..99c4141bf 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -269,7 +269,7 @@ def test_insert_task_at_end(self): assert [parent.id for parent in third_task.parents] == ["test2"] assert [child.id for child in third_task.children] == [] - def test_message_stack_without_memory(self): + def test_prompt_stack_without_memory(self): pipeline = Pipeline(conversation_memory=None, prompt_driver=MockPromptDriver()) task1 = PromptTask("test") @@ -277,20 +277,20 @@ def test_message_stack_without_memory(self): pipeline.add_tasks(task1, task2) - assert len(task1.message_stack.messages) == 2 - assert len(task2.message_stack.messages) == 2 + assert len(task1.prompt_stack.messages) == 2 + assert len(task2.prompt_stack.messages) == 2 pipeline.run() - assert len(task1.message_stack.messages) == 3 - assert len(task2.message_stack.messages) == 3 + assert len(task1.prompt_stack.messages) == 3 + assert len(task2.prompt_stack.messages) == 3 pipeline.run() - assert len(task1.message_stack.messages) == 3 - assert len(task2.message_stack.messages) == 3 + assert len(task1.prompt_stack.messages) == 3 + assert len(task2.prompt_stack.messages) == 3 - def test_message_stack_with_memory(self): + def test_prompt_stack_with_memory(self): pipeline = Pipeline(prompt_driver=MockPromptDriver()) task1 = PromptTask("test") @@ -298,18 +298,18 @@ def test_message_stack_with_memory(self): pipeline.add_tasks(task1, task2) - assert len(task1.message_stack.messages) == 2 - assert len(task2.message_stack.messages) == 2 + assert len(task1.prompt_stack.messages) == 2 + assert len(task2.prompt_stack.messages) == 2 pipeline.run() - assert len(task1.message_stack.messages) == 5 - assert len(task2.message_stack.messages) == 5 + assert len(task1.prompt_stack.messages) == 5 + assert len(task2.prompt_stack.messages) == 5 pipeline.run() - assert len(task1.message_stack.messages) == 7 - assert len(task2.message_stack.messages) == 7 + assert len(task1.prompt_stack.messages) == 7 + assert len(task2.prompt_stack.messages) == 7 def test_text_artifact_token_count(self): text = "foobar" diff --git a/tests/unit/tokenizers/test_google_tokenizer.py b/tests/unit/tokenizers/test_google_tokenizer.py index f8ca60452..34510cdac 100644 --- a/tests/unit/tokenizers/test_google_tokenizer.py +++ b/tests/unit/tokenizers/test_google_tokenizer.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import Mock -from griptape.common import MessageStack -from griptape.common.message_stack.messages.message import Message +from griptape.common import PromptStack +from griptape.common.prompt_stack.messages.message import Message from griptape.tokenizers import GoogleTokenizer @@ -20,7 +20,7 @@ def tokenizer(self, request): @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 5)], indirect=["tokenizer"]) def test_token_count(self, tokenizer, expected): assert tokenizer.count_tokens("foo bar huzzah") == expected - assert tokenizer.count_tokens(MessageStack(messages=[Message(content="foo", role="user")])) == expected + assert tokenizer.count_tokens(PromptStack(messages=[Message(content="foo", role="user")])) == expected assert tokenizer.count_tokens(["foo", "bar", "huzzah"]) == expected @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 30715)], indirect=["tokenizer"]) diff --git a/tests/unit/utils/test_conversation.py b/tests/unit/utils/test_conversation.py index 963903cc6..cce067f73 100644 --- a/tests/unit/utils/test_conversation.py +++ b/tests/unit/utils/test_conversation.py @@ -21,7 +21,7 @@ def test_lines(self): assert lines[2] == "Q: question 1" assert lines[3] == "A: mock output" - def test_message_stack_conversation_memory(self): + def test_prompt_stack_conversation_memory(self): pipeline = Pipeline(prompt_driver=MockPromptDriver(), conversation_memory=ConversationMemory()) pipeline.add_tasks(PromptTask("question 1")) @@ -29,12 +29,12 @@ def test_message_stack_conversation_memory(self): pipeline.run() pipeline.run() - lines = Conversation(pipeline.conversation_memory).message_stack() + lines = Conversation(pipeline.conversation_memory).prompt_stack() assert lines[0] == "user: question 1" assert lines[1] == "assistant: mock output" - def test_message_stack_summary_conversation_memory(self): + def test_prompt_stack_summary_conversation_memory(self): pipeline = Pipeline( prompt_driver=MockPromptDriver(), conversation_memory=SummaryConversationMemory(summary="foobar", prompt_driver=MockPromptDriver()), @@ -45,7 +45,7 @@ def test_message_stack_summary_conversation_memory(self): pipeline.run() pipeline.run() - lines = Conversation(pipeline.conversation_memory).message_stack() + lines = Conversation(pipeline.conversation_memory).prompt_stack() assert lines[0] == "user: Summary of the conversation so far: mock output" assert lines[1] == "user: question 1" diff --git a/tests/unit/utils/test_message_stack.py b/tests/unit/utils/test_message_stack.py index 9bab66d0b..908388a33 100644 --- a/tests/unit/utils/test_message_stack.py +++ b/tests/unit/utils/test_message_stack.py @@ -1,55 +1,55 @@ import pytest from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact -from griptape.common import ImageMessageContent, MessageStack, TextMessageContent +from griptape.common import ImageMessageContent, PromptStack, TextMessageContent class TestPromptStack: @pytest.fixture - def message_stack(self): - return MessageStack() + def prompt_stack(self): + return PromptStack() def test_init(self): - assert MessageStack() + assert PromptStack() - def test_add_message(self, message_stack): - message_stack.add_message("foo", "role") - message_stack.add_message(TextArtifact("foo"), "role") - message_stack.add_message(ImageArtifact(b"foo", format="png", width=100, height=100), "role") - message_stack.add_message(ListArtifact([TextArtifact("foo"), TextArtifact("bar")]), "role") + def test_add_message(self, prompt_stack): + prompt_stack.add_message("foo", "role") + prompt_stack.add_message(TextArtifact("foo"), "role") + prompt_stack.add_message(ImageArtifact(b"foo", format="png", width=100, height=100), "role") + prompt_stack.add_message(ListArtifact([TextArtifact("foo"), TextArtifact("bar")]), "role") - assert message_stack.messages[0].role == "role" - assert isinstance(message_stack.messages[0].content[0], TextMessageContent) - assert message_stack.messages[0].content[0].artifact.value == "foo" + assert prompt_stack.messages[0].role == "role" + assert isinstance(prompt_stack.messages[0].content[0], TextMessageContent) + assert prompt_stack.messages[0].content[0].artifact.value == "foo" - assert message_stack.messages[1].role == "role" - assert isinstance(message_stack.messages[1].content[0], TextMessageContent) - assert message_stack.messages[1].content[0].artifact.value == "foo" + assert prompt_stack.messages[1].role == "role" + assert isinstance(prompt_stack.messages[1].content[0], TextMessageContent) + assert prompt_stack.messages[1].content[0].artifact.value == "foo" - assert message_stack.messages[2].role == "role" - assert isinstance(message_stack.messages[2].content[0], ImageMessageContent) - assert message_stack.messages[2].content[0].artifact.value == b"foo" + assert prompt_stack.messages[2].role == "role" + assert isinstance(prompt_stack.messages[2].content[0], ImageMessageContent) + assert prompt_stack.messages[2].content[0].artifact.value == b"foo" - assert message_stack.messages[3].role == "role" - assert isinstance(message_stack.messages[3].content[0], TextMessageContent) - assert message_stack.messages[3].content[0].artifact.value == "foo" - assert isinstance(message_stack.messages[3].content[1], TextMessageContent) - assert message_stack.messages[3].content[1].artifact.value == "bar" + assert prompt_stack.messages[3].role == "role" + assert isinstance(prompt_stack.messages[3].content[0], TextMessageContent) + assert prompt_stack.messages[3].content[0].artifact.value == "foo" + assert isinstance(prompt_stack.messages[3].content[1], TextMessageContent) + assert prompt_stack.messages[3].content[1].artifact.value == "bar" - def test_add_system_message(self, message_stack): - message_stack.add_system_message("foo") + def test_add_system_message(self, prompt_stack): + prompt_stack.add_system_message("foo") - assert message_stack.messages[0].role == "system" - assert message_stack.messages[0].content[0].artifact.value == "foo" + assert prompt_stack.messages[0].role == "system" + assert prompt_stack.messages[0].content[0].artifact.value == "foo" - def test_add_user_message(self, message_stack): - message_stack.add_user_message("foo") + def test_add_user_message(self, prompt_stack): + prompt_stack.add_user_message("foo") - assert message_stack.messages[0].role == "user" - assert message_stack.messages[0].content[0].artifact.value == "foo" + assert prompt_stack.messages[0].role == "user" + assert prompt_stack.messages[0].content[0].artifact.value == "foo" - def test_add_assistant_message(self, message_stack): - message_stack.add_assistant_message("foo") + def test_add_assistant_message(self, prompt_stack): + prompt_stack.add_assistant_message("foo") - assert message_stack.messages[0].role == "assistant" - assert message_stack.messages[0].content[0].artifact.value == "foo" + assert prompt_stack.messages[0].role == "assistant" + assert prompt_stack.messages[0].content[0].artifact.value == "foo" From 2662ff7b653eb757d4fe01a3f9cf8c9529070ca1 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 1 Jul 2024 18:55:14 -0500 Subject: [PATCH 34/34] Fix docs --- docs/griptape-framework/drivers/prompt-drivers.md | 2 +- docs/griptape-framework/structures/tasks.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 5d665d855..6d68d42e2 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -33,7 +33,7 @@ from griptape.drivers import OpenAiChatPromptDriver stack = PromptStack() -stack.add_system_input( +stack.add_system_message( "You will be provided with Python code, and your task is to calculate its time complexity." ) stack.add_user_input( diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index dcc2e6f4a..847e2a47f 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -95,7 +95,7 @@ from griptape.structures import Agent from griptape.loaders import ImageLoader agent = Agent() -with open("assets/mountain.jpg", "rb") as f: +with open("tests/resources/mountain.jpg", "rb") as f: image_artifact = ImageLoader().load(f.read()) agent.run(["What's in this image?", image_artifact])