Skip to content

Commit

Permalink
Add structure to Task init
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Aug 5, 2024
1 parent fe53c41 commit c289a91
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 22 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `AstraDbVectorStoreDriver` to support DataStax Astra DB as a vector store.
- Ability to set custom schema properties on Tool Activities via `extra_schema_properties`.
- Parameter `structure` to `BaseTask`.
- Method `try_find_task` to `Structure`.

### Changed
- `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible.

## [0.29.0] - 2024-07-30

Expand Down
15 changes: 7 additions & 8 deletions docs/griptape-framework/structures/workflows.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,18 +269,17 @@ from griptape.tasks import PromptTask
from griptape.structures import Workflow
from griptape.rules import Rule

animal_task = PromptTask("Name an animal", id="animal")
adjective_task = PromptTask("Describe {{ parent_outputs['animal'] }} with an adjective", id="adjective")
new_animal_task = PromptTask("Name a {{ parent_outputs['adjective'] }} animal", id="new-animal")

adjective_task.add_parent(animal_task)
adjective_task.add_child(new_animal_task)

workflow = Workflow(
tasks=[animal_task, adjective_task, new_animal_task],
rules=[Rule("output a single lowercase word")],
)

animal_task = PromptTask("Name an animal", id="animal", structure=workflow)
adjective_task = PromptTask("Describe {{ parent_outputs['animal'] }} with an adjective", id="adjective", structure=workflow)
new_animal_task = PromptTask("Name a {{ parent_outputs['adjective'] }} animal", id="new-animal", structure=workflow)

adjective_task.add_parent(animal_task)
adjective_task.add_child(new_animal_task)

workflow.run()
```

Expand Down
3 changes: 3 additions & 0 deletions griptape/structures/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
@define
class Pipeline(Structure):
def add_task(self, task: BaseTask) -> BaseTask:
if (existing_task := self.try_find_task(task.id)) is not None:
return existing_task

task.preprocess(self)

if self.output_task:
Expand Down
13 changes: 11 additions & 2 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,15 @@ def is_executing(self) -> bool:
return any(s for s in self.tasks if s.is_executing())

def find_task(self, task_id: str) -> BaseTask:
if (task := self.try_find_task(task_id)) is not None:
return task
raise ValueError(f"Task with id {task_id} doesn't exist.")

def try_find_task(self, task_id: str) -> Optional[BaseTask]:
for task in self.tasks:
if task.id == task_id:
return task
raise ValueError(f"Task with id {task_id} doesn't exist.")
return None

def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]:
return [self.add_task(s) for s in tasks]
Expand All @@ -227,7 +232,11 @@ def context(self, task: BaseTask) -> dict[str, Any]:
return {"args": self.execution_args, "structure": self}

def resolve_relationships(self) -> None:
task_by_id = {task.id: task for task in self.tasks}
task_by_id = {}
for task in self.tasks:
if task.id in task_by_id:
raise ValueError(f"Duplicate task with id {task.id} found.")
task_by_id[task.id] = task

for task in self.tasks:
# Ensure parents include this task as a child
Expand Down
7 changes: 7 additions & 0 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,18 @@ class Workflow(Structure):
kw_only=True,
)

@property
def input_task(self) -> Optional[BaseTask]:
return self.order_tasks()[0] if self.tasks else None

@property
def output_task(self) -> Optional[BaseTask]:
return self.order_tasks()[-1] if self.tasks else None

def add_task(self, task: BaseTask) -> BaseTask:
if (existing_task := self.try_find_task(task.id)) is not None:
return existing_task

task.preprocess(self)

self.tasks.append(task)
Expand Down
12 changes: 12 additions & 0 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ def children(self) -> list[BaseTask]:
else:
raise Exception("ActionSubtask must be attached to a Task that implements ActionSubtaskOriginMixin.")

def add_child(self, child: str | BaseTask) -> None:
child_id = child if isinstance(child, str) else child.id

