From 2ae50b3c826b2c41788c42bc581ba5b2fb96fd69 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Sep 2024 14:15:45 -0700 Subject: [PATCH] Add JsonSchemaRule (#1165) --- CHANGELOG.md | 1 + .../griptape-framework/structures/rulesets.md | 62 ++++++++++++++++++- .../structures/src/basic_rule.py | 13 ++++ .../structures/src/json_schema_rule.py | 18 ++++++ .../src/json_schema_rule_pydantic.py | 22 +++++++ griptape/mixins/rule_mixin.py | 6 +- griptape/rules/__init__.py | 4 +- griptape/rules/base_rule.py | 17 +++++ griptape/rules/json_schema_rule.py | 17 +++++ griptape/rules/rule.py | 11 +++- griptape/rules/ruleset.py | 6 +- griptape/structures/structure.py | 4 +- griptape/templates/rules/json_schema.j2 | 1 + griptape/templates/rulesets/rulesets.j2 | 2 +- tests/unit/rules/test_json_schema_rule.py | 32 ++++++++++ tests/unit/rules/test_rule.py | 8 +++ 16 files changed, 209 insertions(+), 15 deletions(-) create mode 100644 docs/griptape-framework/structures/src/basic_rule.py create mode 100644 docs/griptape-framework/structures/src/json_schema_rule.py create mode 100644 docs/griptape-framework/structures/src/json_schema_rule_pydantic.py create mode 100644 griptape/rules/base_rule.py create mode 100644 griptape/rules/json_schema_rule.py create mode 100644 griptape/templates/rules/json_schema.j2 create mode 100644 tests/unit/rules/test_json_schema_rule.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 788dd2e23..25387aafd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Parameter `meta: dict` on `BaseEvent`. - `AzureOpenAiTextToSpeechDriver`. - Ability to use Event Listeners as Context Managers for temporarily setting the Event Bus listeners. +- `JsonSchemaRule` for instructing the LLM to output a JSON object that conforms to a schema. ### Changed - **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`. diff --git a/docs/griptape-framework/structures/rulesets.md b/docs/griptape-framework/structures/rulesets.md index d69b085ac..a0773856f 100644 --- a/docs/griptape-framework/structures/rulesets.md +++ b/docs/griptape-framework/structures/rulesets.md @@ -5,8 +5,66 @@ search: ## Overview -A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define rules for [Structures](../structures/agents.md) and [Tasks](../structures/tasks.md). -Rulesets can be used to shape personality, format output, restrict topics, and more. +A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Rule](../../reference/griptape/rules/base_rule.md)s for [Structures](../structures/agents.md) and [Tasks](../structures/tasks.md). Griptape places Rules into the LLM's system prompt for strong control over the output. + +## Types of Rules + +### Rule + +[Rule](../../reference/griptape/rules/base_rule.md)s shape the LLM's behavior by defining specific guidelines or instructions for how it should interpret and respond to inputs. Rules can be used to modify language style, tone, or even behavior based on what you define. + +```python +--8<-- "docs/griptape-framework/structures/src/basic_rule.py" +``` + +``` +[09/10/24 14:41:52] INFO PromptTask b7b23a88ea9e4cd0befb7e7a4ed596b0 + Input: Hi there! How are you? + INFO PromptTask b7b23a88ea9e4cd0befb7e7a4ed596b0 + Output: Ahoy, matey! I be doing just fine, thank ye fer askin'. How be the winds blowin' in yer sails today? +``` + +### Json Schema + +[JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema. +This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types. + +!!! warning + `JsonSchemaRule` may break [ToolkitTask](../structures/tasks.md#toolkittask) which relies on a specific [output token](https://github.com/griptape-ai/griptape/blob/e6a04c7b88cf9fa5d6bcf4c833ffebfab89a3258/griptape/tasks/toolkit_task.py#L28). + + +```python +--8<-- "docs/griptape-framework/structures/src/json_schema_rule.py" +``` + +``` +[09/10/24 14:44:53] INFO PromptTask fb26dd41803443c0b51c3d861626e07a + Input: What is the sentiment of this message?: 'I am so happy!' +[09/10/24 14:44:54] INFO PromptTask fb26dd41803443c0b51c3d861626e07a + Output: { + "answer": "The sentiment of the message is positive.", + "relevant_emojis": ["😊", "😃"] + } +``` + +Although Griptape leverages the `schema` library, you're free to use any JSON schema generation library to define your schema! + +For example, using `pydantic`: + +```python +--8<-- "docs/griptape-framework/structures/src/json_schema_rule_pydantic.py" +``` + +``` +[09/11/24 09:45:58] INFO PromptTask eae43f52829c4289a6cca9ee7950e075 + Input: What is the sentiment of this message?: 'I am so happy!' + INFO PromptTask eae43f52829c4289a6cca9ee7950e075 + Output: { + "answer": "The sentiment of the message is positive.", + "relevant_emojis": ["😊", "😄"] + } +answer='The sentiment of the message is positive.' relevant_emojis=['😊', '😄'] +``` ## Structure diff --git a/docs/griptape-framework/structures/src/basic_rule.py b/docs/griptape-framework/structures/src/basic_rule.py new file mode 100644 index 000000000..75511f514 --- /dev/null +++ b/docs/griptape-framework/structures/src/basic_rule.py @@ -0,0 +1,13 @@ +from griptape.rules import Rule, Ruleset +from griptape.structures import Agent + +pipeline = Agent( + rulesets=[ + Ruleset( + name="Personality", + rules=[Rule("Talk like a pirate.")], + ), + ] +) + +pipeline.run("Hi there! How are you?") diff --git a/docs/griptape-framework/structures/src/json_schema_rule.py b/docs/griptape-framework/structures/src/json_schema_rule.py new file mode 100644 index 000000000..1f78de928 --- /dev/null +++ b/docs/griptape-framework/structures/src/json_schema_rule.py @@ -0,0 +1,18 @@ +import json + +import schema + +from griptape.rules.json_schema_rule import JsonSchemaRule +from griptape.structures import Agent + +agent = Agent( + rules=[ + JsonSchemaRule( + schema.Schema({"answer": str, "relevant_emojis": schema.Schema(["str"])}).json_schema("Output Format") + ) + ] +) + +output = agent.run("What is the sentiment of this message?: 'I am so happy!'").output + +print(json.dumps(json.loads(output.value), indent=2)) diff --git a/docs/griptape-framework/structures/src/json_schema_rule_pydantic.py b/docs/griptape-framework/structures/src/json_schema_rule_pydantic.py new file mode 100644 index 000000000..bfbcf7cf3 --- /dev/null +++ b/docs/griptape-framework/structures/src/json_schema_rule_pydantic.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import pydantic + +from griptape.rules.json_schema_rule import JsonSchemaRule +from griptape.structures import Agent + + +class SentimentModel(pydantic.BaseModel): + answer: str + relevant_emojis: list[str] + + +agent = Agent(rules=[JsonSchemaRule(SentimentModel.model_json_schema())]) + +output = agent.run("What is the sentiment of this message?: 'I am so happy!'").output + +sentiment_analysis = SentimentModel.model_validate_json(output.value) + +# Autocomplete via dot notation 🤩 +print(sentiment_analysis.answer) +print(sentiment_analysis.relevant_emojis) diff --git a/griptape/mixins/rule_mixin.py b/griptape/mixins/rule_mixin.py index ff4395270..7fe6a6346 100644 --- a/griptape/mixins/rule_mixin.py +++ b/griptape/mixins/rule_mixin.py @@ -4,7 +4,7 @@ from attrs import Attribute, define, field -from griptape.rules import Rule, Ruleset +from griptape.rules import BaseRule, Ruleset if TYPE_CHECKING: from griptape.structures import Structure @@ -16,7 +16,7 @@ class RuleMixin: ADDITIONAL_RULESET_NAME = "Additional Ruleset" rulesets: list[Ruleset] = field(factory=list, kw_only=True) - rules: list[Rule] = field(factory=list, kw_only=True) + rules: list[BaseRule] = field(factory=list, kw_only=True) structure: Optional[Structure] = field(default=None, kw_only=True) @rulesets.validator # pyright: ignore[reportAttributeAccessIssue] @@ -28,7 +28,7 @@ def validate_rulesets(self, _: Attribute, rulesets: list[Ruleset]) -> None: raise ValueError("Can't have both rulesets and rules specified.") @rules.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_rules(self, _: Attribute, rules: list[Rule]) -> None: + def validate_rules(self, _: Attribute, rules: list[BaseRule]) -> None: if not rules: return diff --git a/griptape/rules/__init__.py b/griptape/rules/__init__.py index 4becdc1e5..a2e8ae08b 100644 --- a/griptape/rules/__init__.py +++ b/griptape/rules/__init__.py @@ -1,5 +1,7 @@ +from griptape.rules.base_rule import BaseRule from griptape.rules.rule import Rule +from griptape.rules.json_schema_rule import JsonSchemaRule from griptape.rules.ruleset import Ruleset -__all__ = ["Rule", "Ruleset"] +__all__ = ["BaseRule", "Rule", "JsonSchemaRule", "Ruleset"] diff --git a/griptape/rules/base_rule.py b/griptape/rules/base_rule.py new file mode 100644 index 000000000..190fc71e4 --- /dev/null +++ b/griptape/rules/base_rule.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +from attrs import define, field + + +@define(frozen=True) +class BaseRule(ABC): + value: Any = field() + + def __str__(self) -> str: + return self.to_text() + + @abstractmethod + def to_text(self) -> str: ... diff --git a/griptape/rules/json_schema_rule.py b/griptape/rules/json_schema_rule.py new file mode 100644 index 000000000..1bd418464 --- /dev/null +++ b/griptape/rules/json_schema_rule.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import json + +from attrs import define, field + +from griptape.rules import BaseRule +from griptape.utils import J2 + + +@define(frozen=True) +class JsonSchemaRule(BaseRule): + value: dict = field() + template_generator: J2 = field(default=J2("rules/json_schema.j2")) + + def to_text(self) -> str: + return self.template_generator.render(json_schema=json.dumps(self.value)) diff --git a/griptape/rules/rule.py b/griptape/rules/rule.py index 1063d174e..952770adf 100644 --- a/griptape/rules/rule.py +++ b/griptape/rules/rule.py @@ -1,8 +1,13 @@ from __future__ import annotations -from attrs import define +from attrs import define, field + +from griptape.rules import BaseRule @define(frozen=True) -class Rule: - value: str +class Rule(BaseRule): + value: str = field() + + def to_text(self) -> str: + return self.value diff --git a/griptape/rules/ruleset.py b/griptape/rules/ruleset.py index 1f158411a..eec1203f9 100644 --- a/griptape/rules/ruleset.py +++ b/griptape/rules/ruleset.py @@ -1,14 +1,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Sequence from attrs import define, field if TYPE_CHECKING: - from griptape.rules import Rule + from griptape.rules import BaseRule @define class Ruleset: name: str = field() - rules: list[Rule] = field() + rules: Sequence[BaseRule] = field() diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 63ba02373..b066c336e 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from griptape.artifacts import BaseArtifact from griptape.memory.structure import BaseConversationMemory - from griptape.rules import Rule, Ruleset + from griptape.rules import BaseRule, Rule, Ruleset from griptape.tasks import BaseTask @@ -23,7 +23,7 @@ class Structure(ABC): id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) rulesets: list[Ruleset] = field(factory=list, kw_only=True) - rules: list[Rule] = field(factory=list, kw_only=True) + rules: list[BaseRule] = field(factory=list, kw_only=True) tasks: list[BaseTask] = field(factory=list, kw_only=True) conversation_memory: Optional[BaseConversationMemory] = field( default=Factory(lambda: ConversationMemory()), diff --git a/griptape/templates/rules/json_schema.j2 b/griptape/templates/rules/json_schema.j2 new file mode 100644 index 000000000..9a351c1cd --- /dev/null +++ b/griptape/templates/rules/json_schema.j2 @@ -0,0 +1 @@ +You must respond with a JSON object that successfully validates against the following schema: {{json_schema}} diff --git a/griptape/templates/rulesets/rulesets.j2 b/griptape/templates/rulesets/rulesets.j2 index 1f58aa811..5b149adbc 100644 --- a/griptape/templates/rulesets/rulesets.j2 +++ b/griptape/templates/rulesets/rulesets.j2 @@ -6,7 +6,7 @@ Ruleset name: {{ ruleset.name }} "{{ ruleset.name }}" rules: {% for rule in ruleset.rules %} Rule #{{loop.index}} -{{ rule.value }} +{{ rule.to_text() }} {% endfor %} {% endfor %} diff --git a/tests/unit/rules/test_json_schema_rule.py b/tests/unit/rules/test_json_schema_rule.py new file mode 100644 index 000000000..a1a4f2361 --- /dev/null +++ b/tests/unit/rules/test_json_schema_rule.py @@ -0,0 +1,32 @@ +import json + +import schema + +from griptape.rules import JsonSchemaRule + + +class TestJsonSchemaRule: + def test_init(self): + json_schema = schema.Schema({"type": "string"}).json_schema("test") + rule = JsonSchemaRule(json_schema) + assert rule.value == { + "type": "object", + "properties": {"type": {"const": "string"}}, + "required": ["type"], + "additionalProperties": False, + "$id": "test", + "$schema": "http://json-schema.org/draft-07/schema#", + } + + def test_to_text(self): + json_schema = schema.Schema({"type": "string"}).json_schema("test") + rule = JsonSchemaRule(json_schema) + assert ( + rule.to_text() + == f"You must respond with a JSON object that successfully validates against the following schema: {json.dumps(json_schema)}" + ) + + def test___str__(self): + json_schema = schema.Schema({"type": "string"}).json_schema("test") + rule = JsonSchemaRule(json_schema) + assert str(rule) == rule.to_text() diff --git a/tests/unit/rules/test_rule.py b/tests/unit/rules/test_rule.py index f3bbf4664..9afce8dd5 100644 --- a/tests/unit/rules/test_rule.py +++ b/tests/unit/rules/test_rule.py @@ -5,3 +5,11 @@ class TestRule: def test_init(self): rule = Rule("foobar") assert rule.value == "foobar" + + def test_to_text(self): + rule = Rule("foobar") + assert rule.to_text() == "foobar" + + def test___str__(self): + rule = Rule("foobar") + assert str(rule) == "foobar"