Skip to content

Commit

Permalink
Add JsonSchemaRule
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 10, 2024
1 parent 9d9b643 commit 674a61a
Show file tree
Hide file tree
Showing 15 changed files with 164 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Parameter `meta: dict` on `BaseEvent`.
- `AzureOpenAiTextToSpeechDriver`.
- `JsonSchemaRule` for instructing the LLM to output a JSON object that conforms to a schema.

### Changed
- **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`.
Expand Down
43 changes: 41 additions & 2 deletions docs/griptape-framework/structures/rulesets.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,47 @@ search:

## Overview

A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define rules for [Structures](../structures/agents.md) and [Tasks](../structures/tasks.md).
Rulesets can be used to shape personality, format output, restrict topics, and more.
A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Rule](../../reference/griptape/rules/base_rule.md)s for [Structures](../structures/agents.md) and [Tasks](../structures/tasks.md). Griptape places Rules into the LLM's system prompt for strong control over the output.

## Types of Rules

### Rule

[Rule](../../reference/griptape/rules/base_rule.md)s shape the LLM's behavior by defining specific guidelines or instructions for how it should interpret and respond to inputs. Rules can be used to modify language style, tone, or even behavior based on what you define.

```python
--8<-- "docs/griptape-framework/structures/src/basic_rule.py"
```

```
[09/10/24 14:41:52] INFO PromptTask b7b23a88ea9e4cd0befb7e7a4ed596b0
Input: Hi there! How are you?
INFO PromptTask b7b23a88ea9e4cd0befb7e7a4ed596b0
Output: Ahoy, matey! I be doing just fine, thank ye fer askin'. How be the winds blowin' in yer sails today?
```

### Json Schema

[JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema.
This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types.

!!! warning
`JsonSchemaRule` may break [ToolkitTask](../structures/tasks.md#toolkittask) which relies on a specific [output token](https://github.com/griptape-ai/griptape/blob/e6a04c7b88cf9fa5d6bcf4c833ffebfab89a3258/griptape/tasks/toolkit_task.py#L28).


```python
--8<-- "docs/griptape-framework/structures/src/json_schema_rule.py"
```

```
[09/10/24 14:44:53] INFO PromptTask fb26dd41803443c0b51c3d861626e07a
Input: What is the sentiment of this message?: 'I am so happy!'
[09/10/24 14:44:54] INFO PromptTask fb26dd41803443c0b51c3d861626e07a
Output: {
"answer": "The sentiment of the message is positive.",
"relevant_emojis": ["😊", "😃"]
}
```

## Structure

Expand Down
13 changes: 13 additions & 0 deletions docs/griptape-framework/structures/src/basic_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from griptape.rules import Rule, Ruleset
from griptape.structures import Agent

pipeline = Agent(
rulesets=[
Ruleset(
name="Personality",
rules=[Rule("Talk like a pirate.")],
),
]
)

pipeline.run("Hi there! How are you?")
14 changes: 14 additions & 0 deletions docs/griptape-framework/structures/src/json_schema_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import schema

from griptape.rules.json_schema_rule import JsonSchemaRule
from griptape.structures import Agent

agent = Agent(
rules=[
JsonSchemaRule(
schema.Schema({"answer": str, "relevant_emojis": schema.Schema(["str"])}).json_schema("Output Format")
)
]
)

agent.run("What is the sentiment of this message?: 'I am so happy!'")
6 changes: 3 additions & 3 deletions griptape/mixins/rule_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from attrs import Attribute, define, field

from griptape.rules import Rule, Ruleset
from griptape.rules import BaseRule, Ruleset

if TYPE_CHECKING:
from griptape.structures import Structure
Expand All @@ -16,7 +16,7 @@ class RuleMixin:
ADDITIONAL_RULESET_NAME = "Additional Ruleset"

rulesets: list[Ruleset] = field(factory=list, kw_only=True)
rules: list[Rule] = field(factory=list, kw_only=True)
rules: list[BaseRule] = field(factory=list, kw_only=True)
structure: Optional[Structure] = field(default=None, kw_only=True)

@rulesets.validator # pyright: ignore[reportAttributeAccessIssue]
Expand All @@ -28,7 +28,7 @@ def validate_rulesets(self, _: Attribute, rulesets: list[Ruleset]) -> None:
raise ValueError("Can't have both rulesets and rules specified.")

@rules.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_rules(self, _: Attribute, rules: list[Rule]) -> None:
def validate_rules(self, _: Attribute, rules: list[BaseRule]) -> None:
if not rules:
return

Expand Down
4 changes: 3 additions & 1 deletion griptape/rules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from griptape.rules.base_rule import BaseRule
from griptape.rules.rule import Rule
from griptape.rules.json_schema_rule import JsonSchemaRule
from griptape.rules.ruleset import Ruleset


__all__ = ["Rule", "Ruleset"]
__all__ = ["BaseRule", "Rule", "JsonSchemaRule", "Ruleset"]
17 changes: 17 additions & 0 deletions griptape/rules/base_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

from attrs import define, field


@define(frozen=True)
class BaseRule(ABC):
value: Any = field()

def __str__(self) -> str:
return self.to_text()

@abstractmethod
def to_text(self) -> str: ...
17 changes: 17 additions & 0 deletions griptape/rules/json_schema_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

import json

from attrs import define, field

from griptape.rules import BaseRule
from griptape.utils import J2


@define(frozen=True)
class JsonSchemaRule(BaseRule):
value: dict = field()
template_generator: J2 = field(default=J2("rules/json_schema.j2"))

def to_text(self) -> str:
return self.template_generator.render(json_schema=json.dumps(self.value))
11 changes: 8 additions & 3 deletions griptape/rules/rule.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from __future__ import annotations

from attrs import define
from attrs import define, field

from griptape.rules import BaseRule


@define(frozen=True)
class Rule:
value: str
class Rule(BaseRule):
value: str = field()

def to_text(self) -> str:
return self.value
6 changes: 3 additions & 3 deletions griptape/rules/ruleset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence

from attrs import define, field

if TYPE_CHECKING:
from griptape.rules import Rule
from griptape.rules import BaseRule


@define
class Ruleset:
name: str = field()
rules: list[Rule] = field()
rules: Sequence[BaseRule] = field()
4 changes: 2 additions & 2 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
from griptape.memory.structure import BaseConversationMemory
from griptape.rules import Rule, Ruleset
from griptape.rules import BaseRule, Rule, Ruleset
from griptape.tasks import BaseTask


@define
class Structure(ABC):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
rulesets: list[Ruleset] = field(factory=list, kw_only=True)
rules: list[Rule] = field(factory=list, kw_only=True)
rules: list[BaseRule] = field(factory=list, kw_only=True)
tasks: list[BaseTask] = field(factory=list, kw_only=True)
conversation_memory: Optional[BaseConversationMemory] = field(
default=Factory(lambda: ConversationMemory()),
Expand Down
1 change: 1 addition & 0 deletions griptape/templates/rules/json_schema.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
You must respond with a JSON object that successfully validates against the following schema: {{json_schema}}
2 changes: 1 addition & 1 deletion griptape/templates/rulesets/rulesets.j2
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Ruleset name: {{ ruleset.name }}
"{{ ruleset.name }}" rules:
{% for rule in ruleset.rules %}
Rule #{{loop.index}}
{{ rule.value }}
{{ rule.to_text() }}
{% endfor %}

{% endfor %}
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/rules/test_json_schema_rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json

import schema

from griptape.rules import JsonSchemaRule


class TestJsonSchemaRule:
def test_init(self):
json_schema = schema.Schema({"type": "string"}).json_schema("test")
rule = JsonSchemaRule(json_schema)
assert rule.value == {
"type": "object",
"properties": {"type": {"const": "string"}},
"required": ["type"],
"additionalProperties": False,
"$id": "test",
"$schema": "http://json-schema.org/draft-07/schema#",
}

def test_to_text(self):
json_schema = schema.Schema({"type": "string"}).json_schema("test")
rule = JsonSchemaRule(json_schema)
assert (
rule.to_text()
== f"You must respond with a JSON object that successfully validates against the following schema: {json.dumps(json_schema)}"
)

def test___str__(self):
json_schema = schema.Schema({"type": "string"}).json_schema("test")
rule = JsonSchemaRule(json_schema)
assert str(rule) == rule.to_text()
8 changes: 8 additions & 0 deletions tests/unit/rules/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@ class TestRule:
def test_init(self):
rule = Rule("foobar")
assert rule.value == "foobar"

def test_to_text(self):
rule = Rule("foobar")
assert rule.to_text() == "foobar"

def test___str__(self):
rule = Rule("foobar")
assert str(rule) == "foobar"

0 comments on commit 674a61a

Please sign in to comment.