if child_id not in self.child_ids:
self.child_ids.append(child_id)

def add_parent(self, parent: str | BaseTask) -> None:
parent_id = parent if isinstance(parent, str) else parent.id

if parent_id not in self.parent_ids:
self.parent_ids.append(parent_id)

def attach_to(self, parent_task: BaseTask) -> None:
self.parent_task_id = parent_task.id
self.structure = parent_task.structure
Expand Down
28 changes: 25 additions & 3 deletions griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,34 @@ class State(Enum):
parent_ids: list[str] = field(factory=list, kw_only=True)
child_ids: list[str] = field(factory=list, kw_only=True)
max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True)
structure: Optional[Structure] = field(default=None, kw_only=True)

output: Optional[BaseArtifact] = field(default=None, init=False)
structure: Optional[Structure] = field(default=None, init=False)
context: dict[str, Any] = field(factory=dict, kw_only=True)
futures_executor_fn: Callable[[], futures.Executor] = field(
default=Factory(lambda: lambda: futures.ThreadPoolExecutor()),
kw_only=True,
)

def __attrs_post_init__(self) -> None:
if self.structure is not None:
self.structure.add_task(self)

@property
@abstractmethod
def input(self) -> BaseArtifact: ...

@property
def parents(self) -> list[BaseTask]:
return [self.structure.find_task(parent_id) for parent_id in self.parent_ids]
if self.structure is not None:
return [self.structure.find_task(parent_id) for parent_id in self.parent_ids]
raise ValueError("Structure must be set to access parents")

Check warning on line 53 in griptape/tasks/base_task.py

View check run for this annotation

Codecov / codecov/patch

griptape/tasks/base_task.py#L53

Added line #L53 was not covered by tests

@property
def children(self) -> list[BaseTask]:
return [self.structure.find_task(child_id) for child_id in self.child_ids]
if self.structure is not None:
return [self.structure.find_task(child_id) for child_id in self.child_ids]
raise ValueError("Structure must be set to access children")

Check warning on line 59 in griptape/tasks/base_task.py

View check run for this annotation

Codecov / codecov/patch

griptape/tasks/base_task.py#L59

Added line #L59 was not covered by tests

@property
def parent_outputs(self) -> dict[str, str]:
Expand Down Expand Up @@ -76,21 +84,35 @@ def add_parents(self, parents: list[str | BaseTask]) -> None:
self.add_parent(parent)

def add_parent(self, parent: str | BaseTask) -> None:
parent_task = parent if isinstance(parent, BaseTask) else None
parent_id = parent if isinstance(parent, str) else parent.id

if parent_id not in self.parent_ids:
self.parent_ids.append(parent_id)

if parent_task is not None:
parent_task.add_child(self.id)

if self.structure is not None:
self.structure.add_task(parent_task)

def add_children(self, children: list[str | BaseTask]) -> None:
for child in children:
self.add_child(child)

def add_child(self, child: str | BaseTask) -> None:
child_task = child if isinstance(child, BaseTask) else None
child_id = child if isinstance(child, str) else child.id

if child_id not in self.child_ids:
self.child_ids.append(child_id)

if child_task is not None:
child_task.add_parent(self.id)

if self.structure is not None:
self.structure.add_task(child_task)

def preprocess(self, structure: Structure) -> BaseTask:
self.structure = structure

Expand Down
1 change: 1 addition & 0 deletions griptape/tasks/toolkit_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def find_subtask(self, subtask_id: str) -> ActionsSubtask:

def add_subtask(self, subtask: ActionsSubtask) -> ActionsSubtask:
subtask.attach_to(self)
subtask.structure = self.structure

