Skip to content

Commit

Permalink
Clean up actions subtask parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 18, 2024
1 parent ace7596 commit d99b7cc
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 15 deletions.
21 changes: 12 additions & 9 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,12 @@ def __init_from_prompt(self, value: str) -> None:
actions_matches = re.findall(self.ACTIONS_PATTERN, value, re.DOTALL)
answer_matches = re.findall(self.ANSWER_PATTERN, value, re.MULTILINE)

if self.thought is None and thought_matches:
self.thought = thought_matches[-1]
self.actions = self.__parse_actions(actions_matches)

self.__parse_actions(actions_matches)
if thought_matches:
self.thought = thought_matches[-1]

# If there are no actions and no output, there may still be an output we can set.
if len(self.actions) == 0 and self.output is None:
if not self.actions and self.output is None:
if answer_matches:
# A direct answer is provided, set it as the output.
self.output = TextArtifact(answer_matches[-1])
Expand All @@ -248,26 +247,30 @@ def __init_from_artifacts(self, artifacts: ListArtifact) -> None:
if isinstance(artifact, ActionArtifact)
]

# When parsing from Artifacts we can't determine the thought unless there are also Actions
if self.actions:
thoughts = [artifact.value for artifact in artifacts.value if isinstance(artifact, TextArtifact)]
if thoughts:
self.thought = thoughts[0]
else:
self.output = TextArtifact(artifacts.to_text())
if self.output is None:
self.output = TextArtifact(artifacts.to_text())

def __parse_actions(self, actions_matches: list[str]) -> None:
def __parse_actions(self, actions_matches: list[str]) -> list[ToolAction]:
if len(actions_matches) == 0:
return
return []
try:
data = actions_matches[-1]
actions_list: list[dict] = json.loads(data, strict=False)

self.actions = [self.__process_action_object(action_object) for action_object in actions_list]
return [self.__process_action_object(action_object) for action_object in actions_list]
except json.JSONDecodeError as e:
logger.exception("Subtask %s\nInvalid actions JSON: %s", self.origin_task.id, e)

self.output = ErrorArtifact(f"Actions JSON decoding error: {e}", exception=e)

return []

def __process_action_object(self, action_object: dict) -> ToolAction:
# Load action tag; throw exception if the key is not present
action_tag = action_object["tag"]
Expand Down
59 changes: 53 additions & 6 deletions tests/unit/tasks/test_actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class TestActionsSubtask:
def test_basic_input(self):
def test_prompt_input(self):
valid_input = (
"Thought: need to test\n"
'Actions: [{"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "value"}}}]\n'
Expand All @@ -25,22 +25,31 @@ def test_basic_input(self):
assert json_dict[0]["name"] == "MockTool"
assert json_dict[0]["path"] == "test"
assert json_dict[0]["input"] == {"values": {"test": "value"}}
assert subtask.thought == "need to test"
assert subtask.output is None

def test_action_input(self):
valid_input = ActionArtifact(
ToolAction(tag="foo", name="MockTool", path="test", input={"values": {"test": "value"}})
def test_artifact_input(self):
valid_input = ListArtifact(
[
TextArtifact("need to test"),
ActionArtifact(
ToolAction(tag="foo", name="MockTool", path="test", input={"values": {"test": "value"}})
),
TextArtifact("answer"),
]
)
task = ToolkitTask(tools=[MockTool()])
Agent().add_task(task)
subtask = task.add_subtask(ActionsSubtask(valid_input))
json_dict = json.loads(subtask.actions_to_json())

assert subtask.thought is None
assert json_dict[0]["name"] == "MockTool"
assert json_dict[0]["path"] == "test"
assert json_dict[0]["input"] == {"values": {"test": "value"}}
assert subtask.thought == "need to test"
assert subtask.output is None

def test_action_and_thought_input(self):
def test_artifact_action_and_thought_input(self):
valid_input = ListArtifact(
[
TextArtifact("thought"),
Expand All @@ -59,6 +68,42 @@ def test_action_and_thought_input(self):
assert json_dict[0]["path"] == "test"
assert json_dict[0]["input"] == {"values": {"test": "value"}}

def test_prompt_answer(self):
valid_input = "Answer: test output"

task = ToolkitTask(tools=[MockTool()])
Agent().add_task(task)
subtask = task.add_subtask(ActionsSubtask(valid_input))

assert subtask.thought is None
assert subtask.actions == []
assert subtask.output.value == "test output"

def test_prompt_implicit_answer(self):
valid_input = "test output"

task = ToolkitTask(tools=[MockTool()])
Agent().add_task(task)
subtask = task.add_subtask(ActionsSubtask(valid_input))

assert subtask.thought is None
assert subtask.actions == []
assert subtask.output.value == "test output"

def test_artifact_answer(self):
valid_input = ListArtifact(
[
TextArtifact("answer"),
]
)
task = ToolkitTask(tools=[MockTool()])
Agent().add_task(task)
subtask = task.add_subtask(ActionsSubtask(valid_input))

assert subtask.thought is None
assert subtask.actions == []
assert subtask.output.value == "answer"

def test_callable_input(self):
valid_input = ListArtifact(
[
Expand Down Expand Up @@ -146,6 +191,8 @@ def test_invalid_actions(self):

assert isinstance(subtask.output, ErrorArtifact)
assert "Actions JSON decoding error" in subtask.output.value
assert subtask.thought == "need to test"
assert subtask.actions == []

def test_implicit_values(self):
valid_input = (
Expand Down

0 comments on commit d99b7cc

Please sign in to comment.