Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor prompt stack #861

Merged
merged 38 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
aa00bdd
Refactor prompt stack
collindutter Jun 13, 2024
d3995b8
Add support for more modalities to conversation memory
collindutter Jun 14, 2024
524a30d
Update default artifact
collindutter Jun 14, 2024
7f38288
Fix bad merge
collindutter Jun 18, 2024
2380b37
Rename Prompt Stack Element to Prompt Stack Message
collindutter Jun 18, 2024
1e4aac4
Fix Ollama
collindutter Jun 19, 2024
9da420d
Clean up roles
collindutter Jun 19, 2024
0caf028
Rename deltas
collindutter Jun 19, 2024
784aafc
PR cleanup
collindutter Jun 19, 2024
52c1a04
Change task hierarchy
collindutter Jun 19, 2024
b82ae14
Update changelog
collindutter Jun 19, 2024
fc63c54
Regenerate lock file
collindutter Jun 20, 2024
1169a50
Add back missing logs
collindutter Jun 20, 2024
46efbee
Fix doc var names
collindutter Jun 20, 2024
d071def
Clean up message building
collindutter Jun 20, 2024
5fdf810
Add tests
collindutter Jun 20, 2024
43826d3
Add image input support to ollama
collindutter Jun 20, 2024
c981637
Fix tests
collindutter Jun 20, 2024
828e892
Rename inputs to messages
collindutter Jun 21, 2024
7af4e4e
Big rename
collindutter Jun 21, 2024
1297528
Improve test coverage
collindutter Jun 21, 2024
2b7fe92
Add missing module
collindutter Jun 21, 2024
85b8cd0
Update docs
collindutter Jun 21, 2024
635249c
Simplify inputs
collindutter Jun 21, 2024
52ea784
Regenerate lock file
collindutter Jun 21, 2024
75a5c1a
Update changelog
collindutter Jun 21, 2024
06cb2a3
Fix test
collindutter Jun 21, 2024
5c797ad
Simplify cohere
collindutter Jun 21, 2024
182ff92
Simplify system_prompt access
collindutter Jun 21, 2024
8dc4944
Change prompt driver run return value
collindutter Jun 21, 2024
5c663fe
Merge branch 'dev' into refactor/prompt-stack-elements2
collindutter Jun 22, 2024
49b93f9
Regenerate lock file
collindutter Jun 22, 2024
ebedd83
Fix bad merge, tests
collindutter Jul 1, 2024
7802c11
Rename MessageStack back to PromptStack. Snip-snap! Snip-snap! Snip-s…
collindutter Jul 1, 2024
36b6e37
Merge branch 'dev' into refactor/prompt-stack-elements2
collindutter Jul 1, 2024
44d54dc
Merge branch 'dev' into refactor/prompt-stack-elements2
collindutter Jul 1, 2024
2662ff7
Fix docs
collindutter Jul 1, 2024
919b480
Merge branch 'dev' into refactor/prompt-stack-elements2
collindutter Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,32 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
### Added
- `Message` for storing messages in a `PromptStack`. Messages consist of a role, content, and usage.
- `DeltaMessage` for storing partial messages in a `PromptStack`. Multiple `DeltaMessage` can be combined to form a `Message`.
- `TextMessageContent` for storing textual content in a `Message`.
- `ImageMessageContent` for storing image content in a `Message`.
- Support for adding `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s to `PromptStack`.
- Support for image inputs to `OpenAiChatPromptDriver`, `AzureOpenAiChatPromptDriver`, `AmazonBedrockPromptDriver`, `AnthropicPromptDriver`, and `GooglePromptDriver`.
- Input/output token usage metrics to all Prompt Drivers.
- `FinishPromptEvent.input_token_count` and `FinishPromptEvent.output_token_count`.
- Support for storing Artifacts as inputs/outputs in Conversation Memory Runs.
- `Agent.input` for passing Artifacts as input.
- Support for `PromptTask`s to take `TextArtifact`s, `ImageArtifact`s, and `ListArtifact`s as input.

### Changed
- **BREAKING**: Moved/renamed `griptape.utils.PromptStack` to `griptape.common.PromptStack`.
- **BREAKING**: Renamed `PromptStack.inputs` to `PromptStack.messages`.
- **BREAKING**: Moved `PromptStack.USER_ROLE`, `PromptStack.ASSISTANT_ROLE`, and `PromptStack.SYSTEM_ROLE` to `Message`.
- **BREAKING**: Updated return type of `PromptDriver.try_run` from `TextArtifact` to `Message`.
- **BREAKING**: Updated return type of `PromptDriver.try_stream` from `Iterator[TextArtifact]` to `Iterator[DeltaMessage]`.
- **BREAKING**: Removed `BasePromptEvent.token_count` in favor of `FinishPromptEvent.input_token_count` and `FinishPromptEvent.output_token_count`.
- **BREAKING**: Removed `StartPromptEvent.prompt`. Use `StartPromptEvent.prompt_stack` instead.
- **BREAKING**: Removed `Agent.input_template` in favor of `Agent.input`.
- **BREAKING**: `BasePromptDriver.run` now returns a `Message` instead of a `TextArtifact`. For compatibility, `Message.value` contains the Message's Artifact value
- Default Prompt Driver model in `GoogleStructureConfig` to `gemini-1.5-pro`.



### Added
- `RagEngine` is an abstraction for implementing modular RAG pipelines.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ from griptape.structures import Agent
from griptape.tools import WebScraper, FileManager, TaskMemoryClient

agent = Agent(
input_template="Load {{ args[0] }}, summarize it, and store it in a file called {{ args[1] }}.",
input="Load {{ args[0] }}, summarize it, and store it in a file called {{ args[1] }}.",
tools=[
WebScraper(off_prompt=True),
TaskMemoryClient(off_prompt=True),
Expand Down
24 changes: 12 additions & 12 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ agent = Agent(
config=StructureConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.3),
),
input_template="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}",
input="You will be provided with a tweet, and your task is to classify its sentiment as positive, neutral, or negative. Tweet: {{ args[0] }}",
rules=[
Rule(
value="Output only the sentiment."
Expand All @@ -28,23 +28,23 @@ agent.run("I loved the new Batman movie!")
Or use them independently:

```python
from griptape.utils import PromptStack
from griptape.common import PromptStack
from griptape.drivers import OpenAiChatPromptDriver