if len(self.subtasks) > 0:
self.subtasks[-1].add_child(subtask)
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/structures/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,22 @@ def test_run_with_error_artifact_no_fail_fast(self, error_artifact_task, waiting
pipeline.run()

assert pipeline.output is not None

def test_add_duplicate_task(self):
task = PromptTask("test")
pipeline = Pipeline(prompt_driver=MockPromptDriver())

pipeline + task
pipeline + task

assert len(pipeline.tasks) == 1

def test_add_duplicate_task_directly(self):
task = PromptTask("test")
pipeline = Pipeline(prompt_driver=MockPromptDriver())

pipeline + task
pipeline.tasks.append(task)

with pytest.raises(ValueError, match=f"Duplicate task with id {task.id} found."):
pipeline.run()
37 changes: 28 additions & 9 deletions tests/unit/structures/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,34 @@ def test_run_topology_1_imperative_children(self):

self._validate_topology_1(workflow)

def test_run_topology_1_imperative_parents_structure_init(self):
workflow = Workflow(prompt_driver=MockPromptDriver())
task1 = PromptTask("test1", id="task1")
task2 = PromptTask("test2", id="task2", structure=workflow)
task3 = PromptTask("test3", id="task3", structure=workflow)
task4 = PromptTask("test4", id="task4", structure=workflow)
task2.add_parent(task1)
task3.add_parent("task1")
task4.add_parents([task2, "task3"])

workflow.run()

self._validate_topology_1(workflow)

def test_run_topology_1_imperative_children_structure_init(self):
workflow = Workflow(prompt_driver=MockPromptDriver())
task1 = PromptTask("test1", id="task1", structure=workflow)
task2 = PromptTask("test2", id="task2", structure=workflow)
task3 = PromptTask("test3", id="task3", structure=workflow)
task4 = PromptTask("test4", id="task4")
task1.add_children([task2, task3])
task2.add_child(task4)
task3.add_child(task4)

workflow.run()

self._validate_topology_1(workflow)

def test_run_topology_1_imperative_mixed(self):
task1 = PromptTask("test1", id="task1")
task2 = PromptTask("test2", id="task2")
Expand Down Expand Up @@ -781,8 +809,6 @@ def _validate_topology_1(workflow) -> None:
assert len(workflow.tasks) == 4
assert workflow.input_task.id == "task1"
assert workflow.output_task.id == "task4"
assert workflow.input_task.id == workflow.tasks[0].id
assert workflow.output_task.id == workflow.tasks[-1].id

task1 = workflow.find_task("task1")
assert task1.state == BaseTask.State.FINISHED
Expand Down Expand Up @@ -810,8 +836,6 @@ def _validate_topology_2(workflow) -> None:
assert len(workflow.tasks) == 5
assert workflow.input_task.id == "taska"
assert workflow.output_task.id == "taske"
assert workflow.input_task.id == workflow.tasks[0].id
assert workflow.output_task.id == workflow.tasks[-1].id

taska = workflow.find_task("taska")
assert taska.state == BaseTask.State.FINISHED
Expand Down Expand Up @@ -843,9 +867,6 @@ def _validate_topology_3(workflow) -> None:
assert len(workflow.tasks) == 4
assert workflow.input_task.id == "task1"
assert workflow.output_task.id == "task3"
assert workflow.input_task.id == workflow.tasks[0].id
assert workflow.output_task.id == workflow.tasks[-1].id

task1 = workflow.find_task("task1")
assert task1.state == BaseTask.State.FINISHED
assert task1.parent_ids == []
Expand All @@ -871,8 +892,6 @@ def _validate_topology_4(workflow) -> None:
assert len(workflow.tasks) == 9
assert workflow.input_task.id == "collect_movie_info"
assert workflow.output_task.id == "summarize_to_slack"
assert workflow.input_task.id == workflow.tasks[0].id
assert workflow.output_task.id == workflow.tasks[-1].id

collect_movie_info = workflow.find_task("collect_movie_info")
assert collect_movie_info.parent_ids == []
Expand Down

0 comments on commit c289a91

Please sign in to comment.