Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JsonSchemaRule #1165

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Parameter `meta: dict` on `BaseEvent`.
- `AzureOpenAiTextToSpeechDriver`.
- Ability to use Event Listeners as Context Managers for temporarily setting the Event Bus listeners.
- `JsonSchemaRule` for instructing the LLM to output a JSON object that conforms to a schema.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Responding with an inline comment, so it's easier to reply.

Question: should this be the mechanism by which we enable features like OpenAI Structured Outputs? If so, we'll need to make the prompt rendering happen at the Driver level rather than the Task level.

If we didn't use this mechanism for OpenAI Structured Outputs, then what would the mechanism for that look like? Would it require passing a json schema? (I would expect so) If it requires passing a json schema, then we'd have two choices, which doesn't seem like the right approach (imagine that we had taken that approach for tools 🤮 ).

Seems like making prompt rendering at the Driver level is the way forward because I don't really see duplicate ways of passing in a json schema as a viable option. This could technically wait until introducing native structured outputs, right?

(Also, if we do use this rule for native structure outputs, what mechanism will we use to opt-in/out of structure outputs when a driver supports it? Should we that the same approach as a use_native_tools flag?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it require passing a json schema?
Yes, it'd likely require passing a schema to the Prompt Driver 🤢.

Rendering at the Prompt Driver level allows us to lean on more provider-specific nuances so I think it's a worthwhile change. But as you point out it's not required for this PR -- this can still serve as a good intermediate solution.


### Changed
- **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`.
Expand Down
62 changes: 60 additions & 2 deletions docs/griptape-framework/structures/rulesets.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,66 @@ search:

## Overview

A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define rules for [Structures](../structures/agents.md) and [Tasks](../structures/tasks.md).
Rulesets can be used to shape personality, format output, restrict topics, and more.
A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Rule](../../reference/griptape/rules/base_rule.md)s for [Structures](../structures/agents.md) and [Tasks](../structures/tasks.md). Griptape places Rules into the LLM's system prompt for strong control over the output.

## Types of Rules

### Rule

[Rule](../../reference/griptape/rules/base_rule.md)s shape the LLM's behavior by defining specific guidelines or instructions for how it should interpret and respond to inputs. Rules can be used to modify language style, tone, or even behavior based on what you define.

```python
--8<-- "docs/griptape-framework/structures/src/basic_rule.py"
```

```
[09/10/24 14:41:52] INFO PromptTask b7b23a88ea9e4cd0befb7e7a4ed596b0
Input: Hi there! How are you?
INFO PromptTask b7b23a88ea9e4cd0befb7e7a4ed596b0
Output: Ahoy, matey! I be doing just fine, thank ye fer askin'. How be the winds blowin' in yer sails today?
```

### Json Schema

[JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema.
This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types.

!!! warning
`JsonSchemaRule` may break [ToolkitTask](../structures/tasks.md#toolkittask) which relies on a specific [output token](https://github.com/griptape-ai/griptape/blob/e6a04c7b88cf9fa5d6bcf4c833ffebfab89a3258/griptape/tasks/toolkit_task.py#L28).


```python
--8<-- "docs/griptape-framework/structures/src/json_schema_rule.py"
```

```
[09/10/24 14:44:53] INFO PromptTask fb26dd41803443c0b51c3d861626e07a
Input: What is the sentiment of this message?: 'I am so happy!'
[09/10/24 14:44:54] INFO PromptTask fb26dd41803443c0b51c3d861626e07a
Output: {
"answer": "The sentiment of the message is positive.",
"relevant_emojis": ["😊", "😃"]
}
```

Although Griptape leverages the `schema` library, you're free to use any JSON schema generation library to define your schema!

For example, using `pydantic`:

```python
--8<-- "docs/griptape-framework/structures/src/json_schema_rule_pydantic.py"
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved
```

```
[09/11/24 09:45:58] INFO PromptTask eae43f52829c4289a6cca9ee7950e075
Input: What is the sentiment of this message?: 'I am so happy!'
INFO PromptTask eae43f52829c4289a6cca9ee7950e075
Output: {
"answer": "The sentiment of the message is positive.",
"relevant_emojis": ["😊", "😄"]
}
answer='The sentiment of the message is positive.' relevant_emojis=['😊', '😄']
```

## Structure

Expand Down
13 changes: 13 additions & 0 deletions docs/griptape-framework/structures/src/basic_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from griptape.rules import Rule, Ruleset
from griptape.structures import Agent

pipeline = Agent(
rulesets=[
Ruleset(
name="Personality",
rules=[Rule("Talk like a pirate.")],
),
]
)

pipeline.run("Hi there! How are you?")
18 changes: 18 additions & 0 deletions docs/griptape-framework/structures/src/json_schema_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import json

import schema

from griptape.rules.json_schema_rule import JsonSchemaRule
from griptape.structures import Agent

agent = Agent(
rules=[
JsonSchemaRule(
schema.Schema({"answer": str, "relevant_emojis": schema.Schema(["str"])}).json_schema("Output Format")
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved
)
]
)

output = agent.run("What is the sentiment of this message?: 'I am so happy!'").output

print(json.dumps(json.loads(output.value), indent=2))
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

import pydantic

from griptape.rules.json_schema_rule import JsonSchemaRule
from griptape.structures import Agent


class SentimentModel(pydantic.BaseModel):
answer: str
relevant_emojis: list[str]


agent = Agent(rules=[JsonSchemaRule(SentimentModel.model_json_schema())])

output = agent.run("What is the sentiment of this message?: 'I am so happy!'").output

sentiment_analysis = SentimentModel.model_validate_json(output.value)

# Autocomplete via dot notation 🤩
print(sentiment_analysis.answer)
print(sentiment_analysis.relevant_emojis)
6 changes: 3 additions & 3 deletions griptape/mixins/rule_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from attrs import Attribute, define, field

from griptape.rules import Rule, Ruleset
from griptape.rules import BaseRule, Ruleset

if TYPE_CHECKING:
from griptape.structures import Structure
Expand All @@ -16,7 +16,7 @@ class RuleMixin:
ADDITIONAL_RULESET_NAME = "Additional Ruleset"

rulesets: list[Ruleset] = field(factory=list, kw_only=True)
rules: list[Rule] = field(factory=list, kw_only=True)
rules: list[BaseRule] = field(factory=list, kw_only=True)
structure: Optional[Structure] = field(default=None, kw_only=True)

@rulesets.validator # pyright: ignore[reportAttributeAccessIssue]
Expand All @@ -28,7 +28,7 @@ def validate_rulesets(self, _: Attribute, rulesets: list[Ruleset]) -> None:
raise ValueError("Can't have both rulesets and rules specified.")

@rules.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_rules(self, _: Attribute, rules: list[Rule]) -> None:
def validate_rules(self, _: Attribute, rules: list[BaseRule]) -> None:
if not rules:
return

Expand Down
4 changes: 3 additions & 1 deletion griptape/rules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from griptape.rules.base_rule import BaseRule
from griptape.rules.rule import Rule
from griptape.rules.json_schema_rule import JsonSchemaRule
from griptape.rules.ruleset import Ruleset


__all__ = ["Rule", "Ruleset"]
__all__ = ["BaseRule", "Rule", "JsonSchemaRule", "Ruleset"]
17 changes: 17 additions & 0 deletions griptape/rules/base_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

from attrs import define, field


@define(frozen=True)
class BaseRule(ABC):
value: Any = field()

def __str__(self) -> str:
return self.to_text()

@abstractmethod
def to_text(self) -> str: ...
17 changes: 17 additions & 0 deletions griptape/rules/json_schema_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

import json

from attrs import define, field

from griptape.rules import BaseRule
from griptape.utils import J2


@define(frozen=True)
class JsonSchemaRule(BaseRule):
value: dict = field()
template_generator: J2 = field(default=J2("rules/json_schema.j2"))

def to_text(self) -> str:
return self.template_generator.render(json_schema=json.dumps(self.value))
11 changes: 8 additions & 3 deletions griptape/rules/rule.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations

from attrs import define
from attrs import define, field

from griptape.rules import BaseRule


@define(frozen=True)
class Rule:
value: str
class Rule(BaseRule):
value: str = field()

def to_text(self) -> str:
return self.value
6 changes: 3 additions & 3 deletions griptape/rules/ruleset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence

from attrs import define, field

if TYPE_CHECKING:
from griptape.rules import Rule
from griptape.rules import BaseRule


@define
class Ruleset:
name: str = field()
rules: list[Rule] = field()
rules: Sequence[BaseRule] = field()
4 changes: 2 additions & 2 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
from griptape.memory.structure import BaseConversationMemory
from griptape.rules import Rule, Ruleset
from griptape.rules import BaseRule, Rule, Ruleset
from griptape.tasks import BaseTask


@define
class Structure(ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
rulesets: list[Ruleset] = field(factory=list, kw_only=True)
rules: list[Rule] = field(factory=list, kw_only=True)
rules: list[BaseRule] = field(factory=list, kw_only=True)
tasks: list[BaseTask] = field(factory=list, kw_only=True)
conversation_memory: Optional[BaseConversationMemory] = field(
default=Factory(lambda: ConversationMemory()),
Expand Down
1 change: 1 addition & 0 deletions griptape/templates/rules/json_schema.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
You must respond with a JSON object that successfully validates against the following schema: {{json_schema}}
2 changes: 1 addition & 1 deletion griptape/templates/rulesets/rulesets.j2
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Ruleset name: {{ ruleset.name }}
"{{ ruleset.name }}" rules:
{% for rule in ruleset.rules %}
Rule #{{loop.index}}
{{ rule.value }}
{{ rule.to_text() }}
{% endfor %}

{% endfor %}
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/rules/test_json_schema_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json

import schema

from griptape.rules import JsonSchemaRule


class TestJsonSchemaRule:
def test_init(self):
json_schema = schema.Schema({"type": "string"}).json_schema("test")
rule = JsonSchemaRule(json_schema)
assert rule.value == {
"type": "object",
"properties": {"type": {"const": "string"}},
"required": ["type"],
"additionalProperties": False,
"$id": "test",
"$schema": "http://json-schema.org/draft-07/schema#",
}

def test_to_text(self):
json_schema = schema.Schema({"type": "string"}).json_schema("test")
rule = JsonSchemaRule(json_schema)
assert (
rule.to_text()
== f"You must respond with a JSON object that successfully validates against the following schema: {json.dumps(json_schema)}"
)

def test___str__(self):
json_schema = schema.Schema({"type": "string"}).json_schema("test")
rule = JsonSchemaRule(json_schema)
assert str(rule) == rule.to_text()
8 changes: 8 additions & 0 deletions tests/unit/rules/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@ class TestRule:
def test_init(self):
rule = Rule("foobar")
assert rule.value == "foobar"

def test_to_text(self):
rule = Rule("foobar")
assert rule.to_text() == "foobar"

def test___str__(self):
rule = Rule("foobar")
assert str(rule) == "foobar"
Loading