diff --git a/CHANGELOG.md b/CHANGELOG.md index c41c46879..f37819e25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `BaseExtractionEngine` no longer catches exceptions and returns `ErrorArtifact`s. - **BREAKING**: `JsonExtractionEngine.template_schema` is now required. - **BREAKING**: `CsvExtractionEngine.column_names` is now required. +- `StructureRunTask` now inherits from `PromptTask`. - `JsonExtractionEngine.extract_artifacts` now returns a `ListArtifact[JsonArtifact]`. - `CsvExtractionEngine.extract_artifacts` now returns a `ListArtifact[CsvRowArtifact]`. - Remove `manifest.yml` requirements for custom tool creation. diff --git a/griptape/tasks/__init__.py b/griptape/tasks/__init__.py index 7d08cf858..4f65a4226 100644 --- a/griptape/tasks/__init__.py +++ b/griptape/tasks/__init__.py @@ -1,6 +1,5 @@ from .base_task import BaseTask from .base_text_input_task import BaseTextInputTask -from .base_multi_text_input_task import BaseMultiTextInputTask from .prompt_task import PromptTask from .actions_subtask import ActionsSubtask from .toolkit_task import ToolkitTask @@ -23,7 +22,6 @@ __all__ = [ "BaseTask", "BaseTextInputTask", - "BaseMultiTextInputTask", "PromptTask", "ActionsSubtask", "ToolkitTask", diff --git a/griptape/tasks/base_multi_text_input_task.py b/griptape/tasks/base_multi_text_input_task.py deleted file mode 100644 index 347dd7e29..000000000 --- a/griptape/tasks/base_multi_text_input_task.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -import logging -from abc import ABC -from typing import Callable - -from attrs import Factory, define, field - -from griptape.artifacts import ListArtifact, TextArtifact -from griptape.configs import Defaults -from griptape.mixins.rule_mixin import RuleMixin -from griptape.tasks import BaseTask -from griptape.utils import J2 - -logger = logging.getLogger(Defaults.logging_config.logger_name) - - -@define -class BaseMultiTextInputTask(RuleMixin, BaseTask, ABC): - DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" - - _input: tuple[str, ...] | tuple[TextArtifact, ...] | tuple[Callable[[BaseTask], TextArtifact], ...] = field( - default=Factory(lambda self: (self.DEFAULT_INPUT_TEMPLATE,), takes_self=True), - alias="input", - ) - - @property - def input(self) -> ListArtifact: - if all(isinstance(elem, TextArtifact) for elem in self._input): - return ListArtifact([artifact for artifact in self._input if isinstance(artifact, TextArtifact)]) - elif all(isinstance(elem, Callable) for elem in self._input): - return ListArtifact( - [callable_input(self) for callable_input in self._input if isinstance(callable_input, Callable)] - ) - else: - return ListArtifact( - [ - TextArtifact(J2().render_from_string(input_template, **self.full_context)) - for input_template in self._input - if isinstance(input_template, str) - ], - ) - - @input.setter - def input( - self, - value: tuple[str, ...] | tuple[TextArtifact, ...] | tuple[Callable[[BaseTask], TextArtifact], ...], - ) -> None: - self._input = value - - def before_run(self) -> None: - super().before_run() - - joined_input = "\n".join([i.to_text() for i in self.input]) - logger.info("%s %s\nInput: %s", self.__class__.__name__, self.id, joined_input) - - def after_run(self) -> None: - super().after_run() - - logger.info("%s %s\nOutput: %s", self.__class__.__name__, self.id, self.output.to_text()) diff --git a/griptape/tasks/structure_run_task.py b/griptape/tasks/structure_run_task.py index 887a33b8a..6860958aa 100644 --- a/griptape/tasks/structure_run_task.py +++ b/griptape/tasks/structure_run_task.py @@ -4,7 +4,8 @@ from attrs import define, field -from griptape.tasks import BaseMultiTextInputTask +from griptape.artifacts.list_artifact import ListArtifact +from griptape.tasks.prompt_task import PromptTask if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -12,7 +13,7 @@ @define -class StructureRunTask(BaseMultiTextInputTask): +class StructureRunTask(PromptTask): """Task to run a Structure. Attributes: @@ -22,4 +23,7 @@ class StructureRunTask(BaseMultiTextInputTask): driver: BaseStructureRunDriver = field(kw_only=True) def run(self) -> BaseArtifact: - return self.driver.run(*self.input) + if isinstance(self.input, ListArtifact): + return self.driver.run(*self.input.value) + else: + return self.driver.run(self.input) diff --git a/tests/mocks/mock_multi_text_input_task.py b/tests/mocks/mock_multi_text_input_task.py deleted file mode 100644 index be00bbf65..000000000 --- a/tests/mocks/mock_multi_text_input_task.py +++ /dev/null @@ -1,10 +0,0 @@ -from attrs import define - -from griptape.artifacts import TextArtifact -from griptape.tasks import BaseMultiTextInputTask - - -@define -class MockMultiTextInputTask(BaseMultiTextInputTask): - def run(self) -> TextArtifact: - return TextArtifact(self.input[0].to_text()) diff --git a/tests/unit/tasks/test_base_multi_text_input_task.py b/tests/unit/tasks/test_base_multi_text_input_task.py deleted file mode 100644 index 8eaa832ae..000000000 --- a/tests/unit/tasks/test_base_multi_text_input_task.py +++ /dev/null @@ -1,56 +0,0 @@ -from griptape.artifacts import TextArtifact -from griptape.structures import Pipeline -from tests.mocks.mock_multi_text_input_task import MockMultiTextInputTask - - -class TestBaseMultiTextInputTask: - def test_string_input(self): - assert MockMultiTextInputTask(("foobar", "bazbar")).input[0].value == "foobar" - assert MockMultiTextInputTask(("foobar", "bazbar")).input[1].value == "bazbar" - - task = MockMultiTextInputTask() - task.input = ("foobar", "bazbar") - assert task.input[0].value == "foobar" - assert task.input[1].value == "bazbar" - - def test_artifact_input(self): - assert MockMultiTextInputTask((TextArtifact("foobar"), TextArtifact("bazbar"))).input[0].value == "foobar" - assert MockMultiTextInputTask((TextArtifact("foobar"), TextArtifact("bazbar"))).input[1].value == "bazbar" - - task = MockMultiTextInputTask() - task.input = (TextArtifact("foobar"), TextArtifact("bazbar")) - assert task.input[0].value == "foobar" - assert task.input[1].value == "bazbar" - - def test_callable_input(self): - assert ( - MockMultiTextInputTask((lambda _: TextArtifact("foobar"), lambda _: TextArtifact("bazbar"))).input[0].value - == "foobar" - ) - assert ( - MockMultiTextInputTask((lambda _: TextArtifact("foobar"), lambda _: TextArtifact("bazbar"))).input[1].value - == "bazbar" - ) - - task = MockMultiTextInputTask() - task.input = (lambda _: TextArtifact("foobar"), lambda _: TextArtifact("bazbar")) - assert task.input[0].value == "foobar" - assert task.input[1].value == "bazbar" - - def test_full_context(self): - parent = MockMultiTextInputTask(("parent1", "parent2")) - subtask = MockMultiTextInputTask(("test1", "test2"), context={"foo": "bar"}) - child = MockMultiTextInputTask(("child2", "child2")) - pipeline = Pipeline() - - pipeline.add_tasks(parent, subtask, child) - - pipeline.run() - - context = subtask.full_context - - assert context["foo"] == "bar" - assert context["parent_output"] == parent.output.to_text() - assert context["structure"] == pipeline - assert context["parent"] == parent - assert context["child"] == child