stack = PromptStack()

stack.add_system_input(
stack.add_system_message(
"You will be provided with Python code, and your task is to calculate its time complexity."
)
stack.add_user_input(
"""
def foo(n, k):
accum = 0
for i in range(n):
for l in range(k):
accum += i
return accum
"""
"""
def foo(n, k):
accum = 0
for i in range(n):
for l in range(k):
accum += i
return accum
"""
)

result = OpenAiChatPromptDriver(model="gpt-3.5-turbo-16k", temperature=0).run(stack)
Expand Down Expand Up @@ -80,7 +80,7 @@ agent = Agent(
seed=42,
)
),
input_template="You will be provided with a description of a mood, and your task is to generate the CSS code for a color that matches it. Description: {{ args[0] }}",
input="You will be provided with a description of a mood, and your task is to generate the CSS code for a color that matches it. Description: {{ args[0] }}",
rules=[
Rule(
value='Write your output in json with a single key called "css_code".'
Expand Down
8 changes: 4 additions & 4 deletions docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ from griptape.events import BaseEvent, StartPromptEvent, EventListener

def handler(event: BaseEvent):
if isinstance(event, StartPromptEvent):
print("Prompt Stack Inputs:")
for input in event.prompt_stack.inputs:
print(f"{input.role}: {input.content}")
print("Prompt Stack PromptStack:")
for message in event.prompt_stack.messages:
print(f"{message.role}: {message.content}")
print("Final Prompt String:")
print(event.prompt)

Expand All @@ -259,7 +259,7 @@ agent.run("Write me a poem.")
```
```
...
Prompt Stack Inputs:
Prompt Stack Messages:
system:
user: Write me a poem.
Final Prompt String:
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from griptape.structures import Agent


agent = Agent(
input_template="Calculate the following: {{ args[0] }}",
input="Calculate the following: {{ args[0] }}",
tools=[Calculator()]
)

Expand Down
27 changes: 26 additions & 1 deletion docs/griptape-framework/structures/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,31 @@ agent.run("Write me a haiku")
Day begins anew.
```

If the model supports it, you can also pass image inputs:

```python
from griptape.structures import Agent
from griptape.loaders import ImageLoader

agent = Agent()
with open("tests/resources/mountain.jpg", "rb") as f:
image_artifact = ImageLoader().load(f.read())

agent.run(["What's in this image?", image_artifact])
```

```
[06/21/24 10:01:08] INFO PromptTask c229d1792da34ab1a7c45768270aada9
Input: What's in this image?

Media, type: image/jpeg, size: 82351 bytes
[06/21/24 10:01:12] INFO PromptTask c229d1792da34ab1a7c45768270aada9
Output: The image depicts a stunning mountain landscape at sunrise or sunset. The sun is partially visible on the left side of the image,
casting a warm golden light over the scene. The mountains are covered with snow at their peaks, and a layer of clouds or fog is settled in the
valleys between them. The sky is a mix of warm colors near the horizon, transitioning to cooler blues higher up, with some scattered clouds
adding texture to the sky. The overall scene is serene and majestic, highlighting the natural beauty of the mountainous terrain.
```

## Toolkit Task

To use [Griptape Tools](../../griptape-framework/tools/index.md), use a [Toolkit Task](../../reference/griptape/tasks/toolkit_task.md).
Expand Down Expand Up @@ -740,7 +765,7 @@ def build_researcher():

def build_writer():
writer = Agent(
input_template="Instructions: {{args[0]}}\nContext: {{args[1]}}",
input="Instructions: {{args[0]}}\nContext: {{args[1]}}",
rulesets=[
Ruleset(
name="Position",
Expand Down
23 changes: 23 additions & 0 deletions griptape/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .prompt_stack.contents.base_message_content import BaseMessageContent
from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent
from .prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent
from .prompt_stack.contents.text_message_content import TextMessageContent
from .prompt_stack.contents.image_message_content import ImageMessageContent

from .prompt_stack.messages.base_message import BaseMessage
from .prompt_stack.messages.delta_message import DeltaMessage
from .prompt_stack.messages.message import Message

from .prompt_stack.prompt_stack import PromptStack

__all__ = [
"BaseMessage",
"BaseDeltaMessageContent",
"BaseMessageContent",
"DeltaMessage",
"Message",
"TextDeltaMessageContent",
"TextMessageContent",
"ImageMessageContent",
"PromptStack",
]
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

from abc import ABC

from attrs import define, field

from griptape.mixins.serializable_mixin import SerializableMixin


@define
class BaseDeltaMessageContent(ABC, SerializableMixin):
index: int = field(kw_only=True, default=0, metadata={"serializable": True})
31 changes: 31 additions & 0 deletions griptape/common/prompt_stack/contents/base_message_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from abc import ABC
from collections.abc import Sequence

from attrs import define, field

from griptape.artifacts.base_artifact import BaseArtifact
from griptape.mixins import SerializableMixin

from .base_delta_message_content import BaseDeltaMessageContent


@define
class BaseMessageContent(ABC, SerializableMixin):
artifact: BaseArtifact = field(metadata={"serializable": True})

def to_text(self) -> str:
return str(self.artifact)

Check warning on line 19 in griptape/common/prompt_stack/contents/base_message_content.py

View check run for this annotation

Codecov / codecov/patch

griptape/common/prompt_stack/contents/base_message_content.py#L19

Added line #L19 was not covered by tests

def __str__(self) -> str:
return self.artifact.to_text()

Check warning on line 22 in griptape/common/prompt_stack/contents/base_message_content.py

View check run for this annotation

Codecov / codecov/patch

griptape/common/prompt_stack/contents/base_message_content.py#L22

Added line #L22 was not covered by tests

def __bool__(self) -> bool:
return bool(self.artifact)

Check warning on line 25 in griptape/common/prompt_stack/contents/base_message_content.py

View check run for this annotation

Codecov / codecov/patch

griptape/common/prompt_stack/contents/base_message_content.py#L25

Added line #L25 was not covered by tests

def __len__(self) -> int:
return len(self.artifact)

Check warning on line 28 in griptape/common/prompt_stack/contents/base_message_content.py

View check run for this annotation

Codecov / codecov/patch

griptape/common/prompt_stack/contents/base_message_content.py#L28

Added line #L28 was not covered by tests

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> BaseMessageContent: ...
17 changes: 17 additions & 0 deletions griptape/common/prompt_stack/contents/image_message_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from collections.abc import Sequence

from attrs import define, field

from griptape.artifacts import ImageArtifact
from griptape.common import BaseDeltaMessageContent, BaseMessageContent


@define
class ImageMessageContent(BaseMessageContent):
artifact: ImageArtifact = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> ImageMessageContent:
raise NotImplementedError()
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations
from attrs import define, field

from griptape.common import BaseDeltaMessageContent


@define
class TextDeltaMessageContent(BaseDeltaMessageContent):
text: str = field(metadata={"serializable": True})
20 changes: 20 additions & 0 deletions griptape/common/prompt_stack/contents/text_message_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from attrs import define, field
from collections.abc import Sequence

from griptape.artifacts import TextArtifact
from griptape.common import BaseMessageContent, BaseDeltaMessageContent, TextDeltaMessageContent


@define
class TextMessageContent(BaseMessageContent):
artifact: TextArtifact = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> TextMessageContent:
text_deltas = [delta for delta in deltas if isinstance(delta, TextDeltaMessageContent)]

artifact = TextArtifact(value="".join(delta.text for delta in text_deltas))

return cls(artifact=artifact)
Empty file.
44 changes: 44 additions & 0 deletions griptape/common/prompt_stack/messages/base_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from abc import ABC
from typing import Optional, Union
from attrs import Factory, define, field


from griptape.common import BaseMessageContent, BaseDeltaMessageContent
from griptape.mixins import SerializableMixin


@define
class BaseMessage(ABC, SerializableMixin):
@define
class Usage(SerializableMixin):
input_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True})
output_tokens: Optional[float] = field(kw_only=True, default=None, metadata={"serializable": True})

@property
def total_tokens(self) -> float:
return (self.input_tokens or 0) + (self.output_tokens or 0)

Check warning on line 21 in griptape/common/prompt_stack/messages/base_message.py

View check run for this annotation

Codecov / codecov/patch

griptape/common/prompt_stack/messages/base_message.py#L21

Added line #L21 was not covered by tests

def __add__(self, other: BaseMessage.Usage) -> BaseMessage.Usage:
return BaseMessage.Usage(
input_tokens=(self.input_tokens or 0) + (other.input_tokens or 0),
output_tokens=(self.output_tokens or 0) + (other.output_tokens or 0),
)

USER_ROLE = "user"
ASSISTANT_ROLE = "assistant"
SYSTEM_ROLE = "system"

content: list[Union[BaseMessageContent, BaseDeltaMessageContent]] = field(metadata={"serializable": True})
role: str = field(kw_only=True, metadata={"serializable": True})
usage: Usage = field(kw_only=True, default=Factory(lambda: BaseMessage.Usage()), metadata={"serializable": True})

def is_system(self) -> bool:
return self.role == self.SYSTEM_ROLE

def is_user(self) -> bool:
return self.role == self.USER_ROLE

def is_assistant(self) -> bool:
return self.role == self.ASSISTANT_ROLE
15 changes: 15 additions & 0 deletions griptape/common/prompt_stack/messages/delta_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations
from typing import Optional

from attrs import define, field

from griptape.common.prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent


from .base_message import BaseMessage


@define
class DeltaMessage(BaseMessage):
role: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
content: Optional[TextDeltaMessageContent] = field(kw_only=True, default=None, metadata={"serializable": True})
38 changes: 38 additions & 0 deletions griptape/common/prompt_stack/messages/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from typing import Any

from attrs import define, field

from griptape.artifacts import BaseArtifact, ListArtifact, TextArtifact
from griptape.common import BaseMessageContent, TextMessageContent

from .base_message import BaseMessage


@define
class Message(BaseMessage):
def __init__(self, content: str | list[BaseMessageContent], **kwargs: Any):
if isinstance(content, str):
content = [TextMessageContent(TextArtifact(value=content))]
self.__attrs_init__(content, **kwargs) # pyright: ignore[reportAttributeAccessIssue]

content: list[BaseMessageContent] = field(metadata={"serializable": True})

@property
def value(self) -> Any:
return self.to_artifact().value

def __str__(self) -> str:
return self.to_text()

Check warning on line 27 in griptape/common/prompt_stack/messages/message.py

View check run for this annotation

Codecov / codecov/patch

griptape/common/prompt_stack/messages/message.py#L27

Added line #L27 was not covered by tests

def to_text(self) -> str:
return "".join(
[content.artifact.to_text() for content in self.content if isinstance(content, TextMessageContent)]
)

def to_artifact(self) -> BaseArtifact:
if len(self.content) == 1:
return self.content[0].artifact
else:
return ListArtifact([content.artifact for content in self.content])
Loading
Loading