diff --git a/.coveragerc b/.coveragerc index a55d8962f..6112755cc 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,28 +1,15 @@ [run] -data_file = coverage/.coverage branch = True [report] -; Regexes for lines to exclude from consideration exclude_also = - ; Don't complain about missing debug-only code: - def __repr__ - if self\.debug - - ; Don't complain if tests don't hit defensive assertion code: - raise AssertionError - raise NotImplementedError - - ; Don't complain if non-runnable code isn't run: - if 0: - if __name__ == .__main__.: - if TYPE_CHECKING: - - ; Don't complain about abstract methods, they aren't run: - @(abc\.)?abstractmethod - -[html] -directory = coverage/html - -[xml] -output = coverage/coverage.xml + def __repr__ + if self.debug: + if settings.DEBUG + raise AssertionError + raise NotImplementedError + if 0: + if __name__ == .__main__.: + if TYPE_CHECKING: + class .*\bProtocol\): + @(abc\.)?abstractmethod diff --git a/.gitignore b/.gitignore index ba63b4aca..c71d930a9 100644 --- a/.gitignore +++ b/.gitignore @@ -16,7 +16,6 @@ __pycache__ npm-debug.log **/.mypy_cache/** !yarn.lock -coverage/ cucumber-report.json **/.vscode-test/** **/.vscode test/** @@ -57,3 +56,7 @@ dist/** # mkdocs build output site reference + +# coverage.py +htmlcov/ +coverage.* diff --git a/CHANGELOG.md b/CHANGELOG.md index d7698127a..bf9ea5b22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,77 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +## [0.27.0] - 2024-06-19 + +### Added +- `BaseTask.add_child()` to add a child task to a parent task. +- `BaseTask.add_children()` to add multiple child tasks to a parent task. +- `BaseTask.add_parent()` to add a parent task to a child task. +- `BaseTask.add_parents()` to add multiple parent tasks to a child task. +- `Structure.resolve_relationships()` to resolve asymmetrically defined parent/child relationships. In other words, if a parent declares a child, but the child does not declare the parent, the parent will automatically be added as a parent of the child when running this method. The method is invoked automatically by `Structure.before_run()`. +- `CohereEmbeddingDriver` for using Cohere's embeddings API. +- `CohereStructureConfig` for providing Structures with quick Cohere configuration. +- `AmazonSageMakerJumpstartPromptDriver.inference_component_name` for setting the `InferenceComponentName` parameter when invoking an endpoint. +- `AmazonSageMakerJumpstartEmbeddingDriver.inference_component_name` for setting the `InferenceComponentName` parameter when invoking an endpoint. +- `AmazonSageMakerJumpstartEmbeddingDriver.custom_attributes` for setting custom attributes when invoking an endpoint. +- `ToolkitTask.response_stop_sequence` for overriding the default Chain of Thought stop sequence. +- `griptape.utils.StructureVisualizer` for visualizing Workflow structures with Mermaid.js +- `BaseTask.parents_outputs` to get the textual output of all parent tasks. +- `BaseTask.parents_output_text` to get a concatenated string of all parent tasks' outputs. +- `parents_output_text` to Workflow context. +- `OllamaPromptModelDriver` for using models with Ollama. +- Parameter `output` on `Structure` as a convenience for `output_task.output` + +### Changed +- **BREAKING**: `Workflow` no longer modifies task relationships when adding tasks via `tasks` init param, `add_tasks()` or `add_task()`. Previously, adding a task would automatically add the previously added task as its parent. Existing code that relies on this behavior will need to be updated to explicitly add parent/child relationships using the API offered by `BaseTask`. +- **BREAKING**: Removed `AmazonBedrockPromptDriver.prompt_model_driver` as it is no longer needed with the `AmazonBedrockPromptDriver` Converse API implementation. +- **BREAKING**: Removed `BedrockClaudePromptModelDriver`. +- **BREAKING**: Removed `BedrockJurassicPromptModelDriver`. +- **BREAKING**: Removed `BedrockLlamaPromptModelDriver`. +- **BREAKING**: Removed `BedrockTitanPromptModelDriver`. +- **BREAKING**: Removed `BedrockClaudeTokenizer`, use `SimpleTokenizer` instead. +- **BREAKING**: Removed `BedrockJurassicTokenizer`, use `SimpleTokenizer` instead. +- **BREAKING**: Removed `BedrockLlamaTokenizer`, use `SimpleTokenizer` instead. +- **BREAKING**: Removed `BedrockTitanTokenizer`, use `SimpleTokenizer` instead. +- **BREAKING**: Removed `OpenAiChatCompletionPromptDriver` as it uses the legacy [OpenAi Completions API](https://platform.openai.com/docs/api-reference/completions). +- **BREAKING**: Removed `BasePromptDriver.count_tokens()`. +- **BREAKING**: Removed `BasePromptDriver.max_output_tokens()`. +- **BREAKING**: Moved/renamed `PromptStack.add_to_conversation_memory` to `BaseConversationMemory.add_to_prompt_stack`. +- **BREAKING**: Moved `griptape.constants.RESPONSE_STOP_SEQUENCE` to `ToolkitTask`. +- **BREAKING**: Renamed `AmazonSagemakerPromptDriver` to `AmazonSageMakerJumpstartPromptDriver`. +- **BREAKING**: Removed `SagemakerFalconPromptModelDriver`, use `AmazonSageMakerJumpstartPromptDriver` instead. +- **BREAKING**: Removed `SagemakerLlamaPromptModelDriver`, use `AmazonSageMakerJumpstartPromptDriver` instead. +- **BREAKING**: Renamed `AmazonSagemakerEmbeddingDriver` to `AmazonSageMakerJumpstartEmbeddingDriver`. +- **BREAKING**: Removed `SagemakerHuggingfaceEmbeddingModelDriver`, use `AmazonSageMakerJumpstartEmbeddingDriver` instead. +- **BREAKING**: Removed `SagemakerTensorflowHubEmbeddingModelDriver`, use `AmazonSageMakerJumpstartEmbeddingDriver` instead. +- **BREAKING**: `AmazonSageMakerJumpstartPromptDriver.model` parameter, which gets passed to `SageMakerRuntime.Client.invoke_endpoint` as `EndpointName`, is now renamed to `AmazonSageMakerPromptDriver.endpoint`. +- **BREAKING**: Removed parameter `template_generator` on `PromptSummaryEngine` and added parameters `system_template_generator` and `user_template_generator`. +- **BREAKING**: Removed template `engines/summary/prompt_summary.j2` and added templates `engines/summary/system.j2` and `engines/summary/user.j2`. +- `ToolkitTask.RESPONSE_STOP_SEQUENCE` is now only added when using `ToolkitTask`. +- Updated Prompt Drivers to use `BasePromptDriver.max_tokens` instead of using `BasePromptDriver.max_output_tokens()`. +- Improved error message when `GriptapeCloudKnowledgeBaseClient` does not have a description set. +- Updated `AmazonBedrockPromptDriver` to use [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html). +- `Structure.before_run()` now automatically resolves asymmetrically defined parent/child relationships using the new `Structure.resolve_relationships()`. +- Updated `HuggingFaceHubPromptDriver` to use `transformers`'s `apply_chat_template`. +- Updated `HuggingFacePipelinePromptDriver` to use chat features of `transformers.TextGenerationPipeline`. +- Updated `CoherePromptDriver` to use Cohere's latest SDK. +- Moved Task reset logic for all Structures to `Structure.before_run`. +- Updated default prompt templates for `PromptSummaryEngine`. +- Updated template `templates/tasks/tool_task/system.j2`. + +### Fixed +- `Workflow.insert_task()` no longer inserts duplicate tasks when given multiple parent tasks. +- Performance issue in `OpenAiChatPromptDriver` when extracting unused rate-limiting headers. +- Streaming not working when using deprecated `Structure.stream` field. +- Raw Tool output being lost when being executed by ActionsSubtask. +- Re-order Workflow tasks on every task execution wave. +- Web Loader to catch Exceptions and properly return an ErrorArtifact. +- Conversation Memory entry only added if `output_task.output` is not `None` on all `Structures` +- `TextArtifacts` contained in `ListArtifact` returned by `WebSearch.search` to properly formatted stringified JSON. +- Structure run args not being set immediately. +- Input and output logging in BaseAudioInputTasks and BaseAudioGenerationTasks +- Validation of `max_tokens` < 0 on `BaseChunker` + ## [0.26.0] - 2024-06-04 ### Added @@ -15,7 +86,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `AudioTranscriptionTask` and `AudioTranscriptionClient` for transcribing audio content in Structures. - `OpenAiAudioTranscriptionDriver` for integration with OpenAI's speech-to-text models, including Whisper. - Parameter `env` to `BaseStructureRunDriver` to set environment variables for a Structure Run. -- `PusherEventListenerDriver` to enable sending of framework events over a Pusher WebSocket. +- `PusherEventListenerDriver` to enable sending of framework events over a Pusher WebSocket. ### Changed - **BREAKING**: Removed `StructureConfig.global_drivers`. Pass Drivers directly to the Structure Config instead. diff --git a/Makefile b/Makefile index 61d91e6a1..b4f3d0068 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ version: ## Bump version and push to release branch. .PHONY: publish publish: ## Push git tag and publish version to PyPI. - @git tag v$$(poetry version -s) + @git tag $$(poetry version -s) @git push --tags @poetry build @poetry publish diff --git a/docs/examples/amazon-dynamodb-sessions.md b/docs/examples/amazon-dynamodb-sessions.md index 279cced2a..aa4050ab9 100644 --- a/docs/examples/amazon-dynamodb-sessions.md +++ b/docs/examples/amazon-dynamodb-sessions.md @@ -1,8 +1,8 @@ Griptape provides [Conversation Memory](../griptape-framework/structures/conversation-memory.md) as a means of persisting conversation context across multiple Structure runs. If you provide it with a suitable Driver, the memory of the previous conversation can be preserved between run of a Structure, giving it additional context for how to respond. -While we can use the [LocalConversationMemoryDriver](../griptape-framework/drivers/conversation-memory-drivers.md#localconversationmemorydriver) to store the conversation history in a local file, this may not be suitable for production use cases. +While we can use the [LocalConversationMemoryDriver](../griptape-framework/drivers/conversation-memory-drivers.md#local) to store the conversation history in a local file, in production use-cases we may want to store in a proper database. -In this example, we will show you how to use the [AmazonDynamoDbConversationMemoryDriver](../griptape-framework/drivers/conversation-memory-drivers.md#amazondynamodbconversationmemorydriver) to persist the memory in an [Amazon DynamoDB](https://aws.amazon.com/dynamodb/) table. Please refer to the [Amazon DynamoDB documentation](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/getting-started-step-1.html) for information on setting up DynamoDB. +In this example, we will show you how to use the [AmazonDynamoDbConversationMemoryDriver](../griptape-framework/drivers/conversation-memory-drivers.md#amazon-dynamodb) to persist the memory in an [Amazon DynamoDB](https://aws.amazon.com/dynamodb/) table. Please refer to the [Amazon DynamoDB documentation](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/getting-started-step-1.html) for information on setting up DynamoDB. This code implements the idea of a generic "Session" that represents a Conversation Memory entry. For example, a "Session" could be used to represent an individual user's conversation, or a group conversation thread. diff --git a/docs/examples/multi-agent-workflow.md b/docs/examples/multi-agent-workflow.md index 044f4abad..d0fa65232 100644 --- a/docs/examples/multi-agent-workflow.md +++ b/docs/examples/multi-agent-workflow.md @@ -1,7 +1,7 @@ In this example we implement a multi-agent Workflow. We have a single "Researcher" Agent that conducts research on a topic, and then fans out to multiple "Writer" Agents to write blog posts based on the research. By splitting up our workloads across multiple Structures, we can parallelize the work and leverage the strengths of each Agent. The Researcher can focus on gathering data and insights, while the Writers can focus on crafting engaging narratives. -Additionally, this architecture opens us up to using services such as [Griptape Cloud](https://www.griptape.ai/cloud) to have each Agent run on a separate machine, allowing us to scale our Workflow as needed 🤯. +Additionally, this architecture opens us up to using services such as [Griptape Cloud](https://www.griptape.ai/cloud) to have each Agent run completely independently, allowing us to scale our Workflow as needed 🤯. To try out how this would work, you can deploy this example as multiple structures from our [Sample Structures](https://github.com/griptape-ai/griptape-sample-structures/tree/main/griptape-multi-agent-workflows) repo. ```python @@ -155,35 +155,33 @@ if __name__ == "__main__": ), ), ) - end_task = team.add_task( - PromptTask( - 'State "All Done!"', - ) - ) - team.insert_tasks( - research_task, - [ - StructureRunTask( - ( - """Using insights provided, develop an engaging blog + writer_tasks = team.add_tasks(*[ + StructureRunTask( + ( + """Using insights provided, develop an engaging blog post that highlights the most significant AI advancements. Your post should be informative yet accessible, catering to a tech-savvy audience. Make it sound cool, avoid complex words so it doesn't sound like AI. Insights: {{ parent_outputs["research"] }}""", - ), - driver=LocalStructureRunDriver( - structure_factory_fn=lambda: build_writer( - role=writer["role"], - goal=writer["goal"], - backstory=writer["backstory"], - ) - ), - ) - for writer in WRITERS - ], - end_task, + ), + driver=LocalStructureRunDriver( + structure_factory_fn=lambda: build_writer( + role=writer["role"], + goal=writer["goal"], + backstory=writer["backstory"], + ) + ), + parent_ids=[research_task.id], + ) + for writer in WRITERS + ]) + end_task = team.add_task( + PromptTask( + 'State "All Done!"', + parent_ids=[writer_task.id for writer_task in writer_tasks], + ) ) team.run() diff --git a/docs/griptape-framework/data/artifacts.md b/docs/griptape-framework/data/artifacts.md index 18d5cb61f..5b69b120a 100644 --- a/docs/griptape-framework/data/artifacts.md +++ b/docs/griptape-framework/data/artifacts.md @@ -14,7 +14,7 @@ and access it with [embedding](../../reference/griptape/artifacts/text_artifact. ## CsvRowArtifact Used for passing structured row data around the framework. It inherits from [TextArtifact](../../reference/griptape/artifacts/text_artifact.md) and overrides the -[to_text()](../../reference/griptape/artifacts/text_artifact.md#griptape.artifacts.text_artifact.TextArtifact.to_text) method, which always returns a valid CSV row. +[to_text()](../../reference/griptape/artifacts/csv_row_artifact.md#griptape.artifacts.csv_row_artifact.CsvRowArtifact.to_text) method, which always returns a valid CSV row. ## InfoArtifact @@ -29,7 +29,7 @@ Used for passing errors back to the LLM without task memory storing them. Used for passing binary large objects (blobs) back to the LLM. Treat it as a way to return unstructured data, such as images, videos, audio, and other files back from tools. Each blob has a [name](../../reference/griptape/artifacts/base_artifact.md#griptape.artifacts.base_artifact.BaseArtifact.name) and -[dir](../../reference/griptape/artifacts/blob_artifact.md#griptape.artifacts.blob_artifact.BlobArtifact.dir) to uniquely identify stored objects. +[dir](../../reference/griptape/artifacts/blob_artifact.md#griptape.artifacts.blob_artifact.BlobArtifact.dir_name) to uniquely identify stored objects. [TaskMemory](../../reference/griptape/memory/task/task_memory.md) automatically stores [BlobArtifact](../../reference/griptape/artifacts/blob_artifact.md)s returned by tool activities that can be reused by other tools. diff --git a/docs/griptape-framework/data/loaders.md b/docs/griptape-framework/data/loaders.md index 69ffeda06..dbc578ccd 100644 --- a/docs/griptape-framework/data/loaders.md +++ b/docs/griptape-framework/data/loaders.md @@ -4,7 +4,7 @@ Loaders are used to load textual data from different sources and chunk it into [ Each loader can be used to load a single "document" with [load()](../../reference/griptape/loaders/base_loader.md#griptape.loaders.base_loader.BaseLoader.load) or multiple documents with [load_collection()](../../reference/griptape/loaders/base_loader.md#griptape.loaders.base_loader.BaseLoader.load_collection). -## Pdf Loader +## PDF !!! info This driver requires the `loaders-pdf` [extra](../index.md#extras). @@ -33,7 +33,7 @@ with open("attention.pdf", "rb") as attention, open("CoT.pdf", "rb") as cot: PdfLoader().load_collection(list(load_files(["attention.pdf", "CoT.pdf"]).values())) ``` -## Sql Loader +## SQL Can be used to load data from a SQL database into [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md)s: @@ -54,7 +54,7 @@ SqlLoader( ).load_collection(["SELECT 'foo', 'bar';", "SELECT 'fizz', 'buzz';"]) ``` -## Csv Loader +## CSV Can be used to load CSV files into [CsvRowArtifact](../../reference/griptape/artifacts/csv_row_artifact.md)s: @@ -76,7 +76,7 @@ CsvLoader().load_collection(list(load_files(["tests/resources/cities.csv", "test ``` -## DataFrame Loader +## DataFrame !!! info This driver requires the `loaders-dataframe` [extra](../index.md#extras). @@ -100,7 +100,7 @@ DataFrameLoader().load_collection( ``` -## Text Loader +## Text Used to load arbitrary text and text files: @@ -124,9 +124,9 @@ with open("example.txt", "r") as f: ) ``` -You can set a custom [tokenizer](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.tokenizer.md), [max_tokens](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.max_tokens.md) parameter, and [chunker](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.chunker.md). +You can set a custom [tokenizer](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.tokenizer), [max_tokens](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.max_tokens) parameter, and [chunker](../../reference/griptape/loaders/text_loader.md#griptape.loaders.text_loader.TextLoader.chunker). -## Web Loader +## Web !!! info This driver requires the `loaders-web` [extra](../index.md#extras). @@ -145,7 +145,7 @@ WebLoader().load_collection( ) ``` -## Image Loader +## Image !!! info This driver requires the `loaders-image` [extra](../index.md#extras). @@ -183,7 +183,7 @@ ImageLoader().load_collection(list(load_files(["tests/resources/mountain.png", " ``` -## Email Loader +## Email !!! info This driver requires the `loaders-email` [extra](../index.md#extras). @@ -200,7 +200,7 @@ loader.load(EmailLoader.EmailQuery(label="INBOX")) loader.load_collection([EmailLoader.EmailQuery(label="INBOX"), EmailLoader.EmailQuery(label="SENT")]) ``` -## Audio Loader +## Audio !!! info This driver requires the `loaders-audio` [extra](../index.md#extras). diff --git a/docs/griptape-framework/drivers/conversation-memory-drivers.md b/docs/griptape-framework/drivers/conversation-memory-drivers.md index 4c3de1e65..2ca5f8dd3 100644 --- a/docs/griptape-framework/drivers/conversation-memory-drivers.md +++ b/docs/griptape-framework/drivers/conversation-memory-drivers.md @@ -2,7 +2,7 @@ You can persist and load memory by using Conversation Memory Drivers. You can build drivers for your own data stores by extending [BaseConversationMemoryDriver](../../reference/griptape/drivers/memory/conversation/base_conversation_memory_driver.md). -### LocalConversationMemoryDriver +### Local The [LocalConversationMemoryDriver](../../reference/griptape/drivers/memory/conversation/local_conversation_memory_driver.md) allows you to persist Conversation Memory in a local JSON file. @@ -18,7 +18,7 @@ agent.run("Surfing is my favorite sport.") agent.run("What is my favorite sport?") ``` -### AmazonDynamoDbConversationMemoryDriver +### Amazon DynamoDb !!! info This driver requires the `drivers-memory-conversation-amazon-dynamodb` [extra](../index.md#extras). @@ -47,7 +47,7 @@ agent.run("What is my name?") ``` -### Redis Conversation Memory Driver +### Redis !!! info This driver requires the `drivers-memory-conversation-redis` [extra](../index.md#extras). diff --git a/docs/griptape-framework/drivers/embedding-drivers.md b/docs/griptape-framework/drivers/embedding-drivers.md index 49dfde4a5..3f0135ac3 100644 --- a/docs/griptape-framework/drivers/embedding-drivers.md +++ b/docs/griptape-framework/drivers/embedding-drivers.md @@ -10,7 +10,7 @@ You can optionally provide a [Tokenizer](../misc/tokenizers.md) via the [tokeniz ## Embedding Drivers -### OpenAI Embeddings +### OpenAI The [OpenAiEmbeddingDriver](../../reference/griptape/drivers/embedding/openai_embedding_driver.md) uses the [OpenAI Embeddings API](https://platform.openai.com/docs/guides/embeddings). @@ -27,12 +27,12 @@ print(embeddings[:3]) [0.0017853748286142945, 0.006118456833064556, -0.005811543669551611] ``` -### Azure OpenAI Embeddings +### Azure OpenAI The [AzureOpenAiEmbeddingDriver](../../reference/griptape/drivers/embedding/azure_openai_embedding_driver.md) uses the same parameters as [OpenAiEmbeddingDriver](../../reference/griptape/drivers/embedding/openai_embedding_driver.md) with updated defaults. -### Bedrock Titan Embeddings +### Bedrock Titan !!! info This driver requires the `drivers-embedding-amazon-bedrock` [extra](../index.md#extras). @@ -51,7 +51,7 @@ print(embeddings[:3]) [-0.234375, -0.024902344, -0.14941406] ``` -### Google Embeddings +### Google !!! info This driver requires the `drivers-embedding-google` [extra](../index.md#extras). @@ -69,7 +69,7 @@ print(embeddings[:3]) [0.0588633, 0.0033929371, -0.072810836] ``` -### Hugging Face Hub Embeddings +### Hugging Face Hub !!! info This driver requires the `drivers-embedding-huggingface` [extra](../index.md#extras). @@ -88,10 +88,8 @@ driver = HuggingFaceHubEmbeddingDriver( api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], model="sentence-transformers/all-MiniLM-L6-v2", tokenizer=HuggingFaceTokenizer( + model="sentence-transformers/all-MiniLM-L6-v2", max_output_tokens=512, - tokenizer=AutoTokenizer.from_pretrained( - "sentence-transformers/all-MiniLM-L6-v2" - ) ), ) @@ -99,27 +97,21 @@ embeddings = driver.embed_string("Hello world!") # display the first 3 embeddings print(embeddings[:3]) -``` -### Multi Model Embedding Drivers -Certain embeddings providers such as Amazon SageMaker support many types of models, each with their own slight differences in parameters and response formats. To support this variation across models, these Embedding Drivers takes a [Embedding Model Driver](../../reference/griptape/drivers/embedding_model/base_embedding_model_driver.md) -through the [embedding_model_driver](../../reference/griptape/drivers/embedding/base_multi_model_embedding_driver.md#griptape.drivers.embedding.base_multi_model_embedding_driver.BaseMultiModelEmbeddingDriver.embedding_model_driver) parameter. -[Embedding Model Driver](../../reference/griptape/drivers/embedding_model/base_embedding_model_driver.md)s allows for model-specific customization for Embedding Drivers. -#### SageMaker Embeddings +``` +### Amazon SageMaker Jumpstart -The [AmazonSageMakerEmbeddingDriver](../../reference/griptape/drivers/embedding/amazon_sagemaker_embedding_driver.md) uses the [Amazon SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) to generate embeddings on AWS. +The [AmazonSageMakerJumpstartEmbeddingDriver](../../reference/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.md) uses the [Amazon SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) to generate embeddings on AWS. !!! info This driver requires the `drivers-embedding-amazon-sagemaker` [extra](../index.md#extras). -##### TensorFlow Hub Models ```python title="PYTEST_IGNORE" import os -from griptape.drivers import AmazonSageMakerEmbeddingDriver, SageMakerTensorFlowHubEmbeddingModelDriver +from griptape.drivers import AmazonSageMakerJumpstartEmbeddingDriver, SageMakerTensorFlowHubEmbeddingModelDriver -driver = AmazonSageMakerEmbeddingDriver( +driver = AmazonSageMakerJumpstartEmbeddingDriver( model=os.environ["SAGEMAKER_TENSORFLOW_HUB_MODEL"], - embedding_model_driver=SageMakerTensorFlowHubEmbeddingModelDriver(), ) embeddings = driver.embed_string("Hello world!") @@ -128,14 +120,18 @@ embeddings = driver.embed_string("Hello world!") print(embeddings[:3]) ``` -##### HuggingFace Models -```python title="PYTEST_IGNORE" +### VoyageAI +The [VoyageAiEmbeddingDriver](../../reference/griptape/drivers/embedding/voyageai_embedding_driver.md) uses the [VoyageAI Embeddings API](https://www.voyageai.com/). + +!!! info + This driver requires the `drivers-embedding-voyageai` [extra](../index.md#extras). + +```python import os -from griptape.drivers import AmazonSageMakerEmbeddingDriver, SageMakerHuggingFaceEmbeddingModelDriver +from griptape.drivers import VoyageAiEmbeddingDriver -driver = AmazonSageMakerEmbeddingDriver( - model=os.environ["SAGEMAKER_HUGGINGFACE_MODEL"], - embedding_model_driver=SageMakerHuggingFaceEmbeddingModelDriver(), +driver = VoyageAiEmbeddingDriver( + api_key=os.environ["VOYAGE_API_KEY"] ) embeddings = driver.embed_string("Hello world!") @@ -144,21 +140,24 @@ embeddings = driver.embed_string("Hello world!") print(embeddings[:3]) ``` -### VoyageAI Embeddings -The [VoyageAiEmbeddingDriver](../../reference/griptape/drivers/embedding/voyageai_embedding_driver.md) uses the [VoyageAI Embeddings API](https://www.voyageai.com/). +### Cohere + +The [CohereEmbeddingDriver](../../reference/griptape/drivers/embedding/cohere_embedding_driver.md) uses the [Cohere Embeddings API](https://docs.cohere.com/docs/embeddings). !!! info - This driver requires the `drivers-embedding-voyageai` [extra](../index.md#extras). + This driver requires the `drivers-embedding-cohere` [extra](../index.md#extras). ```python import os -from griptape.drivers import VoyageAiEmbeddingDriver +from griptape.drivers import CohereEmbeddingDriver -driver = VoyageAiEmbeddingDriver( - api_key=os.environ["VOYAGE_API_KEY"] +embedding_driver=CohereEmbeddingDriver( + model="embed-english-v3.0", + api_key=os.environ["COHERE_API_KEY"], + input_type="search_document", ) -embeddings = driver.embed_string("Hello world!") +embeddings = embedding_driver.embed_string("Hello world!") # display the first 3 embeddings print(embeddings[:3]) diff --git a/docs/griptape-framework/drivers/event-listener-drivers.md b/docs/griptape-framework/drivers/event-listener-drivers.md index 6e8f59b22..4a85bc9a4 100644 --- a/docs/griptape-framework/drivers/event-listener-drivers.md +++ b/docs/griptape-framework/drivers/event-listener-drivers.md @@ -66,7 +66,7 @@ event_driver.publish_event(done_event) Griptape offers the following Event Listener Drivers for forwarding Griptape Events. -### Amazon SQS Event Listener Driver +### Amazon SQS !!! info This driver requires the `drivers-event-listener-amazon-sqs` [extra](../index.md#extras). @@ -108,7 +108,7 @@ agent.run( ) ``` -### AWS IoT Event Listener Driver +### AWS IoT !!! info This driver requires the `drivers-event-listener-amazon-iot` [extra](../index.md#extras). @@ -152,7 +152,7 @@ agent = Agent( agent.run("I want to fly from Orlando to Boston") ``` -### Griptape Cloud Event Listener Driver +### Griptape Cloud The [GriptapeCloudEventListenerDriver](../../reference/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.md) sends Events to [Griptape Cloud](https://www.griptape.ai/cloud). @@ -212,7 +212,7 @@ agent = Agent( agent.run("Analyze the pros and cons of remote work vs. office work") ``` -### Pusher Event Listener Driver +### Pusher !!! info This driver requires the `drivers-event-listener-pusher` [extra](../index.md#extras). diff --git a/docs/griptape-framework/drivers/image-generation-drivers.md b/docs/griptape-framework/drivers/image-generation-drivers.md index 7389ff711..25572ba9c 100644 --- a/docs/griptape-framework/drivers/image-generation-drivers.md +++ b/docs/griptape-framework/drivers/image-generation-drivers.md @@ -27,7 +27,7 @@ agent.run("Generate a watercolor painting of a dog riding a skateboard") The [Amazon Bedrock Image Generation Driver](../../reference/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.md) provides multi-model access to image generation models hosted by Amazon Bedrock. This Driver manages API calls to the Bedrock API, while the specific Model Drivers below format the API requests and parse the responses. -#### Bedrock Stable Diffusion Model Driver +#### Stable Diffusion The [Bedrock Stable Diffusion Model Driver](../../reference/griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.md) provides support for Stable Diffusion models hosted by Amazon Bedrock. This Model Driver supports configurations specific to Stable Diffusion, like style presets, clip guidance presets, and sampler. @@ -58,7 +58,7 @@ agent = Agent(tools=[ agent.run("Generate an image of a dog riding a skateboard") ``` -#### Bedrock Titan Image Generator Model Driver +#### Titan The [Bedrock Titan Image Generator Model Driver](../../reference/griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.md) provides support for Titan Image Generator models hosted by Amazon Bedrock. This Model Driver supports configurations specific to Titan Image Generator, like quality, seed, and cfg_scale. diff --git a/docs/griptape-framework/drivers/image-query-drivers.md b/docs/griptape-framework/drivers/image-query-drivers.md index d795838d4..8003924b1 100644 --- a/docs/griptape-framework/drivers/image-query-drivers.md +++ b/docs/griptape-framework/drivers/image-query-drivers.md @@ -1,11 +1,11 @@ # Image Query Drivers -Image Query Drivers are used by [Image Query Engines](../engines/image-query-engines.md) to execute natural language queries on the contents of images. You can specify the provider and model used to query the image by providing the Engine with a particular Image Query Driver. +Image Query Drivers are used by [Image Query Engines](../engines/query-engines.md#image) to execute natural language queries on the contents of images. You can specify the provider and model used to query the image by providing the Engine with a particular Image Query Driver. !!! info All Image Query Drivers default to a `max_tokens` of 256. It is recommended that you set this value to correspond to the desired response length. -## AnthropicImageQueryDriver +## Anthropic !!! info To tune `max_tokens`, see [Anthropic's documentation on image tokens](https://docs.anthropic.com/claude/docs/vision#image-costs) for more information on how to relate token count to response length. @@ -59,7 +59,7 @@ result = engine.run("Describe the weather in the image", [image_artifact1, image print(result) ``` -## OpenAiVisionImageQueryDriver +## OpenAI !!! info While the `max_tokens` field is optional, it is recommended to set this to a value that corresponds to the desired response length. Without an explicit value, the model will default to very short responses. See [OpenAI's documentation](https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them) for more information on how to relate token count to response length. @@ -86,7 +86,7 @@ with open("tests/resources/mountain.png", "rb") as f: engine.run("Describe the weather in the image", [image_artifact]) ``` -## AzureOpenAiVisionImageQueryDriver +## Azure OpenAI !!! info In order to use the `gpt-4-vision-preview` model on Azure OpenAI, the `gpt-4` model must be deployed with the version set to `vision-preview`. More information can be found in the [Azure OpenAI documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/gpt-with-vision). @@ -117,7 +117,7 @@ with open("tests/resources/mountain.png", "rb") as f: engine.run("Describe the weather in the image", [image_artifact]) ``` -## AmazonBedrockImageQueryDriver +## Amazon Bedrock The [Amazon Bedrock Image Query Driver](../../reference/griptape/drivers/image_query/amazon_bedrock_image_query_driver.md) provides multi-model access to image query models hosted by Amazon Bedrock. This Driver manages API calls to the Bedrock API, while the specific Model Drivers below format the API requests and parse the responses. diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 1006f215f..0100ccbac 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -125,38 +125,6 @@ agent = Agent( agent.run("Artificial intelligence is a technology with great promise.") ``` -### Azure OpenAI Completion - -The [AzureOpenAiCompletionPromptDriver](../../reference/griptape/drivers/prompt/azure_openai_completion_prompt_driver.md) connects to Azure OpenAI [Text Completion](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference) API. - -```python -import os -from griptape.structures import Agent -from griptape.drivers import AzureOpenAiCompletionPromptDriver -from griptape.config import StructureConfig - -agent = Agent( - config=StructureConfig( - prompt_driver=AzureOpenAiCompletionPromptDriver( - api_key=os.environ["AZURE_OPENAI_API_KEY_1"], - model="text-davinci-003", - azure_deployment=os.environ["AZURE_OPENAI_DAVINCI_DEPLOYMENT_ID"], - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], - temperature=1 - ) - ) -) - -agent.run( - """ - Write a product launch email for new AI-powered headphones that are priced at $79.99 and available at Best Buy, Target and Amazon.com. The target audience is tech-savvy music lovers and the tone is friendly and exciting. - - 1. What should be the subject line of the email? - 2. What should be the body of the email? - """ -) -``` - ### Cohere The [CoherePromptDriver](../../reference/griptape/drivers/prompt/cohere_prompt_driver.md) connects to the Cohere [Generate](https://docs.cohere.ai/reference/generate) API. @@ -232,57 +200,96 @@ agent = Agent( agent.run('Briefly explain how a computer works to a young child.') ``` +### Amazon Bedrock + +!!! info + This driver requires the `drivers-prompt-amazon-bedrock` [extra](../index.md#extras). + +The [AmazonBedrockPromptDriver](../../reference/griptape/drivers/prompt/amazon_bedrock_prompt_driver.md) uses [Amazon Bedrock](https://aws.amazon.com/bedrock/)'s [Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html). + +All models supported by the Converse API are available for use with this driver. + +```python +from griptape.structures import Agent +from griptape.drivers import AmazonBedrockPromptDriver +from griptape.rules import Rule +from griptape.config import StructureConfig + +agent = Agent( + config=StructureConfig( + prompt_driver=AmazonBedrockPromptDriver( + model="anthropic.claude-3-sonnet-20240229-v1:0", + ) + ), + rules=[ + Rule( + value="You are a customer service agent that is classifying emails by type. I want you to give your answer and then explain it." + ) + ], +) +agent.run( + """How would you categorize this email? + + Can I use my Mixmaster 4000 to mix paint, or is it only meant for mixing food? + + + Categories are: + (A) Pre-sale question + (B) Broken or defective item + (C) Billing question + (D) Other (please explain)""" +) +``` + +### Ollama + +!!! info + This driver requires the `drivers-prompt-ollama` [extra](../index.md#extras). + +The [OllamaPromptDriver](../../reference/griptape/drivers/prompt/ollama_prompt_driver.md) connects to the [Ollama Chat Completion API](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion). + +```python +from griptape.config import StructureConfig +from griptape.drivers import OllamaPromptDriver +from griptape.structures import Agent + + +agent = Agent( + config=StructureConfig( + prompt_driver=OllamaPromptDriver( + model="llama3", + ), + ), +) +agent.run("What color is the sky at different times of the day?") +``` + ### Hugging Face Hub !!! info This driver requires the `drivers-prompt-huggingface` [extra](../index.md#extras). -The [HuggingFaceHubPromptDriver](../../reference/griptape/drivers/prompt/huggingface_hub_prompt_driver.md) connects to the [Hugging Face Hub API](https://huggingface.co/docs/hub/api). It supports models with the following tasks: +The [HuggingFaceHubPromptDriver](../../reference/griptape/drivers/prompt/huggingface_hub_prompt_driver.md) connects to the [Hugging Face Hub API](https://huggingface.co/docs/hub/api). -- text2text-generation -- text-generation !!! warning Not all models featured on the Hugging Face Hub are supported by this driver. Models that are not supported by [Hugging Face serverless inference](https://huggingface.co/docs/api-inference/en/index) will not work with this driver. Due to the limitations of Hugging Face serverless inference, only models that are than 10GB are supported. -!!! info - The `prompt_stack_to_string_converter` function is intended to convert a `PromptStack` to model specific input. You - should consult the model's documentation to determine the correct format. - -Let's recreate the [Falcon-7B-Instruct](https://huggingface.co/tiiuae/falcon-7b-instruct) example using Griptape: - ```python import os from griptape.structures import Agent from griptape.drivers import HuggingFaceHubPromptDriver from griptape.rules import Rule, Ruleset -from griptape.utils import PromptStack from griptape.config import StructureConfig -def prompt_stack_to_string_converter(prompt_stack: PromptStack) -> str: - prompt_lines = [] - - for i in prompt_stack.inputs: - if i.is_user(): - prompt_lines.append(f"User: {i.content}") - elif i.is_assistant(): - prompt_lines.append(f"Girafatron: {i.content}") - else: - prompt_lines.append(f"Instructions: {i.content}") - prompt_lines.append("Girafatron:") - - return "\n".join(prompt_lines) - - agent = Agent( config=StructureConfig( prompt_driver=HuggingFaceHubPromptDriver( - model="tiiuae/falcon-7b-instruct", + model="HuggingFaceH4/zephyr-7b-beta", api_token=os.environ["HUGGINGFACE_HUB_ACCESS_TOKEN"], - prompt_stack_to_string=prompt_stack_to_string_converter, ) ), rulesets=[ @@ -294,7 +301,7 @@ agent = Agent( "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. " "Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe." ) - ] + ], ) ], ) @@ -330,262 +337,68 @@ agent.run("Write the code for a snake game.") !!! info This driver requires the `drivers-prompt-huggingface-pipeline` [extra](../index.md#extras). -The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally. It supports models with the following tasks: - -- text2text-generation -- text-generation +The [HuggingFacePipelinePromptDriver](../../reference/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.md) uses [Hugging Face Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for inference locally. !!! warning Running a model locally can be a computationally expensive process. ```python -import os from griptape.structures import Agent from griptape.drivers import HuggingFacePipelinePromptDriver from griptape.rules import Rule, Ruleset -from griptape.utils import PromptStack from griptape.config import StructureConfig -# Override the default Prompt Stack to string converter -# to format the prompt in a way that is easier for this model to understand. -def prompt_stack_to_string_converter(prompt_stack: PromptStack) -> str: - prompt_lines = [] - - for i in prompt_stack.inputs: - if i.is_user(): - prompt_lines.append(f"User: {i.content}") - elif i.is_assistant(): - prompt_lines.append(f"Girafatron: {i.content}") - else: - prompt_lines.append(f"Instructions: {i.content}") - prompt_lines.append("Girafatron:") - - return "\n".join(prompt_lines) - - agent = Agent( config=StructureConfig( prompt_driver=HuggingFacePipelinePromptDriver( - model="TinyLlama/TinyLlama-1.1B-Chat-v0.6", - prompt_stack_to_string=prompt_stack_to_string_converter, + model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", ) ), rulesets=[ Ruleset( - name="Girafatron", + name="Pirate", rules=[ Rule( - value="You are Girafatron, a giraffe-obsessed robot. You are talking to a human. " - "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. " - "Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe." + value="You are a pirate chatbot who always responds in pirate speak!" ) - ] + ], ) ], ) -agent.run("Hello Girafatron, what is your favorite animal?") +agent.run("How many helicopters can a human eat in one sitting?") ``` -### Multi Model Prompt Drivers -Certain LLM providers such as Amazon SageMaker and Amazon Bedrock supports many types of models, each with their own slight differences in prompt structure and parameters. To support this variation across models, these Prompt Drivers takes a [Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md) -through the [prompt_model_driver](../../reference/griptape/drivers/prompt/base_multi_model_prompt_driver.md#griptape.drivers.prompt.base_multi_model_prompt_driver.BaseMultiModelPromptDriver.prompt_model_driver) parameter. -[Prompt Model Driver](../../reference/griptape/drivers/prompt_model/base_prompt_model_driver.md)s allows for model-specific customization for Prompt Drivers. - - -#### Amazon SageMaker +### Amazon SageMaker Jumpstart !!! info This driver requires the `drivers-prompt-amazon-sagemaker` [extra](../index.md#extras). -The [AmazonSageMakerPromptDriver](../../reference/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.md) uses [Amazon SageMaker Endpoints](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) for inference on AWS. - -!!! info - For single model endpoints, the `model` parameter does not need to be specified. - For multi-model endpoints, the `model` parameter should be the inference component name. - -!!! warning - Make sure that the selected prompt model driver is compatible with the selected model. Note that even the same - logical model can require different prompt model drivers depending on how it is bundled in the endpoint. For - example, the reponse format are different for `Meta-Llama-3-8B-Instruct` when deployed via - "Amazon SageMaker JumpStart" and "Hugging Face on Amazon SageMaker". +The [AmazonSageMakerJumpstartPromptDriver](../../reference/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.md) uses [Amazon SageMaker Jumpstart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html) for inference on AWS. -##### Llama - -!!! info - `SageMakerLlamaPromptModelDriver` requires a tokenizer corresponding to a [Gated Model](https://huggingface.co/docs/hub/en/models-gated) on Hugging Face. +Amazon Sagemaker Jumpstart provides a wide range of models with varying capabilities. +This Driver has been primarily _chat-optimized_ models that have a [Huggingface Chat Template](https://huggingface.co/docs/transformers/en/chat_templating) available. +If your model does not fit this use-case, we suggest sub-classing [AmazonSageMakerJumpstartPromptDriver](../../reference/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.md) and overriding the `_to_model_input` and `_to_model_params` methods. - Make sure to request access to the [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model on Hugging Face and configure your environment for hugging face use. - -```python title="PYTEST_IGNORE" -import os -from griptape.structures import Agent -from griptape.drivers import ( - AmazonSageMakerPromptDriver, - SageMakerLlamaPromptModelDriver, -) -from griptape.rules import Rule -from griptape.config import StructureConfig - -agent = Agent( - config=StructureConfig( - prompt_driver=AmazonSageMakerPromptDriver( - endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], - model=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_INFERENCE_COMPONENT_NAME"], - prompt_model_driver=SageMakerLlamaPromptModelDriver(), - temperature=0.75, - ) - ), - rules=[ - Rule( - value="You are a helpful, respectful and honest assistant who is also a swarthy pirate." - "You only speak like a pirate and you never break character." - ) - ], -) - -agent.run("Hello!") -``` - -##### Falcon ```python title="PYTEST_IGNORE" import os from griptape.structures import Agent from griptape.drivers import ( - AmazonSageMakerPromptDriver, + AmazonSageMakerJumpstartPromptDriver, SageMakerFalconPromptModelDriver, ) from griptape.config import StructureConfig agent = Agent( config=StructureConfig( - prompt_driver=AmazonSageMakerPromptDriver( - endpoint=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"], - model=os.environ["SAGEMAKER_FALCON_INFERENCE_COMPONENT_NAME"], - prompt_model_driver=SageMakerFalconPromptModelDriver(), + prompt_driver=AmazonSageMakerJumpstartPromptDriver( + endpoint=os.environ["SAGEMAKER_LLAMA_3_INSTRUCT_ENDPOINT_NAME"], + model="meta-llama/Meta-Llama-3-8B-Instruct", ) ) ) agent.run("What is a good lasagna recipe?") - -``` - -#### Amazon Bedrock - -!!! info - This driver requires the `drivers-prompt-amazon-bedrock` [extra](../index.md#extras). - -The [AmazonBedrockPromptDriver](../../reference/griptape/drivers/prompt/amazon_bedrock_prompt_driver.md) uses [Amazon Bedrock](https://aws.amazon.com/bedrock/) for inference on AWS. - -##### Amazon Titan - -To use this model with Amazon Bedrock, use the [BedrockTitanPromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.md). - -```python -from griptape.structures import Agent -from griptape.drivers import AmazonBedrockPromptDriver, BedrockTitanPromptModelDriver -from griptape.config import StructureConfig - -agent = Agent( - config=StructureConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="amazon.titan-text-express-v1", - prompt_model_driver=BedrockTitanPromptModelDriver( - top_p=1, - ) - ) - ) -) -agent.run( - "Write an informational article for children about how birds fly." - "Compare how birds fly to how airplanes fly." - 'Make sure to use the word "Thrust" at least three times.' -) -``` - -##### Anthropic Claude - -To use this model with Amazon Bedrock, use the [BedrockClaudePromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.md). - -```python -from griptape.structures import Agent -from griptape.drivers import AmazonBedrockPromptDriver, BedrockClaudePromptModelDriver -from griptape.rules import Rule -from griptape.config import StructureConfig - -agent = Agent( - config=StructureConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-3-sonnet-20240229-v1:0", - prompt_model_driver=BedrockClaudePromptModelDriver( - top_p=1, - ) - ) - ), - rules=[ - Rule( - value="You are a customer service agent that is classifying emails by type. I want you to give your answer and then explain it." - ) - ], -) -agent.run( - """How would you categorize this email? - - Can I use my Mixmaster 4000 to mix paint, or is it only meant for mixing food? - - - Categories are: - (A) Pre-sale question - (B) Broken or defective item - (C) Billing question - (D) Other (please explain)""" -) -``` -##### Meta Llama 2 - -To use this model with Amazon Bedrock, use the [BedrockLlamaPromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.md). - -```python -from griptape.structures import Agent -from griptape.drivers import AmazonBedrockPromptDriver, BedrockLlamaPromptModelDriver -from griptape.config import StructureConfig - -agent = Agent( - config=StructureConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="meta.llama2-13b-chat-v1", - prompt_model_driver=BedrockLlamaPromptModelDriver(), - ) - ) -) -agent.run( - "Write an article about impact of high inflation to GDP of a country" -) -``` - -##### Ai21 Jurassic - -To use this model with Amazon Bedrock, use the [BedrockJurassicPromptModelDriver](../../reference/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.md). - -```python -from griptape.structures import Agent -from griptape.drivers import AmazonBedrockPromptDriver, BedrockJurassicPromptModelDriver -from griptape.config import StructureConfig - -agent = Agent( - config=StructureConfig( - prompt_driver=AmazonBedrockPromptDriver( - model="ai21.j2-ultra-v1", - prompt_model_driver=BedrockJurassicPromptModelDriver(top_p=0.95), - temperature=0.7, - ) - ) -) -agent.run( - "Suggest an outline for a blog post based on a title. " - "Title: How I put the pro in prompt engineering." -) ``` diff --git a/docs/griptape-framework/drivers/sql-drivers.md b/docs/griptape-framework/drivers/sql-drivers.md index 35377849a..c4b7dbcca 100644 --- a/docs/griptape-framework/drivers/sql-drivers.md +++ b/docs/griptape-framework/drivers/sql-drivers.md @@ -1,14 +1,14 @@ ## Overview SQL drivers can be used to make SQL queries and load table schemas. They are used by the [SqlLoader](../../reference/griptape/loaders/sql_loader.md) to process data. All loaders implement the following methods: -* `execute_query()` executes a query and returns [RowResult](../../reference/griptape/drivers/sql/base_sql_driver.md#griptape.drivers.sql.base_sql_driver.BaseSqlDriver.RowResult.md)s. +* `execute_query()` executes a query and returns [RowResult](../../reference/griptape/drivers/sql/base_sql_driver.md#griptape.drivers.sql.base_sql_driver.BaseSqlDriver.RowResult)s. * `execute_query_row()` executes a query and returns a raw result from SQL. * `get_table_schema()` returns a table schema. !!! info More database-specific SQL drivers are coming soon. -## SqlDriver +## SQL This is a basic SQL loader based on [SQLAlchemy 1.x](https://docs.sqlalchemy.org/en/14/). Here is an example of how to use it: @@ -22,7 +22,7 @@ driver = SqlDriver( driver.execute_query("select 'foo', 'bar';") ``` -## AmazonRedshiftSqlDriver +## Amazon Redshift !!! info This driver requires the `drivers-sql-redshift` [extra](../index.md#extras). @@ -46,7 +46,7 @@ driver = AmazonRedshiftSqlDriver( driver.execute_query("select * from people;") ``` -## SnowflakeSqlDriver +## Snowflake !!! info This driver requires the `drivers-sql-snowflake` [extra](../index.md#extras). diff --git a/docs/griptape-framework/drivers/structure-run-drivers.md b/docs/griptape-framework/drivers/structure-run-drivers.md index c2b94190c..54413c3a2 100644 --- a/docs/griptape-framework/drivers/structure-run-drivers.md +++ b/docs/griptape-framework/drivers/structure-run-drivers.md @@ -2,7 +2,7 @@ Structure Run Drivers can be used to run Griptape Structures in a variety of runtime environments. When combined with the [Structure Run Task](../../griptape-framework/structures/tasks.md#structure-run-task) or [Structure Run Client](../../griptape-tools/official-tools/structure-run-client.md) you can create complex, multi-agent pipelines that span multiple runtime environments. -## Local Structure Run Driver +## Local The [LocalStructureRunDriver](../../reference/griptape/drivers/structure_run/local_structure_run_driver.md) is used to run Griptape Structures in the same runtime environment as the code that is running the Structure. @@ -53,7 +53,7 @@ joke_coordinator = Pipeline( joke_coordinator.run("Tell me a joke") ``` -## Griptape Cloud Structure Run Driver +## Griptape Cloud The [GriptapeCloudStructureRunDriver](../../reference/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.md) is used to run Griptape Structures in the Griptape Cloud. diff --git a/docs/griptape-framework/drivers/vector-store-drivers.md b/docs/griptape-framework/drivers/vector-store-drivers.md index 73a416c84..481a2dd84 100644 --- a/docs/griptape-framework/drivers/vector-store-drivers.md +++ b/docs/griptape-framework/drivers/vector-store-drivers.md @@ -16,7 +16,7 @@ Each vector driver takes a [BaseEmbeddingDriver](../../reference/griptape/driver !!! info More vector drivers are coming soon. -## Local Vector Store Driver +## Local The [LocalVectorStoreDriver](../../reference/griptape/drivers/vector/local_vector_store_driver.md) can be used to load and query data from memory. Here is a complete example of how the driver can be used to load a webpage into the driver and query it later: @@ -47,7 +47,7 @@ print("\n\n".join(values)) ``` -## Pinecone Vector Store Driver +## Pinecone !!! info This driver requires the `drivers-vector-pinecone` [extra](../index.md#extras). @@ -103,7 +103,7 @@ result = vector_store_driver.query( ) ``` -## Marqo Vector Store Driver +## Marqo !!! info This driver requires the `drivers-vector-marqo` [extra](../index.md#extras). @@ -157,7 +157,7 @@ result = vector_store.query(query="What is griptape?") print(result) ``` -## Mongodb Atlas Vector Store Driver +## Mongodb Atlas !!! info This driver requires the `drivers-vector-mongodb` [extra](../index.md#extras). @@ -225,14 +225,14 @@ The format for creating a vector index should look similar to the following: ``` Replace `path_to_vector` with the expected field name where the vector content will be. -## Azure MongoDB Vector Store Driver +## Azure MongoDB !!! info This driver requires the `drivers-vector-mongodb` [extra](../index.md#extras). The [AzureMongoDbVectorStoreDriver](../../reference/griptape/drivers/vector/azure_mongodb_vector_store_driver.md) provides support for storing vector data in an Azure CosmosDb database account using the MongoDb vCore API -Here is an example of how the driver can be used to load and query information in an Azure CosmosDb MongoDb vCore database. It is almost the same as the [MongodbAtlasVectorStoreDriver](#mongodb-atlas-vector-store-driver): +Here is an example of how the driver can be used to load and query information in an Azure CosmosDb MongoDb vCore database. It is very similar to the Driver for [MongoDb Atlas](#mongodb-atlas): ```python from griptape.drivers import AzureMongoDbVectorStoreDriver, OpenAiEmbeddingDriver @@ -274,7 +274,7 @@ result = vector_store.query(query="What is griptape?") print(result) ``` -## Redis Vector Store Driver +## Redis !!! info This driver requires the `drivers-vector-redis` [extra](../index.md#extras). @@ -319,7 +319,7 @@ The format for creating a vector index should be similar to the following: FT.CREATE idx:griptape ON hash PREFIX 1 "griptape:" SCHEMA namespace TAG vector VECTOR FLAT 6 TYPE FLOAT32 DIM 1536 DISTANCE_METRIC COSINE ``` -## OpenSearch Vector Store Driver +## OpenSearch !!! info This driver requires the `drivers-vector-opensearch` [extra](../index.md#extras). @@ -372,7 +372,7 @@ The body mappings for creating a vector index should look similar to the followi } ``` -## PGVector Vector Store Driver +## PGVector !!! info This driver requires the `drivers-vector-postgresql` [extra](../index.md#extras). diff --git a/docs/griptape-framework/drivers/web-scraper-drivers.md b/docs/griptape-framework/drivers/web-scraper-drivers.md index 888605b73..a02365b67 100644 --- a/docs/griptape-framework/drivers/web-scraper-drivers.md +++ b/docs/griptape-framework/drivers/web-scraper-drivers.md @@ -4,7 +4,7 @@ Web Scraper Drivers can be used to scrape text from the web. They are used by [W * `scrape_url()` scrapes text from a website and returns a [TextArtifact](../../reference/griptape/artifacts/text_artifact.md). The format of the scrapped text is determined by the Driver. -## Markdownify Web Scraper Driver +## Markdownify !!! info This driver requires the `drivers-web-scraper-markdownify` [extra](../index.md#extras) and the @@ -64,7 +64,7 @@ agent = Agent( agent.run("List all email addresses on griptape.ai in a flat numbered markdown list.") ``` -## Trafilatura Web Scraper Driver +## Trafilatura !!! info This driver requires the `drivers-web-scraper-trafilatura` [extra](../index.md#extras). diff --git a/docs/griptape-framework/engines/audio-engines.md b/docs/griptape-framework/engines/audio-engines.md index cbef1ef23..6494d5365 100644 --- a/docs/griptape-framework/engines/audio-engines.md +++ b/docs/griptape-framework/engines/audio-engines.md @@ -2,7 +2,7 @@ [Audio Generation Engines](../../reference/griptape/engines/audio/index.md) facilitate audio generation. Audio Generation Engines provides a `run` method that accepts the necessary inputs for its particular mode and provides the request to the configured [Driver](../drivers/text-to-speech-drivers.md). -### Text to Speech Engine +### Text to Speech This Engine facilitates synthesizing speech from text inputs. @@ -28,7 +28,7 @@ engine.run( ) ``` -### Audio Transcription Engine +### Audio Transcription The [Audio Transcription Engine](../../reference/griptape/engines/audio/audio_transcription_engine.md) facilitates transcribing speech from audio inputs. diff --git a/docs/griptape-framework/engines/extraction-engines.md b/docs/griptape-framework/engines/extraction-engines.md index f7969ff4a..101f81ba8 100644 --- a/docs/griptape-framework/engines/extraction-engines.md +++ b/docs/griptape-framework/engines/extraction-engines.md @@ -3,7 +3,7 @@ Extraction Engines in Griptape facilitate the extraction of data from text forma These engines play a crucial role in the functionality of [Extraction Tasks](../../griptape-framework/structures/tasks.md). As of now, Griptape supports two types of Extraction Engines: the CSV Extraction Engine and the JSON Extraction Engine. -## CSV Extraction Engine +## CSV The CSV Extraction Engine is designed specifically for extracting data from CSV-formatted content. @@ -39,7 +39,7 @@ Bob,35,California Charlie,40,Texas ``` -## JSON Extraction Engine +## JSON The JSON Extraction Engine is tailored for extracting data from JSON-formatted content. diff --git a/docs/griptape-framework/engines/image-generation-engines.md b/docs/griptape-framework/engines/image-generation-engines.md index 0c3997fa9..9d38fd197 100644 --- a/docs/griptape-framework/engines/image-generation-engines.md +++ b/docs/griptape-framework/engines/image-generation-engines.md @@ -39,7 +39,7 @@ engine.run( ) ``` -### Prompt Image Generation Engine +### Prompt Image This Engine facilitates generating images from text prompts. @@ -65,7 +65,7 @@ engine.run( ) ``` -### Variation Image Generation Engine +### Variation This Engine facilitates generating variations of an input image according to a text prompt. The input image is used as a reference for the model's generation. @@ -95,7 +95,7 @@ engine.run( ) ``` -### Inpainting Image Generation Engine +### Inpainting This Engine facilitates inpainting, or modifying an input image according to a text prompt within the bounds of a mask defined by mask image. After inpainting, the area specified by the mask is replaced with the model's generation, while the rest of the input image remains the same. @@ -130,7 +130,7 @@ engine.run( ) ``` -### Outpainting Image Generation Engine +### Outpainting This Engine facilitates outpainting, or modifying an input image according to a text prompt outside the bounds of a mask defined by a mask image. After outpainting, the area of the input image specified by the mask remains the same, while the rest is replaced with the model's generation. diff --git a/docs/griptape-framework/engines/image-query-engines.md b/docs/griptape-framework/engines/image-query-engines.md deleted file mode 100644 index 0457657f8..000000000 --- a/docs/griptape-framework/engines/image-query-engines.md +++ /dev/null @@ -1,25 +0,0 @@ -# ImageQueryEngine - -The [Image Query Engine](../../reference/griptape/engines/image_query/image_query_engine.md) is used to execute natural language queries on the contents of images. You can specify the provider and model used to query the image by providing the Engine with a particular [Image Query Driver](../drivers/image-query-drivers.md). - -All Image Query Drivers default to a `max_tokens` of 256. You can tune this value based on your use case and the [Image Query Driver](../drivers/image-query-drivers.md) you are providing. - -```python -from griptape.drivers import OpenAiImageQueryDriver -from griptape.engines import ImageQueryEngine -from griptape.loaders import ImageLoader - -driver = OpenAiImageQueryDriver( - model="gpt-4o", - max_tokens=256 -) - -engine = ImageQueryEngine( - image_query_driver=driver -) - -with open("tests/resources/mountain.png", "rb") as f: - image_artifact = ImageLoader().load(f.read()) - -engine.run("Describe the weather in the image", [image_artifact]) -``` diff --git a/docs/griptape-framework/engines/query-engines.md b/docs/griptape-framework/engines/query-engines.md index 8acd4686d..38f5f6610 100644 --- a/docs/griptape-framework/engines/query-engines.md +++ b/docs/griptape-framework/engines/query-engines.md @@ -1,13 +1,13 @@ ## Overview -Query engines are used to search collections of text. +Query engines are used to perform text queries against various modalities. -## VectorQueryEngine +## Vector -Used to query vector storages. You can set a custom [prompt_driver](../../reference/griptape/engines/query/vector_query_engine.md#griptape.engines.query.vector_query_engine.VectorQueryEngine.prompt_driver.md) and [vector_store_driver](../../reference/griptape/engines/query/vector_query_engine.md#griptape.engines.query.vector_query_engine.VectorQueryEngine.vector_store_driver.md). Uses [LocalVectorStoreDriver](../../reference/griptape/drivers/vector/local_vector_store_driver.md) by default. +Used to query vector storages. You can set a custom [prompt_driver](../../reference/griptape/engines/query/vector_query_engine.md#griptape.engines.query.vector_query_engine.VectorQueryEngine.prompt_driver) and [vector_store_driver](../../reference/griptape/engines/query/vector_query_engine.md#griptape.engines.query.vector_query_engine.VectorQueryEngine.vector_store_driver). Uses [LocalVectorStoreDriver](../../reference/griptape/drivers/vector/local_vector_store_driver.md) by default. -Use the [upsert_text_artifact](../../reference/griptape/engines/query/vector_query_engine.md#griptape.engines.query.vector_query_engine.VectorQueryEngine.upsert_text_artifact.md) method to insert [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s into vector storage with an optional `namespace`. +Use the [upsert_text_artifact](../../reference/griptape/engines/query/vector_query_engine.md#griptape.engines.query.vector_query_engine.VectorQueryEngine.upsert_text_artifact)s into vector storage with an optional `namespace`. -Use the [VectorQueryEngine](../../reference/griptape/engines/query/vector_query_engine.md#griptape.engines.query.vector_query_engine.VectorQueryEngine.query.md) method to query the vector storage. +Use the [VectorQueryEngine](../../reference/griptape/engines/query/vector_query_engine.md#griptape.engines.query.vector_query_engine.VectorQueryEngine.query) method to query the vector storage. ```python from griptape.drivers import OpenAiChatPromptDriver, LocalVectorStoreDriver, OpenAiEmbeddingDriver @@ -25,3 +25,28 @@ engine.upsert_text_artifacts( engine.query("what is griptape?", namespace="griptape") ``` + +## Image +The [Image Query Engine](../../reference/griptape/engines/image_query/image_query_engine.md) allows you to perform natural language queries on the contents of images. You can specify the provider and model used to query the image by providing the Engine with a particular [Image Query Driver](../drivers/image-query-drivers.md). + +All Image Query Drivers default to a `max_tokens` of 256. You can tune this value based on your use case and the [Image Query Driver](../drivers/image-query-drivers.md) you are providing. + +```python +from griptape.drivers import OpenAiImageQueryDriver +from griptape.engines import ImageQueryEngine +from griptape.loaders import ImageLoader + +driver = OpenAiImageQueryDriver( + model="gpt-4o", + max_tokens=256 +) + +engine = ImageQueryEngine( + image_query_driver=driver +) + +with open("tests/resources/mountain.png", "rb") as f: + image_artifact = ImageLoader().load(f.read()) + +engine.run("Describe the weather in the image", [image_artifact]) +``` diff --git a/docs/griptape-framework/engines/summary-engines.md b/docs/griptape-framework/engines/summary-engines.md index 8ecc1ad09..936c12a7c 100644 --- a/docs/griptape-framework/engines/summary-engines.md +++ b/docs/griptape-framework/engines/summary-engines.md @@ -2,11 +2,11 @@ Summary engines are used to summarize text and collections of [TextArtifact](../../reference/griptape/artifacts/text_artifact.md)s. -## PromptSummaryEngine +## Prompt Used to summarize texts with LLMs. You can set a custom [prompt_driver](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.prompt_driver), [template_generator](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.template_generator), and [chunker](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.chunker). -Use the [summarize_artifacts](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.summarize_artifacts) method to summarize a list of artifacts or [summarize_text](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.summarize_text) to summarize an arbitrary string. +Use the [summarize_artifacts](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.summarize_artifacts) method to summarize a list of artifacts or [summarize_text](../../reference/griptape/engines/summary/base_summary_engine.md#griptape.engines.summary.base_summary_engine.BaseSummaryEngine.summarize_text) to summarize an arbitrary string. ```python import io diff --git a/docs/griptape-framework/misc/tokenizers.md b/docs/griptape-framework/misc/tokenizers.md index aaf488187..1920e912e 100644 --- a/docs/griptape-framework/misc/tokenizers.md +++ b/docs/griptape-framework/misc/tokenizers.md @@ -69,8 +69,8 @@ from griptape.tokenizers import HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer( + model="sentence-transformers/all-MiniLM-L6-v2", max_output_tokens=512, - tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") ) print(tokenizer.count_tokens("Hello world!")) @@ -78,57 +78,18 @@ print(tokenizer.count_input_tokens_left("Hello world!")) print(tokenizer.count_output_tokens_left("Hello world!")) ``` -### Bedrock - -#### Anthropic Claude +### Amazon Bedrock ```python -from griptape.tokenizers import BedrockClaudeTokenizer +from griptape.tokenizers import AmazonBedrockTokenizer -tokenizer = BedrockClaudeTokenizer(model="anthropic.claude-3-sonnet-20240229-v1:0") +tokenizer = AmazonBedrockTokenizer(model="amazon.titan-text-express-v1") print(tokenizer.count_tokens("Hello world!")) print(tokenizer.count_input_tokens_left("Hello world!")) print(tokenizer.count_output_tokens_left("Hello world!")) ``` -#### Amazon Titan -```python -from griptape.tokenizers import BedrockTitanTokenizer - - -tokenizer = BedrockTitanTokenizer(model="amazon.titan-text-express-v1") - -print(tokenizer.count_tokens("Hello world!")) -print(tokenizer.count_input_tokens_left("Hello world!")) -print(tokenizer.count_output_tokens_left("Hello world!")) -``` - -#### Meta Llama 2 -```python -from griptape.tokenizers import BedrockLlamaTokenizer - - -tokenizer = BedrockLlamaTokenizer(model="meta.llama2-13b-chat-v1") - -print(tokenizer.count_tokens("Hello world!")) -print(tokenizer.count_input_tokens_left("Hello world!")) -print(tokenizer.count_output_tokens_left("Hello world!")) -``` - -#### Ai21 Jurassic -```python -from griptape.tokenizers import BedrockJurassicTokenizer - - -tokenizer = BedrockJurassicTokenizer(model="ai21.j2-ultra-v1") - -print(tokenizer.count_tokens("Hello world!")) -print(tokenizer.count_input_tokens_left("Hello world!")) -print(tokenizer.count_output_tokens_left("Hello world!")) -``` - - ### Simple Not all LLM providers have a public tokenizer API. In this case, you can use the `SimpleTokenizer` to count tokens based on a simple heuristic. diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 9637eaf35..d75354768 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -89,7 +89,7 @@ agent = Agent( ### Custom Configs -You can create your own [StructureConfig](../../reference/griptape/config/structure_config.md) by overriding the Drivers in [default_config](../../reference/griptape/config/structure_config.md#griptape.config.structure_config.StructureConfig.default_config). +You can create your own [StructureConfig](../../reference/griptape/config/structure_config.md) by overriding relevant Drivers. The [StructureConfig](../../reference/griptape/config/structure_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyException](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden. This approach ensures that you are informed through clear error messages if you attempt to use Structures without proper Driver configurations. diff --git a/docs/griptape-framework/structures/conversation-memory.md b/docs/griptape-framework/structures/conversation-memory.md index 92fdf98c4..1707a2ad9 100644 --- a/docs/griptape-framework/structures/conversation-memory.md +++ b/docs/griptape-framework/structures/conversation-memory.md @@ -57,7 +57,7 @@ agent.run("Hello!") print(agent.conversation_memory) ``` -You can set the [max_runs](../../reference/griptape/memory/structure/conversation_memory.md#griptape.memory.structure.conversation_memory.ConversationMemory.max_runs) parameter to limit how many runs are kept in memory. +You can set the [max_runs](../../reference/griptape/memory/structure/base_conversation_memory.md#griptape.memory.structure.base_conversation_memory.BaseConversationMemory.max_runs) parameter to limit how many runs are kept in memory. ```python from griptape.structures import Agent diff --git a/docs/griptape-framework/structures/rulesets.md b/docs/griptape-framework/structures/rulesets.md index d97ea65cb..f245579b7 100644 --- a/docs/griptape-framework/structures/rulesets.md +++ b/docs/griptape-framework/structures/rulesets.md @@ -154,7 +154,7 @@ pipeline.run("I love skateboarding!") ### Rules -You can pass [rules](../../reference/griptape/tasks/prompt_task.md#griptape.tasks.prompt_task.PromptTask.rules) directly to the Task to have a Ruleset created for you. +You can pass [rules](../../reference/griptape/mixins/rule_mixin.md#griptape.mixins.rule_mixin.RuleMixin.rules) directly to the Task to have a Ruleset created for you. ```python from griptape.structures import Pipeline diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index f76094dad..219fd4412 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -206,7 +206,6 @@ from griptape.config import ( ) from griptape.drivers import ( AmazonBedrockPromptDriver, - BedrockTitanPromptModelDriver, AmazonBedrockTitanEmbeddingDriver, LocalVectorStoreDriver, OpenAiChatPromptDriver, @@ -227,7 +226,6 @@ agent = Agent( query_engine=VectorQueryEngine( prompt_driver=AmazonBedrockPromptDriver( model="amazon.titan-text-express-v1", - prompt_model_driver=BedrockTitanPromptModelDriver(), ), vector_store_driver=LocalVectorStoreDriver( embedding_driver=AmazonBedrockTitanEmbeddingDriver() diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index 0f05ea3e2..53b1b702e 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -627,7 +627,7 @@ pipeline.run("An image of a mountain shrouded by clouds") ## Image Query Task -The [Image Query Task](../../reference/griptape/tasks/image_query_task.md) executes a natural language query on one or more input images. This Task uses an [Image Query Engine](../engines/image-query-engines.md) configured with an [Image Query Driver](../drivers/image-query-drivers.md) to perform the query. The functionality provided by this Task depend on the capabilities of the model provided by the Driver. +The [Image Query Task](../../reference/griptape/tasks/image_query_task.md) performs a natural language query on one or more input images. This Task uses an [Image Query Engine](../engines/query-engines.md#image) configured with an [Image Query Driver](../drivers/image-query-drivers.md) to perform the query. The functionality provided by this Task depend on the capabilities of the model provided by the Driver. This Task accepts two inputs: a query (represented by either a string or a [Text Artifact](../data/artifacts.md#textartifact)) and a list of [Image Artifacts](../data/artifacts.md#imageartifact) or a Callable returning these two values. @@ -668,7 +668,7 @@ pipeline.run("Describe the weather in the image") ``` ## Structure Run Task -The [Structure Run Task](../../reference/griptape/tasks/structure_run_task.md) executes another Structure with a given input. +The [Structure Run Task](../../reference/griptape/tasks/structure_run_task.md) runs another Structure with a given input. This Task is useful for orchestrating multiple specialized Structures in a single run. Note that the input to the Task is a tuple of arguments that will be passed to the Structure. ```python diff --git a/docs/griptape-framework/structures/workflows.md b/docs/griptape-framework/structures/workflows.md index 3c2bac25a..9490ca1c4 100644 --- a/docs/griptape-framework/structures/workflows.md +++ b/docs/griptape-framework/structures/workflows.md @@ -7,7 +7,8 @@ A [Workflow](../../reference/griptape/structures/workflow.md) is a non-sequentia Workflows have access to the following [context](../../reference/griptape/structures/workflow.md#griptape.structures.workflow.Workflow.context) variables in addition to the [base context](./tasks.md#context): -* `parent_outputs`: outputs into the current task referenceable by parent task IDs. +* `parent_outputs`: dictionary containing mapping of parent IDs to their outputs. +* `parents_output_text`: string containing the concatenated outputs of all parent tasks. * `parents`: parent tasks referenceable by IDs. * `children`: child tasks referenceable by IDs. @@ -17,8 +18,16 @@ Let's build a simple workflow. Let's say, we want to write a story in a fantasy ```python from griptape.tasks import PromptTask from griptape.structures import Workflow +from griptape.utils import StructureVisualizer -workflow = Workflow() + +world_task = PromptTask( + "Create a fictional world based on the following key words {{ keywords|join(', ') }}", + context={ + "keywords": ["fantasy", "ocean", "tidal lock"] + }, + id="world" +) def character_task(task_id, character_name) -> PromptTask: return PromptTask( @@ -26,34 +35,29 @@ def character_task(task_id, character_name) -> PromptTask: context={ "name": character_name }, - id=task_id + id=task_id, + parent_ids=["world"] ) -world_task = PromptTask( - "Create a fictional world based on the following key words {{ keywords|join(', ') }}", - context={ - "keywords": ["fantasy", "ocean", "tidal lock"] - }, - id="world" -) -workflow.add_task(world_task) +scotty_task = character_task("scotty", "Scotty") +annie_task = character_task("annie", "Annie") story_task = PromptTask( "Based on the following description of the world and characters, write a short story:\n{{ parent_outputs['world'] }}\n{{ parent_outputs['scotty'] }}\n{{ parent_outputs['annie'] }}", - id="story" + id="story", + parent_ids=["world", "scotty", "annie"] ) -workflow.add_task(story_task) -character_task_1 = character_task("scotty", "Scotty") -character_task_2 = character_task("annie", "Annie") +workflow = Workflow(tasks=[world_task, story_task, scotty_task, annie_task, story_task]) -# Note the preserve_relationship flag. This ensures that world_task remains a parent of -# story_task so its output can be referenced in the story_task prompt. -workflow.insert_tasks(world_task, [character_task_1, character_task_2], story_task, preserve_relationship=True) +print(StructureVisualizer(workflow).to_url()) workflow.run() ``` +Note that we use the `StructureVisualizer` to get a visual representation of the workflow. If we visit the printed url, it should look like this: + +![Workflow](https://mermaid.ink/img/Z3JhcGggVEQ7OwoJd29ybGQtLT4gc3RvcnkgJiBzY290dHkgJiBhbm5pZTsKCXNjb3R0eS0tPiBzdG9yeTsKCWFubmllLS0+IHN0b3J5Ow==) !!! Info Output edited for brevity @@ -147,3 +151,208 @@ workflow.run() unity and harmony that can exist in diversity. ``` +### Declarative vs Imperative Syntax + +The above example showed how to create a workflow using the declarative syntax via the `parent_ids` init param, but there are a number of declarative and imperative options for you to choose between. There is no functional difference, they merely exist to allow you to structure your code as is most readable for your use case. Possibilities are illustrated below. + +Declaratively specify parents (same as above example): + +```python +from griptape.tasks import PromptTask +from griptape.structures import Workflow +from griptape.rules import Rule + +workflow = Workflow( + tasks=[ + PromptTask("Name an animal", id="animal"), + PromptTask("Describe {{ parent_outputs['animal'] }} with an adjective", id="adjective", parent_ids=["animal"]), + PromptTask("Name a {{ parent_outputs['adjective'] }} animal", id="new-animal", parent_ids=["adjective"]), + ], + rules=[Rule("output a single lowercase word")] +) + +workflow.run() +``` + +Declaratively specify children: + +```python +from griptape.tasks import PromptTask +from griptape.structures import Workflow +from griptape.rules import Rule + +workflow = Workflow( + tasks=[ + PromptTask("Name an animal", id="animal", child_ids=["adjective"]), + PromptTask("Describe {{ parent_outputs['animal'] }} with an adjective", id="adjective", child_ids=["new-animal"]), + PromptTask("Name a {{ parent_outputs['adjective'] }} animal", id="new-animal"), + ], + rules=[Rule("output a single lowercase word")], +) + +workflow.run() +``` + +Declaratively specifying a mix of parents and children: + +```python +from griptape.tasks import PromptTask +from griptape.structures import Workflow +from griptape.rules import Rule + +workflow = Workflow( + tasks=[ + PromptTask("Name an animal", id="animal"), + PromptTask("Describe {{ parent_outputs['animal'] }} with an adjective", id="adjective", parent_ids=["animal"], child_ids=["new-animal"]), + PromptTask("Name a {{ parent_outputs['adjective'] }} animal", id="new-animal"), + ], + rules=[Rule("output a single lowercase word")], +) + +workflow.run() +``` + +Imperatively specify parents: + +```python +from griptape.tasks import PromptTask +from griptape.structures import Workflow +from griptape.rules import Rule + +animal_task = PromptTask("Name an animal", id="animal") +adjective_task = PromptTask("Describe {{ parent_outputs['animal'] }} with an adjective", id="adjective") +new_animal_task = PromptTask("Name a {{ parent_outputs['adjective'] }} animal", id="new-animal") + +adjective_task.add_parent(animal_task) +new_animal_task.add_parent(adjective_task) + +workflow = Workflow( + tasks=[animal_task, adjective_task, new_animal_task], + rules=[Rule("output a single lowercase word")], +) + +workflow.run() +``` + +Imperatively specify children: + +```python +from griptape.tasks import PromptTask +from griptape.structures import Workflow +from griptape.rules import Rule + +animal_task = PromptTask("Name an animal", id="animal") +adjective_task = PromptTask("Describe {{ parent_outputs['animal'] }} with an adjective", id="adjective") +new_animal_task = PromptTask("Name a {{ parent_outputs['adjective'] }} animal", id="new-animal") + +animal_task.add_child(adjective_task) +adjective_task.add_child(new_animal_task) + +workflow = Workflow( + tasks=[animal_task, adjective_task, new_animal_task], + rules=[Rule("output a single lowercase word")], +) + +workflow.run() +``` + +Imperatively specify a mix of parents and children: + +```python +from griptape.tasks import PromptTask +from griptape.structures import Workflow +from griptape.rules import Rule + +animal_task = PromptTask("Name an animal", id="animal") +adjective_task = PromptTask("Describe {{ parent_outputs['animal'] }} with an adjective", id="adjective") +new_animal_task = PromptTask("Name a {{ parent_outputs['adjective'] }} animal", id="new-animal") + +adjective_task.add_parent(animal_task) +adjective_task.add_child(new_animal_task) + +workflow = Workflow( + tasks=[animal_task, adjective_task, new_animal_task], + rules=[Rule("output a single lowercase word")], +) + +workflow.run() +``` + +Or even mix imperative and declarative: + +```python +from griptape.tasks import PromptTask +from griptape.structures import Workflow +from griptape.rules import Rule + +animal_task = PromptTask("Name an animal", id="animal") +adjective_task = PromptTask("Describe {{ parent_outputs['animal'] }} with an adjective", id="adjective", parent_ids=["animal"]) + + +new_animal_task = PromptTask("Name a {{ parent_outputs['adjective'] }} animal", id="new-animal") +new_animal_task.add_parent(adjective_task) + +workflow = Workflow( + tasks=[animal_task, adjective_task, new_animal_task], + rules=[Rule("output a single lowercase word")], +) + +workflow.run() +``` + +### Insert Parallel Tasks + +`Workflow.insert_tasks()` provides a convenient way to insert parallel tasks between parents and children. + +!!! info + By default, all children are removed from the parent task and all parent tasks are removed from the child task. If you want to keep these parent-child relationships, then set the `preserve_relationship` parameter to `True`. + +Imperatively insert parallel tasks between a parent and child: + +```python +from griptape.tasks import PromptTask +from griptape.structures import Workflow +from griptape.rules import Rule + +workflow = Workflow( + rules=[Rule("output a single lowercase word")], +) + +animal_task = PromptTask("Name an animal", id="animal") +adjective_task = PromptTask("Describe {{ parent_outputs['animal'] }} with an adjective", id="adjective") +color_task = PromptTask("Describe {{ parent_outputs['animal'] }} with a color", id="color") +new_animal_task = PromptTask("Name an animal described as: \n{{ parents_output_text }}", id="new-animal") + +# The following workflow runs animal_task, then (adjective_task, and color_task) +# in parallel, then finally new_animal_task. +# +# In other words, the output of animal_task is passed to both adjective_task and color_task +# and the outputs of adjective_task and color_task are then passed to new_animal_task. +workflow.add_task(animal_task) +workflow.add_task(new_animal_task) +workflow.insert_tasks(animal_task, [adjective_task, color_task], new_animal_task) + +workflow.run() +``` + +output: +``` +[06/18/24 09:52:21] INFO PromptTask animal + Input: Name an animal +[06/18/24 09:52:22] INFO PromptTask animal + Output: elephant + INFO PromptTask adjective + Input: Describe elephant with an adjective + INFO PromptTask color + Input: Describe elephant with a color + INFO PromptTask color + Output: gray + INFO PromptTask adjective + Output: majestic + INFO PromptTask new-animal + Input: Name an animal described as: + majestic + gray +[06/18/24 09:52:23] INFO PromptTask new-animal + Output: elephant +``` diff --git a/docs/griptape-framework/tools/index.md b/docs/griptape-framework/tools/index.md index 0ae8054af..4d604a77c 100644 --- a/docs/griptape-framework/tools/index.md +++ b/docs/griptape-framework/tools/index.md @@ -1,4 +1,4 @@ -# Overview +## Overview One of the most powerful features of Griptape is the ability for Toolkit Tasks to generate _chains of thought_ (CoT) and use tools that can interact with the outside world. We use the [ReAct](https://arxiv.org/abs/2210.03629) technique to implement CoT reasoning and acting in the underlying LLMs without using any fine-tuning. diff --git a/griptape/chunkers/base_chunker.py b/griptape/chunkers/base_chunker.py index f2cc452ad..793bf24ad 100644 --- a/griptape/chunkers/base_chunker.py +++ b/griptape/chunkers/base_chunker.py @@ -21,6 +21,11 @@ class BaseChunker(ABC): default=Factory(lambda self: self.tokenizer.max_input_tokens, takes_self=True), kw_only=True ) + @max_tokens.validator # pyright: ignore + def validate_max_tokens(self, _, max_tokens: int) -> None: + if max_tokens < 0: + raise ValueError("max_tokens must be 0 or greater.") + def chunk(self, text: TextArtifact | str) -> list[TextArtifact]: text = text.value if isinstance(text, TextArtifact) else text diff --git a/griptape/config/__init__.py b/griptape/config/__init__.py index 7783b3886..541eb0db0 100644 --- a/griptape/config/__init__.py +++ b/griptape/config/__init__.py @@ -8,6 +8,7 @@ from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig from .anthropic_structure_config import AnthropicStructureConfig from .google_structure_config import GoogleStructureConfig +from .cohere_structure_config import CohereStructureConfig __all__ = [ @@ -19,4 +20,5 @@ "AmazonBedrockStructureConfig", "AnthropicStructureConfig", "GoogleStructureConfig", + "CohereStructureConfig", ] diff --git a/griptape/config/amazon_bedrock_structure_config.py b/griptape/config/amazon_bedrock_structure_config.py index cefb97f57..e70d9c819 100644 --- a/griptape/config/amazon_bedrock_structure_config.py +++ b/griptape/config/amazon_bedrock_structure_config.py @@ -11,7 +11,6 @@ BasePromptDriver, BaseVectorStoreDriver, BedrockClaudeImageQueryModelDriver, - BedrockClaudePromptModelDriver, BedrockTitanImageGenerationModelDriver, LocalVectorStoreDriver, ) @@ -21,11 +20,7 @@ class AmazonBedrockStructureConfig(StructureConfig): prompt_driver: BasePromptDriver = field( default=Factory( - lambda: AmazonBedrockPromptDriver( - model="anthropic.claude-3-sonnet-20240229-v1:0", - stream=False, - prompt_model_driver=BedrockClaudePromptModelDriver(), - ) + lambda: AmazonBedrockPromptDriver(model="anthropic.claude-3-sonnet-20240229-v1:0", stream=False) ), metadata={"serializable": True}, ) diff --git a/griptape/config/cohere_structure_config.py b/griptape/config/cohere_structure_config.py new file mode 100644 index 000000000..82f11b8f4 --- /dev/null +++ b/griptape/config/cohere_structure_config.py @@ -0,0 +1,37 @@ +from attrs import Factory, define, field + +from griptape.config import StructureConfig +from griptape.drivers import ( + BaseEmbeddingDriver, + BasePromptDriver, + CoherePromptDriver, + CohereEmbeddingDriver, + BaseVectorStoreDriver, + LocalVectorStoreDriver, +) + + +@define +class CohereStructureConfig(StructureConfig): + api_key: str = field(metadata={"serializable": False}, kw_only=True) + + prompt_driver: BasePromptDriver = field( + default=Factory(lambda self: CoherePromptDriver(model="command-r", api_key=self.api_key), takes_self=True), + metadata={"serializable": True}, + kw_only=True, + ) + embedding_driver: BaseEmbeddingDriver = field( + default=Factory( + lambda self: CohereEmbeddingDriver( + model="embed-english-v3.0", api_key=self.api_key, input_type="search_document" + ), + takes_self=True, + ), + metadata={"serializable": True}, + kw_only=True, + ) + vector_store_driver: BaseVectorStoreDriver = field( + default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True), + kw_only=True, + metadata={"serializable": True}, + ) diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index f44602f8b..8e8128d7a 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -1,17 +1,15 @@ from .prompt.base_prompt_driver import BasePromptDriver from .prompt.openai_chat_prompt_driver import OpenAiChatPromptDriver -from .prompt.openai_completion_prompt_driver import OpenAiCompletionPromptDriver from .prompt.azure_openai_chat_prompt_driver import AzureOpenAiChatPromptDriver -from .prompt.azure_openai_completion_prompt_driver import AzureOpenAiCompletionPromptDriver from .prompt.cohere_prompt_driver import CoherePromptDriver from .prompt.huggingface_pipeline_prompt_driver import HuggingFacePipelinePromptDriver from .prompt.huggingface_hub_prompt_driver import HuggingFaceHubPromptDriver from .prompt.anthropic_prompt_driver import AnthropicPromptDriver -from .prompt.amazon_sagemaker_prompt_driver import AmazonSageMakerPromptDriver +from .prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver from .prompt.amazon_bedrock_prompt_driver import AmazonBedrockPromptDriver from .prompt.google_prompt_driver import GooglePromptDriver -from .prompt.base_multi_model_prompt_driver import BaseMultiModelPromptDriver from .prompt.dummy_prompt_driver import DummyPromptDriver +from .prompt.ollama_prompt_driver import OllamaPromptDriver from .memory.conversation.base_conversation_memory_driver import BaseConversationMemoryDriver from .memory.conversation.local_conversation_memory_driver import LocalConversationMemoryDriver @@ -21,18 +19,14 @@ from .embedding.base_embedding_driver import BaseEmbeddingDriver from .embedding.openai_embedding_driver import OpenAiEmbeddingDriver from .embedding.azure_openai_embedding_driver import AzureOpenAiEmbeddingDriver -from .embedding.base_multi_model_embedding_driver import BaseMultiModelEmbeddingDriver -from .embedding.amazon_sagemaker_embedding_driver import AmazonSageMakerEmbeddingDriver +from .embedding.amazon_sagemaker_jumpstart_embedding_driver import AmazonSageMakerJumpstartEmbeddingDriver from .embedding.amazon_bedrock_titan_embedding_driver import AmazonBedrockTitanEmbeddingDriver from .embedding.amazon_bedrock_cohere_embedding_driver import AmazonBedrockCohereEmbeddingDriver from .embedding.voyageai_embedding_driver import VoyageAiEmbeddingDriver from .embedding.huggingface_hub_embedding_driver import HuggingFaceHubEmbeddingDriver from .embedding.google_embedding_driver import GoogleEmbeddingDriver from .embedding.dummy_embedding_driver import DummyEmbeddingDriver - -from .embedding_model.base_embedding_model_driver import BaseEmbeddingModelDriver -from .embedding_model.sagemaker_huggingface_embedding_model_driver import SageMakerHuggingFaceEmbeddingModelDriver -from .embedding_model.sagemaker_tensorflow_hub_embedding_model_driver import SageMakerTensorFlowHubEmbeddingModelDriver +from .embedding.cohere_embedding_driver import CohereEmbeddingDriver from .vector.base_vector_store_driver import BaseVectorStoreDriver from .vector.local_vector_store_driver import LocalVectorStoreDriver @@ -51,14 +45,6 @@ from .sql.snowflake_sql_driver import SnowflakeSqlDriver from .sql.sql_driver import SqlDriver -from .prompt_model.base_prompt_model_driver import BasePromptModelDriver -from .prompt_model.sagemaker_llama_prompt_model_driver import SageMakerLlamaPromptModelDriver -from .prompt_model.sagemaker_falcon_prompt_model_driver import SageMakerFalconPromptModelDriver -from .prompt_model.bedrock_titan_prompt_model_driver import BedrockTitanPromptModelDriver -from .prompt_model.bedrock_claude_prompt_model_driver import BedrockClaudePromptModelDriver -from .prompt_model.bedrock_jurassic_prompt_model_driver import BedrockJurassicPromptModelDriver -from .prompt_model.bedrock_llama_prompt_model_driver import BedrockLlamaPromptModelDriver - from .image_generation_model.base_image_generation_model_driver import BaseImageGenerationModelDriver from .image_generation_model.bedrock_stable_diffusion_image_generation_model_driver import ( BedrockStableDiffusionImageGenerationModelDriver, @@ -115,18 +101,16 @@ __all__ = [ "BasePromptDriver", "OpenAiChatPromptDriver", - "OpenAiCompletionPromptDriver", "AzureOpenAiChatPromptDriver", - "AzureOpenAiCompletionPromptDriver", "CoherePromptDriver", "HuggingFacePipelinePromptDriver", "HuggingFaceHubPromptDriver", "AnthropicPromptDriver", - "AmazonSageMakerPromptDriver", + "AmazonSageMakerJumpstartPromptDriver", "AmazonBedrockPromptDriver", "GooglePromptDriver", - "BaseMultiModelPromptDriver", "DummyPromptDriver", + "OllamaPromptDriver", "BaseConversationMemoryDriver", "LocalConversationMemoryDriver", "AmazonDynamoDbConversationMemoryDriver", @@ -134,17 +118,14 @@ "BaseEmbeddingDriver", "OpenAiEmbeddingDriver", "AzureOpenAiEmbeddingDriver", - "BaseMultiModelEmbeddingDriver", - "AmazonSageMakerEmbeddingDriver", + "AmazonSageMakerJumpstartEmbeddingDriver", "AmazonBedrockTitanEmbeddingDriver", "AmazonBedrockCohereEmbeddingDriver", "VoyageAiEmbeddingDriver", "HuggingFaceHubEmbeddingDriver", "GoogleEmbeddingDriver", "DummyEmbeddingDriver", - "BaseEmbeddingModelDriver", - "SageMakerHuggingFaceEmbeddingModelDriver", - "SageMakerTensorFlowHubEmbeddingModelDriver", + "CohereEmbeddingDriver", "BaseVectorStoreDriver", "LocalVectorStoreDriver", "PineconeVectorStoreDriver", @@ -160,13 +141,6 @@ "AmazonRedshiftSqlDriver", "SnowflakeSqlDriver", "SqlDriver", - "BasePromptModelDriver", - "SageMakerLlamaPromptModelDriver", - "SageMakerFalconPromptModelDriver", - "BedrockTitanPromptModelDriver", - "BedrockClaudePromptModelDriver", - "BedrockJurassicPromptModelDriver", - "BedrockLlamaPromptModelDriver", "BaseImageGenerationModelDriver", "BedrockStableDiffusionImageGenerationModelDriver", "BedrockTitanImageGenerationModelDriver", diff --git a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py index 15ce67c4c..903022f86 100644 --- a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py @@ -3,7 +3,8 @@ from typing import Any, TYPE_CHECKING from attrs import define, field, Factory from griptape.drivers import BaseEmbeddingDriver -from griptape.tokenizers import BedrockCohereTokenizer +from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer +from griptape.tokenizers.base_tokenizer import BaseTokenizer from griptape.utils import import_optional_dependency if TYPE_CHECKING: @@ -28,8 +29,8 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver): model: str = field(default=DEFAULT_MODEL, kw_only=True) input_type: str = field(default="search_query", kw_only=True) session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - tokenizer: BedrockCohereTokenizer = field( - default=Factory(lambda self: BedrockCohereTokenizer(model=self.model), takes_self=True), kw_only=True + tokenizer: BaseTokenizer = field( + default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True ) bedrock_client: Any = field( default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True diff --git a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py index a510c618c..b754b0608 100644 --- a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py @@ -3,7 +3,8 @@ from typing import Any, TYPE_CHECKING from attrs import define, field, Factory from griptape.drivers import BaseEmbeddingDriver -from griptape.tokenizers import BedrockTitanTokenizer +from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer +from griptape.tokenizers.base_tokenizer import BaseTokenizer from griptape.utils import import_optional_dependency if TYPE_CHECKING: @@ -24,8 +25,8 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver): model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True}) session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - tokenizer: BedrockTitanTokenizer = field( - default=Factory(lambda self: BedrockTitanTokenizer(model=self.model), takes_self=True), kw_only=True + tokenizer: BaseTokenizer = field( + default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True ) bedrock_client: Any = field( default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True diff --git a/griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py b/griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py deleted file mode 100644 index 4ab6d2bf7..000000000 --- a/griptape/drivers/embedding/amazon_sagemaker_embedding_driver.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING -import json -from typing import Any - -from attrs import Factory, define, field - -from griptape.drivers import BaseMultiModelEmbeddingDriver -from griptape.utils import import_optional_dependency - -if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingModelDriver - import boto3 - - -@define -class AmazonSageMakerEmbeddingDriver(BaseMultiModelEmbeddingDriver): - session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - sagemaker_client: Any = field( - default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True - ) - embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True) - - def try_embed_chunk(self, chunk: str) -> list[float]: - payload = self.embedding_model_driver.chunk_to_model_params(chunk) - endpoint_response = self.sagemaker_client.invoke_endpoint( - EndpointName=self.model, ContentType="application/x-text", Body=json.dumps(payload).encode("utf-8") - ) - - response = json.loads(endpoint_response.get("Body").read().decode("utf-8")) - return self.embedding_model_driver.process_output(response) diff --git a/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py new file mode 100644 index 000000000..2b764c2f4 --- /dev/null +++ b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py @@ -0,0 +1,53 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +import json +from typing import Any, Optional + +from attrs import Factory, define, field + +from griptape.drivers import BaseEmbeddingDriver +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + import boto3 + + +@define +class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver): + session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) + sagemaker_client: Any = field( + default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True + ) + endpoint: str = field(kw_only=True, metadata={"serializable": True}) + custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) + inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + + def try_embed_chunk(self, chunk: str) -> list[float]: + payload = {"text_inputs": chunk, "mode": "embedding"} + + endpoint_response = self.sagemaker_client.invoke_endpoint( + EndpointName=self.endpoint, + ContentType="application/json", + Body=json.dumps(payload).encode("utf-8"), + CustomAttributes=self.custom_attributes, + **( + {"InferenceComponentName": self.inference_component_name} + if self.inference_component_name is not None + else {} + ), + ) + + response = json.loads(endpoint_response.get("Body").read().decode("utf-8")) + + if "embedding" in response: + embedding = response["embedding"] + + if embedding: + if isinstance(embedding[0], list): + return embedding[0] + else: + return embedding + else: + raise ValueError("model response is empty") + else: + raise ValueError("invalid response from model") diff --git a/griptape/drivers/embedding/base_multi_model_embedding_driver.py b/griptape/drivers/embedding/base_multi_model_embedding_driver.py deleted file mode 100644 index 90f827ad2..000000000 --- a/griptape/drivers/embedding/base_multi_model_embedding_driver.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations -from abc import ABC -from typing import TYPE_CHECKING - -from attrs import define, field - -from griptape.drivers import BaseEmbeddingDriver - -if TYPE_CHECKING: - from griptape.drivers import BaseEmbeddingModelDriver - - -@define -class BaseMultiModelEmbeddingDriver(BaseEmbeddingDriver, ABC): - embedding_model_driver: BaseEmbeddingModelDriver = field(kw_only=True) diff --git a/griptape/drivers/embedding/cohere_embedding_driver.py b/griptape/drivers/embedding/cohere_embedding_driver.py new file mode 100644 index 000000000..5e8bdf4dd --- /dev/null +++ b/griptape/drivers/embedding/cohere_embedding_driver.py @@ -0,0 +1,43 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +from attrs import define, field, Factory +from griptape.drivers import BaseEmbeddingDriver +from griptape.tokenizers import CohereTokenizer +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + from cohere import Client + + +@define +class CohereEmbeddingDriver(BaseEmbeddingDriver): + """ + Attributes: + api_key: Cohere API key. + model: Cohere model name. + client: Custom `cohere.Client`. + tokenizer: Custom `CohereTokenizer`. + input_type: Cohere embedding input type. + """ + + DEFAULT_MODEL = "models/embedding-001" + + api_key: str = field(kw_only=True, metadata={"serializable": False}) + client: Client = field( + default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True), + kw_only=True, + ) + tokenizer: CohereTokenizer = field( + default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), + kw_only=True, + ) + + input_type: str = field(kw_only=True, metadata={"serializable": True}) + + def try_embed_chunk(self, chunk: str) -> list[float]: + result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type) + + if isinstance(result.embeddings, list): + return result.embeddings[0] + else: + raise ValueError("Non-float embeddings are not supported.") diff --git a/griptape/drivers/embedding_model/base_embedding_model_driver.py b/griptape/drivers/embedding_model/base_embedding_model_driver.py deleted file mode 100644 index ad7bf3bda..000000000 --- a/griptape/drivers/embedding_model/base_embedding_model_driver.py +++ /dev/null @@ -1,11 +0,0 @@ -from attrs import define -from abc import ABC, abstractmethod - - -@define -class BaseEmbeddingModelDriver(ABC): - @abstractmethod - def chunk_to_model_params(self, chunk: str) -> dict: ... - - @abstractmethod - def process_output(self, output: dict) -> list[float]: ... diff --git a/griptape/drivers/embedding_model/sagemaker_huggingface_embedding_model_driver.py b/griptape/drivers/embedding_model/sagemaker_huggingface_embedding_model_driver.py deleted file mode 100644 index dceffcd8a..000000000 --- a/griptape/drivers/embedding_model/sagemaker_huggingface_embedding_model_driver.py +++ /dev/null @@ -1,11 +0,0 @@ -from attrs import define -from griptape.drivers import BaseEmbeddingModelDriver - - -@define -class SageMakerHuggingFaceEmbeddingModelDriver(BaseEmbeddingModelDriver): - def chunk_to_model_params(self, chunk: str) -> dict: - return {"text_inputs": chunk} - - def process_output(self, output: dict) -> list[float]: - return output["embedding"][0] diff --git a/griptape/drivers/embedding_model/sagemaker_tensorflow_hub_embedding_model_driver.py b/griptape/drivers/embedding_model/sagemaker_tensorflow_hub_embedding_model_driver.py deleted file mode 100644 index 9d9632fb0..000000000 --- a/griptape/drivers/embedding_model/sagemaker_tensorflow_hub_embedding_model_driver.py +++ /dev/null @@ -1,11 +0,0 @@ -from attrs import define -from griptape.drivers import BaseEmbeddingModelDriver - - -@define -class SageMakerTensorFlowHubEmbeddingModelDriver(BaseEmbeddingModelDriver): - def chunk_to_model_params(self, chunk: str) -> dict: - return {"text_inputs": chunk} - - def process_output(self, output: dict) -> list[float]: - return output["embedding"] diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index 0675e7f92..849ed0901 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -1,55 +1,75 @@ from __future__ import annotations -import json -from typing import TYPE_CHECKING, Any + from collections.abc import Iterator -from attrs import define, field, Factory +from typing import TYPE_CHECKING, Any + +from attrs import Factory, define, field + from griptape.artifacts import TextArtifact +from griptape.drivers import BasePromptDriver +from griptape.tokenizers import AmazonBedrockTokenizer, BaseTokenizer from griptape.utils import import_optional_dependency -from .base_multi_model_prompt_driver import BaseMultiModelPromptDriver if TYPE_CHECKING: - from griptape.utils import PromptStack import boto3 + from griptape.utils import PromptStack + @define -class AmazonBedrockPromptDriver(BaseMultiModelPromptDriver): +class AmazonBedrockPromptDriver(BasePromptDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) bedrock_client: Any = field( default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True ) + additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True) + tokenizer: BaseTokenizer = field( + default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True + ) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack) - payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)} - if isinstance(model_input, dict): - payload.update(model_input) + response = self.bedrock_client.converse(**self._base_params(prompt_stack)) + + output_message = response["output"]["message"] + output_content = output_message["content"][0]["text"] - response = self.bedrock_client.invoke_model( - modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload) - ) + return TextArtifact(output_content) - response_body = response["body"].read() + def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + response = self.bedrock_client.converse_stream(**self._base_params(prompt_stack)) - if response_body: - return self.prompt_model_driver.process_output(response_body) + stream = response.get("stream") + if stream is not None: + for event in stream: + if "contentBlockDelta" in event: + yield TextArtifact(event["contentBlockDelta"]["delta"]["text"]) else: raise Exception("model response is empty") - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - model_input = self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack) - payload = {**self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack)} - if isinstance(model_input, dict): - payload.update(model_input) - - response = self.bedrock_client.invoke_model_with_response_stream( - modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload) - ) - - response_body = response["body"] - if response_body: - for chunk in response["body"]: - chunk_bytes = chunk["chunk"]["bytes"] - yield self.prompt_model_driver.process_output(chunk_bytes) + def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + content = [{"text": prompt_input.content}] + + if prompt_input.is_system(): + return {"text": prompt_input.content} + elif prompt_input.is_assistant(): + return {"role": "assistant", "content": content} else: - raise Exception("model response is empty") + return {"role": "user", "content": content} + + def _base_params(self, prompt_stack: PromptStack) -> dict: + system_messages = [ + self._prompt_stack_input_to_message(input) + for input in prompt_stack.inputs + if input.is_system() and input.content + ] + messages = [ + self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs if not input.is_system() + ] + + return { + "modelId": self.model, + "messages": messages, + "system": system_messages, + "inferenceConfig": {"temperature": self.temperature}, + "additionalModelRequestFields": self.additional_model_request_fields, + } diff --git a/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py new file mode 100644 index 000000000..18f8e4b77 --- /dev/null +++ b/griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import json +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any, Optional + +from attrs import Factory, define, field + +from griptape.artifacts import TextArtifact +from griptape.drivers.prompt.base_prompt_driver import BasePromptDriver +from griptape.tokenizers import HuggingFaceTokenizer +from griptape.utils import import_optional_dependency + +if TYPE_CHECKING: + import boto3 + + from griptape.utils import PromptStack + + +@define +class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver): + session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) + sagemaker_client: Any = field( + default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True + ) + endpoint: str = field(kw_only=True, metadata={"serializable": True}) + custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) + inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) + tokenizer: HuggingFaceTokenizer = field( + default=Factory( + lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True + ), + kw_only=True, + ) + + @stream.validator # pyright: ignore + def validate_stream(self, _, stream): + if stream: + raise ValueError("streaming is not supported") + + def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + payload = {"inputs": self._to_model_input(prompt_stack), "parameters": self._to_model_params(prompt_stack)} + + response = self.sagemaker_client.invoke_endpoint( + EndpointName=self.endpoint, + ContentType="application/json", + Body=json.dumps(payload), + CustomAttributes=self.custom_attributes, + **( + {"InferenceComponentName": self.inference_component_name} + if self.inference_component_name is not None + else {} + ), + ) + + decoded_body = json.loads(response["Body"].read().decode("utf8")) + + if isinstance(decoded_body, list): + if decoded_body: + return TextArtifact(decoded_body[0]["generated_text"]) + else: + raise ValueError("model response is empty") + else: + return TextArtifact(decoded_body["generated_text"]) + + def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + raise NotImplementedError("streaming is not supported") + + def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + return {"role": prompt_input.role, "content": prompt_input.content} + + def _to_model_input(self, prompt_stack: PromptStack) -> str: + prompt = self.tokenizer.tokenizer.apply_chat_template( + [self._prompt_stack_input_to_message(i) for i in prompt_stack.inputs], + tokenize=False, + add_generation_prompt=True, + ) + + if isinstance(prompt, str): + return prompt + else: + raise ValueError("Invalid output type.") + + def _to_model_params(self, prompt_stack: PromptStack) -> dict: + return { + "temperature": self.temperature, + "max_new_tokens": self.max_tokens, + "do_sample": True, + "eos_token_id": self.tokenizer.tokenizer.eos_token_id, + "stop_strings": self.tokenizer.stop_sequences, + "return_full_text": False, + } diff --git a/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py b/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py deleted file mode 100644 index 2934ea642..000000000 --- a/griptape/drivers/prompt/amazon_sagemaker_prompt_driver.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations -import json -from typing import TYPE_CHECKING, Any -from collections.abc import Iterator -from attrs import define, field, Factory -from griptape.artifacts import TextArtifact -from griptape.utils import import_optional_dependency -from .base_multi_model_prompt_driver import BaseMultiModelPromptDriver - -if TYPE_CHECKING: - from griptape.utils import PromptStack - import boto3 - - -@define -class AmazonSageMakerPromptDriver(BaseMultiModelPromptDriver): - session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) - sagemaker_client: Any = field( - default=Factory(lambda self: self.session.client("sagemaker-runtime"), takes_self=True), kw_only=True - ) - endpoint: str = field(kw_only=True, metadata={"serializable": True}) - model: str = field(default=None, kw_only=True, metadata={"serializable": True}) - custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) - stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - - @stream.validator # pyright: ignore - def validate_stream(self, _, stream): - if stream: - raise ValueError("streaming is not supported") - - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - payload = { - "inputs": self.prompt_model_driver.prompt_stack_to_model_input(prompt_stack), - "parameters": self.prompt_model_driver.prompt_stack_to_model_params(prompt_stack), - } - response = self.sagemaker_client.invoke_endpoint( - EndpointName=self.endpoint, - ContentType="application/json", - Body=json.dumps(payload), - CustomAttributes=self.custom_attributes, - **({"InferenceComponentName": self.model} if self.model is not None else {}), - ) - - decoded_body = json.loads(response["Body"].read().decode("utf8")) - - if decoded_body: - return self.prompt_model_driver.process_output(decoded_body) - else: - raise Exception("model response is empty") - - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - raise NotImplementedError("streaming is not supported") diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 486233643..b74a9d5f6 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -5,7 +5,7 @@ from griptape.artifacts import TextArtifact from griptape.utils import PromptStack, import_optional_dependency from griptape.drivers import BasePromptDriver -from griptape.tokenizers import AnthropicTokenizer +from griptape.tokenizers import AnthropicTokenizer, BaseTokenizer @define @@ -15,7 +15,6 @@ class AnthropicPromptDriver(BasePromptDriver): api_key: Anthropic API key. model: Anthropic model name. client: Custom `Anthropic` client. - tokenizer: Custom `AnthropicTokenizer`. """ api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) @@ -26,11 +25,12 @@ class AnthropicPromptDriver(BasePromptDriver): ), kw_only=True, ) - tokenizer: AnthropicTokenizer = field( + tokenizer: BaseTokenizer = field( default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True ) top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True}) top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) + max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: response = self.client.messages.create(**self._base_params(prompt_stack)) @@ -44,34 +44,36 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: if chunk.type == "content_block_delta": yield TextArtifact(value=chunk.delta.text) + def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + content = prompt_input.content + + if prompt_input.is_system(): + return {"role": "system", "content": content} + elif prompt_input.is_assistant(): + return {"role": "assistant", "content": content} + else: + return {"role": "user", "content": content} + def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict: messages = [ - {"role": self.__to_anthropic_role(prompt_input), "content": prompt_input.content} + self._prompt_stack_input_to_message(prompt_input) for prompt_input in prompt_stack.inputs if not prompt_input.is_system() ] - system = next((i for i in prompt_stack.inputs if i.is_system()), None) + system = next((self._prompt_stack_input_to_message(i) for i in prompt_stack.inputs if i.is_system()), None) if system is None: return {"messages": messages} else: - return {"messages": messages, "system": system.content} + return {"messages": messages, "system": system["content"]} def _base_params(self, prompt_stack: PromptStack) -> dict: return { "model": self.model, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, - "max_tokens": self.max_output_tokens(self.prompt_stack_to_string(prompt_stack)), "top_p": self.top_p, "top_k": self.top_k, + "max_tokens": self.max_tokens, **self._prompt_stack_to_model_input(prompt_stack), } - - def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str: - if prompt_input.is_system(): - return "system" - elif prompt_input.is_assistant(): - return "assistant" - else: - return "user" diff --git a/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py b/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py deleted file mode 100644 index 4ff2a4902..000000000 --- a/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Callable, Optional -from attrs import define, field, Factory -from griptape.drivers import OpenAiCompletionPromptDriver -import openai - - -@define -class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver): - """ - Attributes: - azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name. - azure_endpoint: An Azure OpenAi endpoint. - azure_ad_token: An optional Azure Active Directory token. - azure_ad_token_provider: An optional Azure Active Directory token provider. - api_version: An Azure OpenAi API version. - client: An `openai.AzureOpenAI` client. - """ - - azure_deployment: str = field( - kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True} - ) - azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) - azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) - azure_ad_token_provider: Optional[Callable[[], str]] = field( - kw_only=True, default=None, metadata={"serializable": False} - ) - api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True}) - client: openai.AzureOpenAI = field( - default=Factory( - lambda self: openai.AzureOpenAI( - organization=self.organization, - api_key=self.api_key, - api_version=self.api_version, - azure_endpoint=self.azure_endpoint, - azure_deployment=self.azure_deployment, - azure_ad_token=self.azure_ad_token, - azure_ad_token_provider=self.azure_ad_token_provider, - ), - takes_self=True, - ) - ) diff --git a/griptape/drivers/prompt/base_multi_model_prompt_driver.py b/griptape/drivers/prompt/base_multi_model_prompt_driver.py deleted file mode 100644 index 5411ea730..000000000 --- a/griptape/drivers/prompt/base_multi_model_prompt_driver.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations -from attrs import define, field -from abc import ABC -from .base_prompt_driver import BasePromptDriver -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from griptape.tokenizers import BaseTokenizer - from griptape.drivers import BasePromptModelDriver - - -@define -class BaseMultiModelPromptDriver(BasePromptDriver, ABC): - """Prompt Driver for platforms like Amazon SageMaker, and Amazon Bedrock that host many LLM models. - - Instances of this Prompt Driver require a Prompt Model Driver which is used to convert the prompt stack - into a model input and parameters, and to process the model output. - - Attributes: - model: Name of the model to use. - tokenizer: Tokenizer to use. Defaults to the Tokenizer of the Prompt Model Driver. - prompt_model_driver: Prompt Model Driver to use. - """ - - tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True) - prompt_model_driver: BasePromptModelDriver = field(kw_only=True, metadata={"serializable": True}) - stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - - @stream.validator # pyright: ignore - def validate_stream(self, _, stream): - if stream and not self.prompt_model_driver.supports_streaming: - raise ValueError(f"{self.prompt_model_driver.__class__.__name__} does not support streaming") - - def __attrs_post_init__(self) -> None: - self.prompt_model_driver.prompt_driver = self - - if not self.tokenizer: - self.tokenizer = self.prompt_model_driver.tokenizer diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 096035f8b..9ef076dbc 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional, Callable +from typing import TYPE_CHECKING, Optional from collections.abc import Iterator from attrs import define, field, Factory from griptape.events import StartPromptEvent, FinishPromptEvent, CompletionChunkEvent @@ -32,9 +32,6 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): temperature: float = field(default=0.1, kw_only=True, metadata={"serializable": True}) max_tokens: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) structure: Optional[Structure] = field(default=None, kw_only=True) - prompt_stack_to_string: Callable[[PromptStack], str] = field( - default=Factory(lambda self: self.default_prompt_stack_to_string_converter, takes_self=True), kw_only=True - ) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory(lambda: (ImportError, ValueError)), kw_only=True ) @@ -42,23 +39,12 @@ class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - def max_output_tokens(self, text: str | list) -> int: - tokens_left = self.tokenizer.count_output_tokens_left(text) - - if self.max_tokens: - return min(self.max_tokens, tokens_left) - else: - return tokens_left - - def token_count(self, prompt_stack: PromptStack) -> int: - return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack)) - def before_run(self, prompt_stack: PromptStack) -> None: if self.structure: self.structure.publish_event( StartPromptEvent( model=self.model, - token_count=self.token_count(prompt_stack), + token_count=self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack)), prompt_stack=prompt_stack, prompt=self.prompt_stack_to_string(prompt_stack), ) @@ -67,7 +53,9 @@ def before_run(self, prompt_stack: PromptStack) -> None: def after_run(self, result: TextArtifact) -> None: if self.structure: self.structure.publish_event( - FinishPromptEvent(model=self.model, token_count=result.token_count(self.tokenizer), result=result.value) + FinishPromptEvent( + model=self.model, result=result.value, token_count=self.tokenizer.count_tokens(result.value) + ) ) def run(self, prompt_stack: PromptStack) -> TextArtifact: @@ -92,7 +80,16 @@ def run(self, prompt_stack: PromptStack) -> TextArtifact: else: raise Exception("prompt driver failed after all retry attempts") - def default_prompt_stack_to_string_converter(self, prompt_stack: PromptStack) -> str: + def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: + """Converts a Prompt Stack to a string for token counting or model input. + This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens. + + Args: + prompt_stack: The Prompt Stack to convert to a string. + + Returns: + A single string representation of the Prompt Stack. + """ prompt_lines = [] for i in prompt_stack.inputs: diff --git a/griptape/drivers/prompt/cohere_prompt_driver.py b/griptape/drivers/prompt/cohere_prompt_driver.py index 2f85c49bf..3ff2c9e89 100644 --- a/griptape/drivers/prompt/cohere_prompt_driver.py +++ b/griptape/drivers/prompt/cohere_prompt_driver.py @@ -4,8 +4,8 @@ from attrs import define, field, Factory from griptape.artifacts import TextArtifact from griptape.drivers import BasePromptDriver -from griptape.tokenizers import CohereTokenizer from griptape.utils import PromptStack, import_optional_dependency +from griptape.tokenizers import BaseTokenizer, CohereTokenizer if TYPE_CHECKING: from cohere import Client @@ -18,48 +18,48 @@ class CoherePromptDriver(BasePromptDriver): api_key: Cohere API key. model: Cohere model name. client: Custom `cohere.Client`. - tokenizer: Custom `CohereTokenizer`. """ - api_key: str = field(kw_only=True, metadata={"serializable": True}) + api_key: str = field(kw_only=True, metadata={"serializable": False}) model: str = field(kw_only=True, metadata={"serializable": True}) client: Client = field( default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True), kw_only=True, ) - tokenizer: CohereTokenizer = field( + tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True, ) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - result = self.client.generate(**self._base_params(prompt_stack)) + result = self.client.chat(**self._base_params(prompt_stack)) - if result.generations: - if len(result.generations) == 1: - generation = result.generations[0] - - return TextArtifact(value=generation.text.strip()) - else: - raise Exception("completion with more than one choice is not supported yet") - else: - raise Exception("model response is empty") + return TextArtifact(value=result.text) def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - result = self.client.generate( - **self._base_params(prompt_stack), - stream=True, # pyright: ignore[reportCallIssue] - ) + result = self.client.chat_stream(**self._base_params(prompt_stack)) + + for event in result: + if event.event_type == "text-generation": + yield TextArtifact(value=event.text) - for chunk in result: - yield TextArtifact(value=chunk.text) + def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + if prompt_input.is_system(): + return {"role": "SYSTEM", "text": prompt_input.content} + elif prompt_input.is_user(): + return {"role": "USER", "text": prompt_input.content} + else: + return {"role": "ASSISTANT", "text": prompt_input.content} def _base_params(self, prompt_stack: PromptStack) -> dict: - prompt = self.prompt_stack_to_string(prompt_stack) + user_message = prompt_stack.inputs[-1].content + + history_messages = [self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs[:-1]] + return { - "prompt": self.prompt_stack_to_string(prompt_stack), - "model": self.model, + "message": user_message, + "chat_history": history_messages, "temperature": self.temperature, - "end_sequences": self.tokenizer.stop_sequences, - "max_tokens": self.max_output_tokens(prompt), + "stop_sequences": self.tokenizer.stop_sequences, + "max_tokens": self.max_tokens, } diff --git a/griptape/drivers/prompt/dummy_prompt_driver.py b/griptape/drivers/prompt/dummy_prompt_driver.py index f92f9cbc1..a55ecd4fe 100644 --- a/griptape/drivers/prompt/dummy_prompt_driver.py +++ b/griptape/drivers/prompt/dummy_prompt_driver.py @@ -17,3 +17,6 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact: def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: raise DummyException(__class__.__name__, "try_stream") + + def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + raise DummyException(__class__.__name__, "_prompt_stack_input_to_message") diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index 9f833c035..67bc19e24 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -19,7 +19,6 @@ class GooglePromptDriver(BasePromptDriver): api_key: Google API key. model: Google model name. model_client: Custom `GenerativeModel` client. - tokenizer: Custom `GoogleTokenizer`. top_p: Optional value for top_p. top_k: Optional value for top_k. """ @@ -42,7 +41,7 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact: inputs, generation_config=GenerationConfig( stop_sequences=self.tokenizer.stop_sequences, - max_output_tokens=self.max_output_tokens(inputs), + max_output_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, @@ -60,7 +59,7 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: stream=True, generation_config=GenerationConfig( stop_sequences=self.tokenizer.stop_sequences, - max_output_tokens=self.max_output_tokens(inputs), + max_output_tokens=self.max_tokens, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, @@ -70,6 +69,14 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: for chunk in response: yield TextArtifact(value=chunk.text) + def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + parts = [prompt_input.content] + + if prompt_input.is_assistant(): + return {"role": "model", "parts": parts} + else: + return {"role": "user", "parts": parts} + def _default_model_client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) @@ -90,13 +97,6 @@ def _prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> list[Conten def __to_content_dict(self, prompt_input: PromptStack.Input) -> ContentDict: ContentDict = import_optional_dependency("google.generativeai.types").ContentDict + message = self._prompt_stack_input_to_message(prompt_input) - return ContentDict({"role": self.__to_google_role(prompt_input), "parts": [prompt_input.content]}) - - def __to_google_role(self, prompt_input: PromptStack.Input) -> str: - if prompt_input.is_system(): - return "user" - elif prompt_input.is_assistant(): - return "model" - else: - return "user" + return ContentDict(message) diff --git a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py index 062672aa8..3edd252cb 100644 --- a/griptape/drivers/prompt/huggingface_hub_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_hub_prompt_driver.py @@ -42,11 +42,7 @@ class HuggingFaceHubPromptDriver(BasePromptDriver): ) tokenizer: HuggingFaceTokenizer = field( default=Factory( - lambda self: HuggingFaceTokenizer( - tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model), - max_output_tokens=self.max_tokens, - ), - takes_self=True, + lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True ), kw_only=True, ) @@ -55,7 +51,7 @@ def try_run(self, prompt_stack: PromptStack) -> TextArtifact: prompt = self.prompt_stack_to_string(prompt_stack) response = self.client.text_generation( - prompt, return_full_text=False, max_new_tokens=self.max_output_tokens(prompt), **self.params + prompt, return_full_text=False, max_new_tokens=self.max_tokens, **self.params ) return TextArtifact(value=response) @@ -64,8 +60,26 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: prompt = self.prompt_stack_to_string(prompt_stack) response = self.client.text_generation( - prompt, return_full_text=False, max_new_tokens=self.max_output_tokens(prompt), stream=True, **self.params + prompt, return_full_text=False, max_new_tokens=self.max_tokens, stream=True, **self.params ) for token in response: yield TextArtifact(value=token) + + def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + return {"role": prompt_input.role, "content": prompt_input.content} + + def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: + return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) + + def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: + tokens = self.tokenizer.tokenizer.apply_chat_template( + [self._prompt_stack_input_to_message(i) for i in prompt_stack.inputs], + add_generation_prompt=True, + tokenize=True, + ) + + if isinstance(tokens, list): + return tokens + else: + raise ValueError("Invalid output type.") diff --git a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py index bde6d5e4e..4fa291877 100644 --- a/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py +++ b/griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py @@ -1,5 +1,7 @@ +from __future__ import annotations from collections.abc import Iterator +from typing import TYPE_CHECKING from attrs import Factory, define, field from griptape.artifacts import TextArtifact @@ -7,6 +9,9 @@ from griptape.tokenizers import HuggingFaceTokenizer from griptape.utils import PromptStack, import_optional_dependency +if TYPE_CHECKING: + from transformers import TextGenerationPipeline + @define class HuggingFacePipelinePromptDriver(BasePromptDriver): @@ -14,48 +19,66 @@ class HuggingFacePipelinePromptDriver(BasePromptDriver): Attributes: params: Custom model run parameters. model: Hugging Face Hub model name. - tokenizer: Custom `HuggingFaceTokenizer`. """ - SUPPORTED_TASKS = ["text2text-generation", "text-generation"] - DEFAULT_PARAMS = {"return_full_text": False, "num_return_sequences": 1} - max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) tokenizer: HuggingFaceTokenizer = field( default=Factory( - lambda self: HuggingFaceTokenizer( - tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model), - max_output_tokens=self.max_tokens, - ), - takes_self=True, + lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True ), kw_only=True, ) + pipe: TextGenerationPipeline = field( + default=Factory( + lambda self: import_optional_dependency("transformers").pipeline( + "text-generation", model=self.model, max_new_tokens=self.max_tokens, tokenizer=self.tokenizer.tokenizer + ), + takes_self=True, + ) + ) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - prompt = self.prompt_stack_to_string(prompt_stack) - pipeline = import_optional_dependency("transformers").pipeline + messages = [self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs] - generator = pipeline( + result = self.pipe( + messages, + max_new_tokens=self.max_tokens, tokenizer=self.tokenizer.tokenizer, - model=self.model, - max_new_tokens=self.tokenizer.count_output_tokens_left(prompt), + stop_strings=self.tokenizer.stop_sequences, + temperature=self.temperature, + do_sample=True, ) - if generator.task in self.SUPPORTED_TASKS: - extra_params = {"pad_token_id": self.tokenizer.tokenizer.eos_token_id} - - response = generator(prompt, **(self.DEFAULT_PARAMS | extra_params | self.params)) + if isinstance(result, list): + if len(result) == 1: + generated_text = result[0]["generated_text"][-1]["content"] - if len(response) == 1: - return TextArtifact(value=response[0]["generated_text"].strip()) + return TextArtifact(value=generated_text) else: raise Exception("completion with more than one choice is not supported yet") else: - raise Exception(f"only models with the following tasks are supported: {self.SUPPORTED_TASKS}") + raise Exception("invalid output format") def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: raise NotImplementedError("streaming is not supported") + + def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: + return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) + + def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + return {"role": prompt_input.role, "content": prompt_input.content} + + def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: + tokens = self.tokenizer.tokenizer.apply_chat_template( + [self._prompt_stack_input_to_message(i) for i in prompt_stack.inputs], + add_generation_prompt=True, + tokenize=True, + ) + + if isinstance(tokens, list): + return tokens + else: + raise ValueError("Invalid output type.") diff --git a/griptape/drivers/prompt/ollama_prompt_driver.py b/griptape/drivers/prompt/ollama_prompt_driver.py new file mode 100644 index 000000000..b21176e82 --- /dev/null +++ b/griptape/drivers/prompt/ollama_prompt_driver.py @@ -0,0 +1,69 @@ +from __future__ import annotations +from collections.abc import Iterator +from typing import TYPE_CHECKING, Optional +from attrs import define, field, Factory +from griptape.artifacts import TextArtifact +from griptape.drivers import BasePromptDriver +from griptape.tokenizers.base_tokenizer import BaseTokenizer +from griptape.utils import PromptStack, import_optional_dependency +from griptape.tokenizers import SimpleTokenizer + +if TYPE_CHECKING: + from ollama import Client + + +@define +class OllamaPromptDriver(BasePromptDriver): + """ + Attributes: + model: Model name. + """ + + model: str = field(kw_only=True, metadata={"serializable": True}) + host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + client: Client = field( + default=Factory(lambda self: import_optional_dependency("ollama").Client(host=self.host), takes_self=True), + kw_only=True, + ) + tokenizer: BaseTokenizer = field( + default=Factory( + lambda self: SimpleTokenizer( + characters_per_token=4, max_input_tokens=2000, max_output_tokens=self.max_tokens + ), + takes_self=True, + ), + kw_only=True, + ) + options: dict = field( + default=Factory( + lambda self: { + "temperature": self.temperature, + "stop": self.tokenizer.stop_sequences, + "num_predict": self.max_tokens, + }, + takes_self=True, + ), + kw_only=True, + ) + + def try_run(self, prompt_stack: PromptStack) -> TextArtifact: + response = self.client.chat(**self._base_params(prompt_stack)) + + if isinstance(response, dict): + return TextArtifact(value=response["message"]["content"]) + else: + raise Exception("invalid model response") + + def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: + stream = self.client.chat(**self._base_params(prompt_stack), stream=True) + + if isinstance(stream, Iterator): + for chunk in stream: + yield TextArtifact(value=chunk["message"]["content"]) + else: + raise Exception("invalid model response") + + def _base_params(self, prompt_stack: PromptStack) -> dict: + messages = [{"role": input.role, "content": input.content} for input in prompt_stack.inputs] + + return {"messages": messages, "model": self.model, "options": self.options} diff --git a/griptape/drivers/prompt/openai_chat_prompt_driver.py b/griptape/drivers/prompt/openai_chat_prompt_driver.py index 3d19063d3..9545bd45a 100644 --- a/griptape/drivers/prompt/openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/openai_chat_prompt_driver.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Optional, Any, Literal +from typing import Optional, Literal from collections.abc import Iterator import openai from attrs import define, field, Factory @@ -7,8 +7,6 @@ from griptape.utils import PromptStack from griptape.drivers import BasePromptDriver from griptape.tokenizers import OpenAiTokenizer, BaseTokenizer -import dateparser -from datetime import datetime, timedelta @define @@ -25,12 +23,6 @@ class OpenAiChatPromptDriver(BasePromptDriver): response_format: An optional OpenAi Chat Completion response format. Currently only supports `json_object` which will enable OpenAi's JSON mode. seed: An optional OpenAi Chat Completion seed. ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types. - _ratelimit_request_limit: The maximum number of requests allowed in the current rate limit window. - _ratelimit_requests_remaining: The number of requests remaining in the current rate limit window. - _ratelimit_requests_reset_at: The time at which the current rate limit window resets. - _ratelimit_token_limit: The maximum number of tokens allowed in the current rate limit window. - _ratelimit_tokens_remaining: The number of tokens remaining in the current rate limit window. - _ratelimit_tokens_reset_at: The time at which the current rate limit window resets. """ base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) @@ -64,22 +56,12 @@ class OpenAiChatPromptDriver(BasePromptDriver): ), kw_only=True, ) - _ratelimit_request_limit: Optional[int] = field(init=False, default=None) - _ratelimit_requests_remaining: Optional[int] = field(init=False, default=None) - _ratelimit_requests_reset_at: Optional[datetime] = field(init=False, default=None) - _ratelimit_token_limit: Optional[int] = field(init=False, default=None) - _ratelimit_tokens_remaining: Optional[int] = field(init=False, default=None) - _ratelimit_tokens_reset_at: Optional[datetime] = field(init=False, default=None) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - result = self.client.chat.completions.with_raw_response.create(**self._base_params(prompt_stack)) + result = self.client.chat.completions.create(**self._base_params(prompt_stack)) - self._extract_ratelimit_metadata(result) - - parsed_result = result.parse() - - if len(parsed_result.choices) == 1: - return TextArtifact(value=parsed_result.choices[0].message.content.strip()) + if len(result.choices) == 1: + return TextArtifact(value=result.choices[0].message.content.strip()) else: raise Exception("Completion with more than one choice is not supported yet.") @@ -97,14 +79,15 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: yield TextArtifact(value=delta_content) - def token_count(self, prompt_stack: PromptStack) -> int: - if isinstance(self.tokenizer, OpenAiTokenizer): - return self.tokenizer.count_tokens(self._prompt_stack_to_messages(prompt_stack)) - else: - return self.tokenizer.count_tokens(self.prompt_stack_to_string(prompt_stack)) + def _prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: + content = prompt_input.content - def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict[str, Any]]: - return [{"role": self.__to_openai_role(i), "content": i.content} for i in prompt_stack.inputs] + if prompt_input.is_system(): + return {"role": "system", "content": content} + elif prompt_input.is_assistant(): + return {"role": "assistant", "content": content} + else: + return {"role": "user", "content": content} def _base_params(self, prompt_stack: PromptStack) -> dict: params = { @@ -120,7 +103,7 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: # JSON mode still requires a system input instructing the LLM to output JSON. prompt_stack.add_system_input("Provide your response as a valid JSON object.") - messages = self._prompt_stack_to_messages(prompt_stack) + messages = [self._prompt_stack_input_to_message(input) for input in prompt_stack.inputs] if self.max_tokens is not None: params["max_tokens"] = self.max_tokens @@ -128,41 +111,3 @@ def _base_params(self, prompt_stack: PromptStack) -> dict: params["messages"] = messages return params - - def __to_openai_role(self, prompt_input: PromptStack.Input) -> str: - if prompt_input.is_system(): - return "system" - elif prompt_input.is_assistant(): - return "assistant" - else: - return "user" - - def _extract_ratelimit_metadata(self, response): - # The OpenAI SDK's requestssession variable is global, so this hook will fire for all API requests. - # The following headers are not reliably returned in every API call, so we check for the presence of the - # headers before reading and parsing their values to prevent other SDK users from encountering KeyErrors. - reset_requests_at = response.headers.get("x-ratelimit-reset-requests") - if reset_requests_at is not None: - self._ratelimit_requests_reset_at = dateparser.parse( - reset_requests_at, settings={"PREFER_DATES_FROM": "future"} - ) - - # The dateparser utility doesn't handle sub-second durations as are sometimes returned by OpenAI's API. - # If the API returns, for example, "13ms", dateparser.parse() returns None. In this case, we will set - # the time value to the current time plus a one second buffer. - if self._ratelimit_requests_reset_at is None: - self._ratelimit_requests_reset_at = datetime.now() + timedelta(seconds=1) - - reset_tokens_at = response.headers.get("x-ratelimit-reset-tokens") - if reset_tokens_at is not None: - self._ratelimit_tokens_reset_at = dateparser.parse( - reset_tokens_at, settings={"PREFER_DATES_FROM": "future"} - ) - - if self._ratelimit_tokens_reset_at is None: - self._ratelimit_tokens_reset_at = datetime.now() + timedelta(seconds=1) - - self._ratelimit_request_limit = response.headers.get("x-ratelimit-limit-requests") - self._ratelimit_requests_remaining = response.headers.get("x-ratelimit-remaining-requests") - self._ratelimit_token_limit = response.headers.get("x-ratelimit-limit-tokens") - self._ratelimit_tokens_remaining = response.headers.get("x-ratelimit-remaining-tokens") diff --git a/griptape/drivers/prompt/openai_completion_prompt_driver.py b/griptape/drivers/prompt/openai_completion_prompt_driver.py deleted file mode 100644 index 1a738a487..000000000 --- a/griptape/drivers/prompt/openai_completion_prompt_driver.py +++ /dev/null @@ -1,83 +0,0 @@ -from typing import Optional -from collections.abc import Iterator -from attrs import define, field, Factory -from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack -from griptape.drivers import BasePromptDriver -from griptape.tokenizers import OpenAiTokenizer -import openai - - -@define -class OpenAiCompletionPromptDriver(BasePromptDriver): - """ - Attributes: - base_url: An optional OpenAi API URL. - api_key: An optional OpenAi API key. If not provided, the `OPENAI_API_KEY` environment variable will be used. - organization: An optional OpenAI organization. If not provided, the `OPENAI_ORG_ID` environment variable will be used. - client: An `openai.OpenAI` client. - model: An OpenAI model name. - tokenizer: An `OpenAiTokenizer`. - user: A user id. Can be used to track requests by user. - ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types. - """ - - base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) - organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - client: openai.OpenAI = field( - default=Factory( - lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), - takes_self=True, - ) - ) - model: str = field(kw_only=True, metadata={"serializable": True}) - tokenizer: OpenAiTokenizer = field( - default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True - ) - user: str = field(default="", kw_only=True, metadata={"serializable": True}) - ignored_exception_types: tuple[type[Exception], ...] = field( - default=Factory( - lambda: ( - openai.BadRequestError, - openai.AuthenticationError, - openai.PermissionDeniedError, - openai.NotFoundError, - openai.ConflictError, - openai.UnprocessableEntityError, - ) - ), - kw_only=True, - ) - - def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - result = self.client.completions.create(**self._base_params(prompt_stack)) - - if len(result.choices) == 1: - return TextArtifact(value=result.choices[0].text.strip()) - else: - raise Exception("completion with more than one choice is not supported yet") - - def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - result = self.client.completions.create(**self._base_params(prompt_stack), stream=True) - - for chunk in result: - if len(chunk.choices) == 1: - choice = chunk.choices[0] - delta_content = choice.text - yield TextArtifact(value=delta_content) - - else: - raise Exception("completion with more than one choice is not supported yet") - - def _base_params(self, prompt_stack: PromptStack) -> dict: - prompt = self.prompt_stack_to_string(prompt_stack) - - return { - "model": self.model, - "max_tokens": self.max_output_tokens(prompt), - "temperature": self.temperature, - "stop": self.tokenizer.stop_sequences, - "user": self.user, - "prompt": prompt, - } diff --git a/griptape/drivers/prompt_model/__init__.py b/griptape/drivers/prompt_model/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/griptape/drivers/prompt_model/base_prompt_model_driver.py b/griptape/drivers/prompt_model/base_prompt_model_driver.py deleted file mode 100644 index 096802370..000000000 --- a/griptape/drivers/prompt_model/base_prompt_model_driver.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations -from abc import ABC, abstractmethod -from typing import Optional -from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack -from griptape.drivers import BasePromptDriver -from griptape.tokenizers import BaseTokenizer -from griptape.mixins import SerializableMixin - - -@define -class BasePromptModelDriver(SerializableMixin, ABC): - max_tokens: Optional[int] = field(default=None, kw_only=True) - prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True) - supports_streaming: bool = field(default=True, kw_only=True) - - @property - @abstractmethod - def tokenizer(self) -> BaseTokenizer: ... - - @abstractmethod - def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str | list | dict: ... - - @abstractmethod - def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: ... - - @abstractmethod - def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: ... diff --git a/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py deleted file mode 100644 index 2b4c547a9..000000000 --- a/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations -from typing import Optional -import json -from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack -from griptape.drivers import BasePromptModelDriver, AmazonBedrockPromptDriver -from griptape.tokenizers import BedrockClaudeTokenizer - - -@define -class BedrockClaudePromptModelDriver(BasePromptModelDriver): - ANTHROPIC_VERSION = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example - - top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True}) - top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) - _tokenizer: BedrockClaudeTokenizer = field(default=None, kw_only=True) - prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) - - @property - def tokenizer(self) -> BedrockClaudeTokenizer: - """Returns the tokenizer for this driver. - - We need to pass the `session` field from the Prompt Driver to the - Tokenizer. However, the Prompt Driver is not initialized until after - the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer` - field a @property that is only initialized when it is first accessed. - This ensures that by the time we need to initialize the Tokenizer, the - Prompt Driver has already been initialized. - - See this thread more more information: https://github.com/griptape-ai/griptape/issues/244 - - Returns: - BedrockClaudeTokenizer: The tokenizer for this driver. - """ - if self._tokenizer: - return self._tokenizer - else: - self._tokenizer = BedrockClaudeTokenizer(model=self.prompt_driver.model) - return self._tokenizer - - def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict: - messages = [ - {"role": self.__to_anthropic_role(prompt_input), "content": prompt_input.content} - for prompt_input in prompt_stack.inputs - if not prompt_input.is_system() - ] - system = next((i for i in prompt_stack.inputs if i.is_system()), None) - - if system is None: - return {"messages": messages} - else: - return {"messages": messages, "system": system.content} - - def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: - input = self.prompt_stack_to_model_input(prompt_stack) - - return { - "stop_sequences": self.tokenizer.stop_sequences, - "temperature": self.prompt_driver.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - "max_tokens": self.prompt_driver.max_output_tokens(self.prompt_driver.prompt_stack_to_string(prompt_stack)), - "anthropic_version": self.ANTHROPIC_VERSION, - **input, - } - - def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: - if isinstance(output, bytes): - body = json.loads(output.decode()) - else: - raise Exception("Output must be bytes.") - - if body["type"] == "content_block_delta": - return TextArtifact(value=body["delta"]["text"]) - elif body["type"] == "message": - return TextArtifact(value=body["content"][0]["text"]) - else: - return TextArtifact(value="") - - def __to_anthropic_role(self, prompt_input: PromptStack.Input) -> str: - if prompt_input.is_system(): - return "system" - elif prompt_input.is_assistant(): - return "assistant" - else: - return "user" diff --git a/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py deleted file mode 100644 index 4da99e88f..000000000 --- a/griptape/drivers/prompt_model/bedrock_jurassic_prompt_model_driver.py +++ /dev/null @@ -1,76 +0,0 @@ -from __future__ import annotations -from typing import Optional -import json -from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack -from griptape.drivers import BasePromptModelDriver -from griptape.tokenizers import BedrockJurassicTokenizer -from griptape.drivers import AmazonBedrockPromptDriver - - -@define -class BedrockJurassicPromptModelDriver(BasePromptModelDriver): - top_p: float = field(default=0.9, kw_only=True, metadata={"serializable": True}) - _tokenizer: BedrockJurassicTokenizer = field(default=None, kw_only=True) - prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) - supports_streaming: bool = field(default=False, kw_only=True) - - @property - def tokenizer(self) -> BedrockJurassicTokenizer: - """Returns the tokenizer for this driver. - - We need to pass the `session` field from the Prompt Driver to the - Tokenizer. However, the Prompt Driver is not initialized until after - the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer` - field a @property that is only initialized when it is first accessed. - This ensures that by the time we need to initialize the Tokenizer, the - Prompt Driver has already been initialized. - - See this thread more more information: https://github.com/griptape-ai/griptape/issues/244 - - Returns: - BedrockJurassicTokenizer: The tokenizer for this driver. - """ - if self._tokenizer: - return self._tokenizer - else: - self._tokenizer = BedrockJurassicTokenizer(model=self.prompt_driver.model) - return self._tokenizer - - def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict: - prompt_lines = [] - - for i in prompt_stack.inputs: - if i.is_user(): - prompt_lines.append(f"User: {i.content}") - elif i.is_assistant(): - prompt_lines.append(f"Assistant: {i.content}") - elif i.is_system(): - prompt_lines.append(f"System: {i.content}") - else: - prompt_lines.append(i.content) - prompt_lines.append("Assistant:") - - prompt = "\n".join(prompt_lines) - - return {"prompt": prompt} - - def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: - prompt = self.prompt_stack_to_model_input(prompt_stack)["prompt"] - - return { - "maxTokens": self.prompt_driver.max_output_tokens(prompt), - "temperature": self.prompt_driver.temperature, - "stopSequences": self.tokenizer.stop_sequences, - "countPenalty": {"scale": 0}, - "presencePenalty": {"scale": 0}, - "frequencyPenalty": {"scale": 0}, - } - - def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: - if isinstance(output, bytes): - body = json.loads(output.decode()) - else: - raise Exception("Output must be bytes.") - return TextArtifact(body["completions"][0]["data"]["text"]) diff --git a/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py deleted file mode 100644 index 951583c51..000000000 --- a/griptape/drivers/prompt_model/bedrock_llama_prompt_model_driver.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations -import json -import itertools as it -from typing import Optional -from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack -from griptape.drivers import BasePromptModelDriver -from griptape.tokenizers import BedrockLlamaTokenizer -from griptape.drivers import AmazonBedrockPromptDriver - - -@define -class BedrockLlamaPromptModelDriver(BasePromptModelDriver): - top_p: float = field(default=0.9, kw_only=True) - _tokenizer: BedrockLlamaTokenizer = field(default=None, kw_only=True) - prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) - - @property - def tokenizer(self) -> BedrockLlamaTokenizer: - """Returns the tokenizer for this driver. - - We need to pass the `session` field from the Prompt Driver to the - Tokenizer. However, the Prompt Driver is not initialized until after - the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer` - field a @property that is only initialized when it is first accessed. - This ensures that by the time we need to initialize the Tokenizer, the - Prompt Driver has already been initialized. - - See this thread more more information: https://github.com/griptape-ai/griptape/issues/244 - - Returns: - BedrockLlamaTokenizer: The tokenizer for this driver. - """ - if self._tokenizer: - return self._tokenizer - else: - self._tokenizer = BedrockLlamaTokenizer(model=self.prompt_driver.model) - return self._tokenizer - - def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str: - """ - Converts a `PromptStack` to a string that can be used as the input to the model. - - Prompt structure adapted from https://huggingface.co/blog/llama2#how-to-prompt-llama-2 - - Args: - prompt_stack: The `PromptStack` to convert. - """ - prompt_lines = [] - - inputs = iter(prompt_stack.inputs) - input_pairs: list[tuple] = list(it.zip_longest(inputs, inputs)) - for input_pair in input_pairs: - first_input: PromptStack.Input = input_pair[0] - second_input: Optional[PromptStack.Input] = input_pair[1] - - if first_input.is_system(): - prompt_lines.append(f"[INST] <>\n{first_input.content}\n<>\n\n") - if second_input: - if second_input.is_user(): - prompt_lines.append(f"{second_input.content} [/INST]") - else: - raise Exception("System input must be followed by user input.") - elif first_input.is_assistant(): - prompt_lines.append(f" {first_input.content} ") - if second_input: - if second_input.is_user(): - prompt_lines.append(f"[INST] {second_input.content} [/INST]") - else: - raise Exception("Assistant input must be followed by user input.") - elif first_input.is_user(): - prompt_lines.append(f"[INST] {first_input.content} [/INST]") - if second_input: - if second_input.is_assistant(): - prompt_lines.append(f" {second_input.content} ") - else: - raise Exception("User input must be followed by assistant input.") - - return "".join(prompt_lines) - - def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: - prompt = self.prompt_stack_to_model_input(prompt_stack) - - return { - "prompt": prompt, - "max_gen_len": self.prompt_driver.max_output_tokens(prompt), - "temperature": self.prompt_driver.temperature, - "top_p": self.top_p, - } - - def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: - # When streaming, the response body comes back as bytes. - if isinstance(output, bytes): - output = output.decode() - elif isinstance(output, list) or isinstance(output, dict): - raise Exception("Invalid output format.") - - body = json.loads(output) - - return TextArtifact(body["generation"]) diff --git a/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py deleted file mode 100644 index 5f5bbc1d2..000000000 --- a/griptape/drivers/prompt_model/bedrock_titan_prompt_model_driver.py +++ /dev/null @@ -1,83 +0,0 @@ -from __future__ import annotations -from typing import Optional -import json -from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack -from griptape.drivers import BasePromptModelDriver -from griptape.tokenizers import BedrockTitanTokenizer -from griptape.drivers import AmazonBedrockPromptDriver - - -@define -class BedrockTitanPromptModelDriver(BasePromptModelDriver): - top_p: float = field(default=0.9, kw_only=True, metadata={"serializable": True}) - _tokenizer: BedrockTitanTokenizer = field(default=None, kw_only=True) - prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) - - @property - def tokenizer(self) -> BedrockTitanTokenizer: - """Returns the tokenizer for this driver. - - We need to pass the `session` field from the Prompt Driver to the - Tokenizer. However, the Prompt Driver is not initialized until after - the Prompt Model Driver is initialized. To resolve this, we make the `tokenizer` - field a @property that is only initialized when it is first accessed. - This ensures that by the time we need to initialize the Tokenizer, the - Prompt Driver has already been initialized. - - See this thread for more information: https://github.com/griptape-ai/griptape/issues/244 - - Returns: - BedrockTitanTokenizer: The tokenizer for this driver. - """ - if self._tokenizer: - return self._tokenizer - else: - self._tokenizer = BedrockTitanTokenizer(model=self.prompt_driver.model) - return self._tokenizer - - def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> dict: - prompt_lines = [] - - for i in prompt_stack.inputs: - if i.is_user(): - prompt_lines.append(f"User: {i.content}") - elif i.is_assistant(): - prompt_lines.append(f"Bot: {i.content}") - elif i.is_system(): - prompt_lines.append(f"Instructions: {i.content}") - else: - prompt_lines.append(i.content) - prompt_lines.append("Bot:") - - prompt = "\n\n".join(prompt_lines) - - return {"inputText": prompt} - - def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: - prompt = self.prompt_stack_to_model_input(prompt_stack)["inputText"] - - return { - "textGenerationConfig": { - "maxTokenCount": self.prompt_driver.max_output_tokens(prompt), - "stopSequences": self.tokenizer.stop_sequences, - "temperature": self.prompt_driver.temperature, - "topP": self.top_p, - } - } - - def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: - # When streaming, the response body comes back as bytes. - if isinstance(output, str) or isinstance(output, bytes): - if isinstance(output, bytes): - output = output.decode() - - body = json.loads(output) - - if self.prompt_driver.stream: - return TextArtifact(body["outputText"]) - else: - return TextArtifact(body["results"][0]["outputText"]) - else: - raise ValueError("output must be an instance of 'str' or 'bytes'") diff --git a/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py b/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py deleted file mode 100644 index a5a8a4dc9..000000000 --- a/griptape/drivers/prompt_model/sagemaker_falcon_prompt_model_driver.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import annotations -from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack, import_optional_dependency -from griptape.drivers import BasePromptModelDriver -from griptape.tokenizers import HuggingFaceTokenizer - - -@define -class SageMakerFalconPromptModelDriver(BasePromptModelDriver): - DEFAULT_MAX_TOKENS = 600 - - _tokenizer: HuggingFaceTokenizer = field(default=None, kw_only=True) - - @property - def tokenizer(self) -> HuggingFaceTokenizer: - if self._tokenizer is None: - self._tokenizer = HuggingFaceTokenizer( - tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained("tiiuae/falcon-40b"), - max_output_tokens=self.max_tokens or self.DEFAULT_MAX_TOKENS, - ) - return self._tokenizer - - def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str: - return self.prompt_driver.prompt_stack_to_string(prompt_stack) - - def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: - prompt = self.prompt_stack_to_model_input(prompt_stack) - stop_sequences = self.prompt_driver.tokenizer.stop_sequences - - return { - "max_new_tokens": self.prompt_driver.max_output_tokens(prompt), - "temperature": self.prompt_driver.temperature, - "do_sample": True, - "stop": stop_sequences, - } - - def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: - if isinstance(output, list): - return TextArtifact(output[0]["generated_text"].strip()) - else: - raise ValueError("output must be an instance of 'list'") diff --git a/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py b/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py deleted file mode 100644 index 7e934d4a6..000000000 --- a/griptape/drivers/prompt_model/sagemaker_llama_prompt_model_driver.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations -from attrs import define, field -from griptape.artifacts import TextArtifact -from griptape.utils import PromptStack, import_optional_dependency -from griptape.drivers import BasePromptModelDriver -from griptape.tokenizers import HuggingFaceTokenizer - - -@define -class SageMakerLlamaPromptModelDriver(BasePromptModelDriver): - # Default context length for all Llama 3 models is 8K as per https://huggingface.co/blog/llama3 - DEFAULT_MAX_INPUT_TOKENS = 8000 - - _tokenizer: HuggingFaceTokenizer = field(default=None, kw_only=True) - - @property - def tokenizer(self) -> HuggingFaceTokenizer: - if self._tokenizer is None: - self._tokenizer = HuggingFaceTokenizer( - tokenizer=import_optional_dependency("transformers").AutoTokenizer.from_pretrained( - "meta-llama/Meta-Llama-3-8B-Instruct", model_max_length=self.DEFAULT_MAX_INPUT_TOKENS - ), - max_output_tokens=self.max_tokens or self.DEFAULT_MAX_INPUT_TOKENS, - ) - return self._tokenizer - - def prompt_stack_to_model_input(self, prompt_stack: PromptStack) -> str: - return self.tokenizer.tokenizer.apply_chat_template( # pyright: ignore - [{"role": i.role, "content": i.content} for i in prompt_stack.inputs], - tokenize=False, - add_generation_prompt=True, - ) - - def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: - prompt = self.prompt_driver.prompt_stack_to_string(prompt_stack) - return { - "max_new_tokens": self.prompt_driver.max_output_tokens(prompt), - "temperature": self.prompt_driver.temperature, - "stop": self.tokenizer.tokenizer.eos_token, - } - - def process_output(self, output: dict | list[dict] | str | bytes) -> TextArtifact: - # This output format is specific to the Llama 3 Instruct models when deployed via SageMaker JumpStart. - if isinstance(output, dict): - return TextArtifact(output["generated_text"]) - else: - raise ValueError("Invalid output format.") diff --git a/griptape/engines/query/vector_query_engine.py b/griptape/engines/query/vector_query_engine.py index adfa4b2db..24338b348 100644 --- a/griptape/engines/query/vector_query_engine.py +++ b/griptape/engines/query/vector_query_engine.py @@ -49,12 +49,14 @@ def query( ) user_message = self.user_template_generator.render(query=query) - message_token_count = self.prompt_driver.token_count( - PromptStack( - inputs=[ - PromptStack.Input(system_message, role=PromptStack.SYSTEM_ROLE), - PromptStack.Input(user_message, role=PromptStack.USER_ROLE), - ] + message_token_count = self.prompt_driver.tokenizer.count_input_tokens_left( + self.prompt_driver.prompt_stack_to_string( + PromptStack( + inputs=[ + PromptStack.Input(system_message, role=PromptStack.SYSTEM_ROLE), + PromptStack.Input(user_message, role=PromptStack.USER_ROLE), + ] + ) ) ) diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 0da99cb0a..9d3e8db78 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -13,7 +13,8 @@ class PromptSummaryEngine(BaseSummaryEngine): chunk_joiner: str = field(default="\n\n", kw_only=True) max_token_multiplier: float = field(default=0.5, kw_only=True) - template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/prompt_summary.j2")), kw_only=True) + system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) + user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) prompt_driver: BasePromptDriver = field(kw_only=True) chunker: BaseChunker = field( default=Factory( @@ -49,25 +50,38 @@ def summarize_artifacts_rec( ) -> TextArtifact: artifacts_text = self.chunk_joiner.join([a.to_text() for a in artifacts]) - full_text = self.template_generator.render( - summary=summary, text=artifacts_text, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets) + system_prompt = self.system_template_generator.render( + summary=summary, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets) ) - if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: + user_prompt = self.user_template_generator.render(text=artifacts_text) + + if ( + self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt) + >= self.min_response_tokens + ): return self.prompt_driver.run( - PromptStack(inputs=[PromptStack.Input(full_text, role=PromptStack.USER_ROLE)]) + PromptStack( + inputs=[ + PromptStack.Input(system_prompt, role=PromptStack.SYSTEM_ROLE), + PromptStack.Input(user_prompt, role=PromptStack.USER_ROLE), + ] + ) ) else: chunks = self.chunker.chunk(artifacts_text) - partial_text = self.template_generator.render( - summary=summary, text=chunks[0].value, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets) - ) + partial_text = self.user_template_generator.render(text=chunks[0].value) return self.summarize_artifacts_rec( chunks[1:], self.prompt_driver.run( - PromptStack(inputs=[PromptStack.Input(partial_text, role=PromptStack.USER_ROLE)]) + PromptStack( + inputs=[ + PromptStack.Input(system_prompt, role=PromptStack.SYSTEM_ROLE), + PromptStack.Input(partial_text, role=PromptStack.USER_ROLE), + ] + ) ).value, rulesets=rulesets, ) diff --git a/griptape/loaders/web_loader.py b/griptape/loaders/web_loader.py index f8862f7cd..64a52b65f 100644 --- a/griptape/loaders/web_loader.py +++ b/griptape/loaders/web_loader.py @@ -13,5 +13,8 @@ class WebLoader(BaseTextLoader): ) def load(self, source: str, *args, **kwargs) -> ErrorArtifact | list[TextArtifact]: - single_chunk_text_artifact = self.web_scraper_driver.scrape_url(source) - return self._text_to_artifacts(single_chunk_text_artifact.value) + try: + single_chunk_text_artifact = self.web_scraper_driver.scrape_url(source) + return self._text_to_artifacts(single_chunk_text_artifact.value) + except Exception as e: + return ErrorArtifact(f"Error loading from source: {source}", exception=e) diff --git a/griptape/memory/structure/base_conversation_memory.py b/griptape/memory/structure/base_conversation_memory.py index 6db05c92c..f8cc51743 100644 --- a/griptape/memory/structure/base_conversation_memory.py +++ b/griptape/memory/structure/base_conversation_memory.py @@ -45,3 +45,52 @@ def try_add_run(self, run: Run) -> None: ... @abstractmethod def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: ... + + def add_to_prompt_stack(self, prompt_stack: PromptStack, index: Optional[int] = None) -> PromptStack: + """Add the Conversation Memory runs to the Prompt Stack by modifying the inputs in place. + + If autoprune is enabled, this will fit as many Conversation Memory runs into the Prompt Stack + as possible without exceeding the token limit. + + Args: + prompt_stack: The Prompt Stack to add the Conversation Memory to. + index: Optional index to insert the Conversation Memory runs at. + Defaults to appending to the end of the Prompt Stack. + """ + num_runs_to_fit_in_prompt = len(self.runs) + + if self.autoprune and hasattr(self, "structure"): + should_prune = True + prompt_driver = self.structure.config.prompt_driver + temp_stack = PromptStack() + + # Try to determine how many Conversation Memory runs we can + # fit into the Prompt Stack without exceeding the token limit. + while should_prune and num_runs_to_fit_in_prompt > 0: + temp_stack.inputs = prompt_stack.inputs.copy() + + # Add n runs from Conversation Memory. + # Where we insert into the Prompt Stack doesn't matter here + # since we only care about the total token count. + memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).inputs + temp_stack.inputs.extend(memory_inputs) + + # Convert the prompt stack into tokens left. + tokens_left = prompt_driver.tokenizer.count_input_tokens_left( + prompt_driver.prompt_stack_to_string(temp_stack) + ) + if tokens_left > 0: + # There are still tokens left, no need to prune. + should_prune = False + else: + # There were not any tokens left, prune one run and try again. + num_runs_to_fit_in_prompt -= 1 + + if num_runs_to_fit_in_prompt: + memory_inputs = self.to_prompt_stack(num_runs_to_fit_in_prompt).inputs + if index: + prompt_stack.inputs[index:index] = memory_inputs + else: + prompt_stack.inputs.extend(memory_inputs) + + return prompt_stack diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 79ae5e51f..6ba23d6fe 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -103,13 +103,14 @@ def _resolve_types(cls, attrs_cls: type) -> None: from griptape.utils.import_utils import import_optional_dependency, is_dependency_installed # These modules are required to avoid `NameError`s when resolving types. - from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver, BasePromptModelDriver + from griptape.drivers import BaseConversationMemoryDriver, BasePromptDriver from griptape.structures import Structure from griptape.utils import PromptStack from griptape.tokenizers.base_tokenizer import BaseTokenizer from typing import Any boto3 = import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any + Client = import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any attrs.resolve_types( attrs_cls, @@ -120,8 +121,8 @@ def _resolve_types(cls, attrs_cls: type) -> None: "BaseConversationMemoryDriver": BaseConversationMemoryDriver, "BasePromptDriver": BasePromptDriver, "BaseTokenizer": BaseTokenizer, - "BasePromptModelDriver": BasePromptModelDriver, "boto3": boto3, + "Client": Client, }, ) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 79b831f63..d0446aff0 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -47,13 +47,9 @@ def add_tasks(self, *tasks: BaseTask) -> list[BaseTask]: return super().add_tasks(*tasks) def try_run(self, *args) -> Agent: - self._execution_args = args - - self.task.reset() - self.task.execute() - if self.conversation_memory: + if self.conversation_memory and self.output is not None: if isinstance(self.task.input, tuple): input_text = self.task.input[0].to_text() else: diff --git a/griptape/structures/pipeline.py b/griptape/structures/pipeline.py index 00c5f1d09..d5724244e 100644 --- a/griptape/structures/pipeline.py +++ b/griptape/structures/pipeline.py @@ -43,19 +43,15 @@ def insert_task(self, parent_task: BaseTask, task: BaseTask) -> BaseTask: return task def try_run(self, *args) -> Pipeline: - self._execution_args = args - - [task.reset() for task in self.tasks] - self.__run_from_task(self.input_task) - if self.conversation_memory: + if self.conversation_memory and self.output is not None: if isinstance(self.input_task.input, tuple): input_text = self.input_task.input[0].to_text() else: input_text = self.input_task.input.to_text() - run = Run(input=input_text, output=self.output_task.output.to_text()) + run = Run(input=input_text, output=self.output.to_text()) self.conversation_memory.add_run(run) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 8b71dd905..78dd69633 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -9,7 +9,7 @@ from attrs import Factory, define, field from rich.logging import RichHandler -from griptape.artifacts import BlobArtifact, TextArtifact +from griptape.artifacts import BlobArtifact, TextArtifact, BaseArtifact from griptape.config import BaseStructureConfig, OpenAiStructureConfig, StructureConfig from griptape.drivers import BaseEmbeddingDriver, BasePromptDriver, OpenAiEmbeddingDriver, OpenAiChatPromptDriver from griptape.drivers.vector.local_vector_store_driver import LocalVectorStoreDriver @@ -127,13 +127,17 @@ def input_task(self) -> Optional[BaseTask]: def output_task(self) -> Optional[BaseTask]: return self.tasks[-1] if self.tasks else None + @property + def output(self) -> Optional[BaseArtifact]: + return self.output_task.output if self.output_task is not None else None + @property def finished_tasks(self) -> list[BaseTask]: return [s for s in self.tasks if s.is_finished()] @property def default_config(self) -> BaseStructureConfig: - if self.prompt_driver is not None or self.embedding_driver is not None: + if self.prompt_driver is not None or self.embedding_driver is not None or self.stream is not None: config = StructureConfig() if self.prompt_driver is None: @@ -209,13 +213,39 @@ def publish_event(self, event: BaseEvent, flush: bool = False) -> None: def context(self, task: BaseTask) -> dict[str, Any]: return {"args": self.execution_args, "structure": self} - def before_run(self) -> None: + def resolve_relationships(self) -> None: + task_by_id = {task.id: task for task in self.tasks} + + for task in self.tasks: + # Ensure parents include this task as a child + for parent_id in task.parent_ids: + if parent_id not in task_by_id: + raise ValueError(f"Task with id {parent_id} doesn't exist.") + parent = task_by_id[parent_id] + if task.id not in parent.child_ids: + parent.child_ids.append(task.id) + + # Ensure children include this task as a parent + for child_id in task.child_ids: + if child_id not in task_by_id: + raise ValueError(f"Task with id {child_id} doesn't exist.") + child = task_by_id[child_id] + if task.id not in child.parent_ids: + child.parent_ids.append(task.id) + + def before_run(self, args: Any) -> None: + self._execution_args = args + + [task.reset() for task in self.tasks] + self.publish_event( StartStructureRunEvent( structure_id=self.id, input_task_input=self.input_task.input, input_task_output=self.input_task.output ) ) + self.resolve_relationships() + def after_run(self) -> None: self.publish_event( FinishStructureRunEvent( @@ -230,7 +260,7 @@ def after_run(self) -> None: def add_task(self, task: BaseTask) -> BaseTask: ... def run(self, *args) -> Structure: - self.before_run() + self.before_run(args) result = self.try_run(*args) diff --git a/griptape/structures/workflow.py b/griptape/structures/workflow.py index e60efa425..6552fba89 100644 --- a/griptape/structures/workflow.py +++ b/griptape/structures/workflow.py @@ -1,7 +1,7 @@ from __future__ import annotations import concurrent.futures as futures from graphlib import TopologicalSorter -from typing import Any +from typing import Any, Optional from attrs import define, field, Factory from griptape.artifacts import ErrorArtifact from griptape.structures import Structure @@ -13,13 +13,13 @@ class Workflow(Structure): futures_executor: futures.Executor = field(default=Factory(lambda: futures.ThreadPoolExecutor()), kw_only=True) + @property + def output_task(self) -> Optional[BaseTask]: + return self.order_tasks()[-1] if self.tasks else None + def add_task(self, task: BaseTask) -> BaseTask: task.preprocess(self) - if self.output_task: - self.output_task.child_ids.append(task.id) - task.parent_ids.append(self.output_task.id) - self.tasks.append(task) return task @@ -77,6 +77,7 @@ def insert_task( if parent_task.id in child_task.parent_ids: child_task.parent_ids.remove(parent_task.id) + last_parent_index = -1 for parent_task in parent_tasks: # Link the new task to the parent task if parent_task.id not in task.parent_ids: @@ -85,17 +86,20 @@ def insert_task( parent_task.child_ids.append(task.id) parent_index = self.tasks.index(parent_task) - self.tasks.insert(parent_index + 1, task) + if parent_index > last_parent_index: + last_parent_index = parent_index + + # Insert the new task once, just after the last parent task + self.tasks.insert(last_parent_index + 1, task) return task def try_run(self, *args) -> Workflow: - self._execution_args = args - ordered_tasks = self.order_tasks() exit_loop = False while not self.is_finished() and not exit_loop: futures_list = {} + ordered_tasks = self.order_tasks() for task in ordered_tasks: if task.can_execute(): @@ -109,7 +113,7 @@ def try_run(self, *args) -> Workflow: break - if self.conversation_memory: + if self.conversation_memory and self.output is not None: if isinstance(self.input_task.input, tuple): input_text = self.input_task.input[0].to_text() else: @@ -126,9 +130,8 @@ def context(self, task: BaseTask) -> dict[str, Any]: context.update( { - "parent_outputs": { - parent.id: parent.output.to_text() if parent.output else "" for parent in task.parents - }, + "parent_outputs": task.parent_outputs, + "parents_output_text": task.parents_output_text, "parents": {parent.id: parent for parent in task.parents}, "children": {child.id: child for child in task.children}, } diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index d47d1df32..1546a825d 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -101,7 +101,13 @@ def run(self) -> BaseArtifact: else: results = self.execute_actions(self.actions) - self.output = ListArtifact([TextArtifact(name=f"{r[0]} output", value=r[1].to_text()) for r in results]) + actions_output = [] + for result in results: + tag, output = result + output.name = f"{tag} output" + + actions_output.append(output) + self.output = ListArtifact(actions_output) except Exception as e: self.structure.logger.error(f"Subtask {self.id}\n{e}", exc_info=True) @@ -172,24 +178,6 @@ def actions_to_dicts(self) -> list[dict]: def actions_to_json(self) -> str: return json.dumps(self.actions_to_dicts()) - def add_child(self, child: ActionsSubtask) -> ActionsSubtask: - if child.id not in self.child_ids: - self.child_ids.append(child.id) - - if self.id not in child.parent_ids: - child.parent_ids.append(self.id) - - return child - - def add_parent(self, parent: ActionsSubtask) -> ActionsSubtask: - if parent.id not in self.parent_ids: - self.parent_ids.append(parent.id) - - if self.id not in parent.child_ids: - parent.child_ids.append(self.id) - - return parent - def __init_from_prompt(self, value: str) -> None: thought_matches = re.findall(self.THOUGHT_PATTERN, value, re.MULTILINE) actions_matches = re.findall(self.ACTIONS_PATTERN, value, re.DOTALL) diff --git a/griptape/tasks/audio_transcription_task.py b/griptape/tasks/audio_transcription_task.py index c75faa0d4..57dbf6782 100644 --- a/griptape/tasks/audio_transcription_task.py +++ b/griptape/tasks/audio_transcription_task.py @@ -1,37 +1,18 @@ from __future__ import annotations -from abc import ABC -from typing import Callable - from attrs import define, field -from griptape.artifacts.audio_artifact import AudioArtifact from griptape.engines import AudioTranscriptionEngine from griptape.artifacts import TextArtifact -from griptape.mixins import RuleMixin -from griptape.tasks import BaseTask +from griptape.tasks.base_audio_input_task import BaseAudioInputTask @define -class AudioTranscriptionTask(RuleMixin, BaseTask, ABC): - _input: AudioArtifact | Callable[[BaseTask], AudioArtifact] = field() +class AudioTranscriptionTask(BaseAudioInputTask): _audio_transcription_engine: AudioTranscriptionEngine = field( default=None, kw_only=True, alias="audio_transcription_engine" ) - @property - def input(self) -> AudioArtifact: - if isinstance(self._input, AudioArtifact): - return self._input - elif isinstance(self._input, Callable): - return self._input(self) - else: - raise ValueError("Input must be an AudioArtifact.") - - @input.setter - def input(self, value: AudioArtifact | Callable[[BaseTask], AudioArtifact]) -> None: - self._input = value - @property def audio_transcription_engine(self) -> AudioTranscriptionEngine: if self._audio_transcription_engine is None: diff --git a/griptape/tasks/base_audio_generation_task.py b/griptape/tasks/base_audio_generation_task.py index d401af0a5..71c2fbdf4 100644 --- a/griptape/tasks/base_audio_generation_task.py +++ b/griptape/tasks/base_audio_generation_task.py @@ -9,4 +9,13 @@ @define -class BaseAudioGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): ... +class BaseAudioGenerationTask(BlobArtifactFileOutputMixin, RuleMixin, BaseTask, ABC): + def before_run(self) -> None: + super().before_run() + + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {self.input.to_text()}") + + def after_run(self) -> None: + super().after_run() + + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") diff --git a/griptape/tasks/base_audio_input_task.py b/griptape/tasks/base_audio_input_task.py new file mode 100644 index 000000000..0991a6014 --- /dev/null +++ b/griptape/tasks/base_audio_input_task.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from abc import ABC +from typing import Callable + +from attrs import define, field + +from griptape.artifacts.audio_artifact import AudioArtifact +from griptape.mixins import RuleMixin +from griptape.tasks import BaseTask + + +@define +class BaseAudioInputTask(RuleMixin, BaseTask, ABC): + _input: AudioArtifact | Callable[[BaseTask], AudioArtifact] = field(alias="input") + + @property + def input(self) -> AudioArtifact: + if isinstance(self._input, AudioArtifact): + return self._input + elif isinstance(self._input, Callable): + return self._input(self) + else: + raise ValueError("Input must be an AudioArtifact.") + + @input.setter + def input(self, value: AudioArtifact | Callable[[BaseTask], AudioArtifact]) -> None: + self._input = value + + def before_run(self) -> None: + super().before_run() + + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {self.input.to_text()}") + + def after_run(self) -> None: + super().after_run() + + self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index 771fe4dc8..8a45cb14e 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -47,6 +47,14 @@ def parents(self) -> list[BaseTask]: def children(self) -> list[BaseTask]: return [self.structure.find_task(child_id) for child_id in self.child_ids] + @property + def parent_outputs(self) -> dict[str, str]: + return {parent.id: parent.output.to_text() if parent.output else "" for parent in self.parents} + + @property + def parents_output_text(self) -> str: + return "\n".join([parent.output.to_text() for parent in self.parents if parent.output]) + @property def meta_memories(self) -> list[BaseMetaEntry]: if self.structure and self.structure.meta_memory: @@ -60,6 +68,26 @@ def meta_memories(self) -> list[BaseMetaEntry]: def __str__(self) -> str: return str(self.output.value) + def add_parents(self, parents: list[str | BaseTask]) -> None: + for parent in parents: + self.add_parent(parent) + + def add_parent(self, parent: str | BaseTask) -> None: + parent_id = parent if isinstance(parent, str) else parent.id + + if parent_id not in self.parent_ids: + self.parent_ids.append(parent_id) + + def add_children(self, children: list[str | BaseTask]) -> None: + for child in children: + self.add_child(child) + + def add_child(self, child: str | BaseTask) -> None: + child_id = child if isinstance(child, str) else child.id + + if child_id not in self.child_ids: + self.child_ids.append(child_id) + def preprocess(self, structure: Structure) -> BaseTask: self.structure = structure diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 75051db74..694a5050d 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -34,7 +34,7 @@ def prompt_stack(self) -> PromptStack: if memory: # inserting at index 1 to place memory right after system prompt - stack.add_conversation_memory(memory, 1) + memory.add_to_prompt_stack(stack, 1) return stack diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index ed787aa45..c99f9e23f 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -21,6 +21,9 @@ @define class ToolkitTask(PromptTask, ActionsSubtaskOriginMixin): DEFAULT_MAX_STEPS = 20 + # Stop sequence for chain-of-thought in the framework. Using this "token-like" string to make it more unique, + # so that it doesn't trigger on accident. + RESPONSE_STOP_SEQUENCE = "<|Response|>" tools: list[BaseTool] = field(factory=list, kw_only=True) max_subtasks: int = field(default=DEFAULT_MAX_STEPS, kw_only=True) @@ -32,6 +35,7 @@ class ToolkitTask(PromptTask, ActionsSubtaskOriginMixin): generate_user_subtask_template: Callable[[ActionsSubtask], str] = field( default=Factory(lambda self: self.default_user_subtask_template_generator, takes_self=True), kw_only=True ) + response_stop_sequence: str = field(default=RESPONSE_STOP_SEQUENCE, kw_only=True) def __attrs_post_init__(self) -> None: if self.task_memory: @@ -74,7 +78,7 @@ def prompt_stack(self) -> PromptStack: if memory: # inserting at index 1 to place memory right after system prompt - stack.add_conversation_memory(memory, 1) + memory.add_to_prompt_stack(stack, 1) return stack @@ -95,17 +99,17 @@ def default_system_template_generator(self, _: PromptTask) -> str: action_names=str.join(", ", [tool.name for tool in self.tools]), actions_schema=utils.minify_json(json.dumps(schema)), meta_memory=J2("memory/meta/meta_memory.j2").render(meta_memories=self.meta_memories), - stop_sequence=utils.constants.RESPONSE_STOP_SEQUENCE, + stop_sequence=self.response_stop_sequence, ) def default_assistant_subtask_template_generator(self, subtask: ActionsSubtask) -> str: return J2("tasks/toolkit_task/assistant_subtask.j2").render( - stop_sequence=utils.constants.RESPONSE_STOP_SEQUENCE, subtask=subtask + stop_sequence=self.response_stop_sequence, subtask=subtask ) def default_user_subtask_template_generator(self, subtask: ActionsSubtask) -> str: return J2("tasks/toolkit_task/user_subtask.j2").render( - stop_sequence=utils.constants.RESPONSE_STOP_SEQUENCE, subtask=subtask + stop_sequence=self.response_stop_sequence, subtask=subtask ) def actions_schema(self) -> Schema: @@ -126,6 +130,7 @@ def run(self) -> BaseArtifact: self.subtasks.clear() + self.prompt_driver.tokenizer.stop_sequences.extend([self.response_stop_sequence]) subtask = self.add_subtask(ActionsSubtask(self.prompt_driver.run(prompt_stack=self.prompt_stack).to_text())) while True: @@ -161,6 +166,7 @@ def add_subtask(self, subtask: ActionsSubtask) -> ActionsSubtask: if len(self.subtasks) > 0: self.subtasks[-1].add_child(subtask) + subtask.add_parent(self.subtasks[-1]) self.subtasks.append(subtask) diff --git a/griptape/templates/engines/summary/prompt_summary.j2 b/griptape/templates/engines/summary/prompt_summary.j2 deleted file mode 100644 index 717810de6..000000000 --- a/griptape/templates/engines/summary/prompt_summary.j2 +++ /dev/null @@ -1,27 +0,0 @@ -{% if summary %} -Current text summary: """ -{{ summary }} -""" - -Rewrite the current text summary to include the following additional text: """ -{{ text }} -""" -{% if rulesets %} - -{{ rulesets }} -{% endif %} - -Rewritten summary: -{% else %} - -Summarize the following text: """ -{{ text }} -""" - -{% if rulesets %} - -{{ rulesets }} -{% endif %} - -Summary: -{% endif %} diff --git a/griptape/templates/engines/summary/system.j2 b/griptape/templates/engines/summary/system.j2 new file mode 100644 index 000000000..b5e132aa0 --- /dev/null +++ b/griptape/templates/engines/summary/system.j2 @@ -0,0 +1,14 @@ +You are an expert in text summarization. +{% if rulesets %} + +{{ rulesets }} + +{% endif %} + +{% if summary %} +Use the current text summary to help summarize the additional text. +Current text summary: """ +{{ summary }} +""" + +{% endif %} diff --git a/griptape/templates/engines/summary/user.j2 b/griptape/templates/engines/summary/user.j2 new file mode 100644 index 000000000..d6a40e412 --- /dev/null +++ b/griptape/templates/engines/summary/user.j2 @@ -0,0 +1,5 @@ +Summarize the following text: """ +{{ text }} +""" + +Summary: diff --git a/griptape/templates/tasks/tool_task/system.j2 b/griptape/templates/tasks/tool_task/system.j2 index eaf858037..7a802d989 100644 --- a/griptape/templates/tasks/tool_task/system.j2 +++ b/griptape/templates/tasks/tool_task/system.j2 @@ -1,5 +1,4 @@ -When appropriate, respond to requests by using the following Action Schema. Your response should be a plain JSON object that successfully validates against the schema. The schema is provided below. If you can't use the Action Schema, say "I don't know how to respond." - +You must respond to requests by using the following Action Schema. Your response should be a plain JSON object that successfully validates against the schema. The schema is provided below. If you can't use the Action Schema, say "I don't know how to respond." Action Schema: {{ action_schema }} {% if meta_memory %} diff --git a/griptape/tokenizers/__init__.py b/griptape/tokenizers/__init__.py index b116f9fb0..03b0aefe5 100644 --- a/griptape/tokenizers/__init__.py +++ b/griptape/tokenizers/__init__.py @@ -3,15 +3,11 @@ from griptape.tokenizers.cohere_tokenizer import CohereTokenizer from griptape.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer from griptape.tokenizers.anthropic_tokenizer import AnthropicTokenizer -from griptape.tokenizers.bedrock_titan_tokenizer import BedrockTitanTokenizer -from griptape.tokenizers.bedrock_cohere_tokenizer import BedrockCohereTokenizer -from griptape.tokenizers.bedrock_jurassic_tokenizer import BedrockJurassicTokenizer -from griptape.tokenizers.bedrock_claude_tokenizer import BedrockClaudeTokenizer -from griptape.tokenizers.bedrock_llama_tokenizer import BedrockLlamaTokenizer from griptape.tokenizers.google_tokenizer import GoogleTokenizer from griptape.tokenizers.voyageai_tokenizer import VoyageAiTokenizer from griptape.tokenizers.simple_tokenizer import SimpleTokenizer from griptape.tokenizers.dummy_tokenizer import DummyTokenizer +from griptape.tokenizers.amazon_bedrock_tokenizer import AmazonBedrockTokenizer __all__ = [ @@ -20,13 +16,9 @@ "CohereTokenizer", "HuggingFaceTokenizer", "AnthropicTokenizer", - "BedrockTitanTokenizer", - "BedrockCohereTokenizer", - "BedrockJurassicTokenizer", - "BedrockClaudeTokenizer", - "BedrockLlamaTokenizer", "GoogleTokenizer", "VoyageAiTokenizer", "SimpleTokenizer", "DummyTokenizer", + "AmazonBedrockTokenizer", ] diff --git a/griptape/tokenizers/amazon_bedrock_tokenizer.py b/griptape/tokenizers/amazon_bedrock_tokenizer.py new file mode 100644 index 000000000..670b5739a --- /dev/null +++ b/griptape/tokenizers/amazon_bedrock_tokenizer.py @@ -0,0 +1,40 @@ +from __future__ import annotations +from attrs import define, field +from griptape.tokenizers.base_tokenizer import BaseTokenizer + + +@define() +class AmazonBedrockTokenizer(BaseTokenizer): + MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = { + "anthropic.claude-3": 200000, + "anthropic.claude-v2:1": 200000, + "anthropic.claude": 100000, + "cohere.command-r": 128000, + "cohere.embed": 512, + "cohere.command": 4000, + "cohere": 1024, + "ai21": 8192, + "meta-llama3": 8000, + "meta-llama2": 4096, + "mistral": 32000, + "amazon": 4096, + } + MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = { + "anthropic.claude": 4096, + "cohere": 4096, + "ai21.j2": 8191, + "meta": 2048, + "amazon.titan-text-lite": 4096, + "amazon.titan-text-express": 8192, + "amazon.titan-text-premier": 3072, + "amazon": 4096, + "mistral": 8192, + } + + model: str = field(kw_only=True) + characters_per_token: int = field(default=4, kw_only=True) + + def count_tokens(self, text: str) -> int: + num_tokens = (len(text) + self.characters_per_token - 1) // self.characters_per_token + + return num_tokens diff --git a/griptape/tokenizers/anthropic_tokenizer.py b/griptape/tokenizers/anthropic_tokenizer.py index 577df7b93..f5fabab0e 100644 --- a/griptape/tokenizers/anthropic_tokenizer.py +++ b/griptape/tokenizers/anthropic_tokenizer.py @@ -17,8 +17,5 @@ class AnthropicTokenizer(BaseTokenizer): default=Factory(lambda: import_optional_dependency("anthropic").Anthropic()), kw_only=True ) - def count_tokens(self, text: str | list) -> int: - if isinstance(text, str): - return self.client.count_tokens(text) - else: - raise ValueError("Text must be a string.") + def count_tokens(self, text: str) -> int: + return self.client.count_tokens(text) diff --git a/griptape/tokenizers/base_tokenizer.py b/griptape/tokenizers/base_tokenizer.py index 179d2fb59..474ccbaa5 100644 --- a/griptape/tokenizers/base_tokenizer.py +++ b/griptape/tokenizers/base_tokenizer.py @@ -1,27 +1,30 @@ from __future__ import annotations -from abc import ABC, abstractmethod +import logging +from abc import ABC from attrs import define, field, Factory -from griptape import utils @define() class BaseTokenizer(ABC): + DEFAULT_MAX_INPUT_TOKENS = 4096 + DEFAULT_MAX_OUTPUT_TOKENS = 1000 MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {} MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {} model: str = field(kw_only=True) - stop_sequences: list[str] = field(default=Factory(lambda: [utils.constants.RESPONSE_STOP_SEQUENCE]), kw_only=True) + stop_sequences: list[str] = field(default=Factory(list), kw_only=True) max_input_tokens: int = field(kw_only=True, default=None) max_output_tokens: int = field(kw_only=True, default=None) def __attrs_post_init__(self) -> None: - if self.max_input_tokens is None: - self.max_input_tokens = self._default_max_input_tokens() + if hasattr(self, "model"): + if self.max_input_tokens is None: + self.max_input_tokens = self._default_max_input_tokens() - if self.max_output_tokens is None: - self.max_output_tokens = self._default_max_output_tokens() + if self.max_output_tokens is None: + self.max_output_tokens = self._default_max_output_tokens() - def count_input_tokens_left(self, text: str | list) -> int: + def count_input_tokens_left(self, text: str) -> int: diff = self.max_input_tokens - self.count_tokens(text) if diff > 0: @@ -29,7 +32,7 @@ def count_input_tokens_left(self, text: str | list) -> int: else: return 0 - def count_output_tokens_left(self, text: str | list) -> int: + def count_output_tokens_left(self, text: str) -> int: diff = self.max_output_tokens - self.count_tokens(text) if diff > 0: @@ -37,14 +40,16 @@ def count_output_tokens_left(self, text: str | list) -> int: else: return 0 - @abstractmethod - def count_tokens(self, text: str | list[dict]) -> int: ... + def count_tokens(self, text: str) -> int: ... def _default_max_input_tokens(self) -> int: tokens = next((v for k, v in self.MODEL_PREFIXES_TO_MAX_INPUT_TOKENS.items() if self.model.startswith(k)), None) if tokens is None: - raise ValueError(f"Unknown model default max input tokens: {self.model}") + logging.warning( + f"Model {self.model} not found in MODEL_PREFIXES_TO_MAX_INPUT_TOKENS, using default value of {self.DEFAULT_MAX_INPUT_TOKENS}." + ) + return self.DEFAULT_MAX_INPUT_TOKENS else: return tokens @@ -54,6 +59,9 @@ def _default_max_output_tokens(self) -> int: ) if tokens is None: - raise ValueError(f"Unknown model for default max output tokens: {self.model}") + logging.warning( + f"Model {self.model} not found in MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS, using default value of {self.DEFAULT_MAX_OUTPUT_TOKENS}." + ) + return self.DEFAULT_MAX_OUTPUT_TOKENS else: return tokens diff --git a/griptape/tokenizers/bedrock_claude_tokenizer.py b/griptape/tokenizers/bedrock_claude_tokenizer.py deleted file mode 100644 index d44116e2c..000000000 --- a/griptape/tokenizers/bedrock_claude_tokenizer.py +++ /dev/null @@ -1,12 +0,0 @@ -from attrs import define -from griptape.tokenizers import AnthropicTokenizer - - -@define() -class BedrockClaudeTokenizer(AnthropicTokenizer): - MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = { - "anthropic.claude-3": 200000, - "anthropic.claude-v2:1": 200000, - "anthropic.claude": 100000, - } - MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"anthropic.claude": 4096} diff --git a/griptape/tokenizers/bedrock_cohere_tokenizer.py b/griptape/tokenizers/bedrock_cohere_tokenizer.py deleted file mode 100644 index 44ccb4ac6..000000000 --- a/griptape/tokenizers/bedrock_cohere_tokenizer.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations -from attrs import define, field -from .simple_tokenizer import SimpleTokenizer - - -@define() -class BedrockCohereTokenizer(SimpleTokenizer): - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html - DEFAULT_CHARACTERS_PER_TOKEN = 4 - MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"cohere": 1024} - MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"cohere": 4096} - - model: str = field(kw_only=True) - characters_per_token: int = field(default=DEFAULT_CHARACTERS_PER_TOKEN, kw_only=True) diff --git a/griptape/tokenizers/bedrock_jurassic_tokenizer.py b/griptape/tokenizers/bedrock_jurassic_tokenizer.py deleted file mode 100644 index 7511138b3..000000000 --- a/griptape/tokenizers/bedrock_jurassic_tokenizer.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations -from attrs import define, field, Factory -from .simple_tokenizer import SimpleTokenizer - - -@define() -class BedrockJurassicTokenizer(SimpleTokenizer): - DEFAULT_CHARACTERS_PER_TOKEN = 6 # https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html#model-customization-prepare-finetuning - MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"ai21": 8192} - MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = { - "ai21.j2-mid-v1": 8191, - "ai21.j2-ultra-v1": 8191, - "ai21.j2-large-v1": 8191, - "ai21": 2048, - } - - model: str = field(kw_only=True) - characters_per_token: int = field( - default=Factory(lambda self: self.DEFAULT_CHARACTERS_PER_TOKEN, takes_self=True), kw_only=True - ) diff --git a/griptape/tokenizers/bedrock_llama_tokenizer.py b/griptape/tokenizers/bedrock_llama_tokenizer.py deleted file mode 100644 index e7d1ec829..000000000 --- a/griptape/tokenizers/bedrock_llama_tokenizer.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations -from attrs import define, field -from .simple_tokenizer import SimpleTokenizer - - -@define() -class BedrockLlamaTokenizer(SimpleTokenizer): - DEFAULT_CHARACTERS_PER_TOKEN = 6 # https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html#model-customization-prepare-finetuning - MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"meta": 2048} - MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"meta": 2048} - - model: str = field(kw_only=True) - characters_per_token: int = field(default=DEFAULT_CHARACTERS_PER_TOKEN, kw_only=True) - stop_sequences: list[str] = field(factory=list, kw_only=True) diff --git a/griptape/tokenizers/bedrock_titan_tokenizer.py b/griptape/tokenizers/bedrock_titan_tokenizer.py deleted file mode 100644 index 0d8ba0273..000000000 --- a/griptape/tokenizers/bedrock_titan_tokenizer.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations -from attrs import define, field, Factory -from .simple_tokenizer import SimpleTokenizer - - -@define() -class BedrockTitanTokenizer(SimpleTokenizer): - DEFAULT_CHARACTERS_PER_TOKEN = 6 # https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html#model-customization-prepare-finetuning - MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"amazon": 4096} - MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"amazon": 8000} - - model: str = field(kw_only=True) - characters_per_token: int = field(default=DEFAULT_CHARACTERS_PER_TOKEN, kw_only=True) - stop_sequences: list[str] = field(default=Factory(lambda: ["User:"]), kw_only=True) diff --git a/griptape/tokenizers/cohere_tokenizer.py b/griptape/tokenizers/cohere_tokenizer.py index 0a3c6a236..ae3bddd80 100644 --- a/griptape/tokenizers/cohere_tokenizer.py +++ b/griptape/tokenizers/cohere_tokenizer.py @@ -9,13 +9,10 @@ @define() class CohereTokenizer(BaseTokenizer): - MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"command": 4096} - MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"command": 4096} + MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"command-r": 128000, "command": 4096, "embed": 512} + MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"command": 4096, "embed": 512} client: Client = field(kw_only=True) - def count_tokens(self, text: str | list) -> int: - if isinstance(text, str): - return len(self.client.tokenize(text=text, model=self.model).tokens) - else: - raise ValueError("Text must be a string.") + def count_tokens(self, text: str) -> int: + return len(self.client.tokenize(text=text, model=self.model).tokens) diff --git a/griptape/tokenizers/dummy_tokenizer.py b/griptape/tokenizers/dummy_tokenizer.py index 74f6d104c..a36d0343e 100644 --- a/griptape/tokenizers/dummy_tokenizer.py +++ b/griptape/tokenizers/dummy_tokenizer.py @@ -1,14 +1,15 @@ from __future__ import annotations from attrs import define, field +from typing import Optional from griptape.exceptions import DummyException from griptape.tokenizers import BaseTokenizer @define class DummyTokenizer(BaseTokenizer): - model: None = field(init=False, default=None, kw_only=True) + model: Optional[str] = field(default=None, kw_only=True) max_input_tokens: int = field(init=False, default=0, kw_only=True) max_output_tokens: int = field(init=False, default=0, kw_only=True) - def count_tokens(self, text: str | list) -> int: + def count_tokens(self, text: str) -> int: raise DummyException(__class__.__name__, "count_tokens") diff --git a/griptape/tokenizers/google_tokenizer.py b/griptape/tokenizers/google_tokenizer.py index 55942f597..f99a0682f 100644 --- a/griptape/tokenizers/google_tokenizer.py +++ b/griptape/tokenizers/google_tokenizer.py @@ -18,11 +18,8 @@ class GoogleTokenizer(BaseTokenizer): default=Factory(lambda self: self._default_model_client(), takes_self=True), kw_only=True ) - def count_tokens(self, text: str | list) -> int: - if isinstance(text, str) or isinstance(text, list): - return self.model_client.count_tokens(text).total_tokens - else: - raise ValueError("Text must be a string or a list.") + def count_tokens(self, text: str) -> int: + return self.model_client.count_tokens(text).total_tokens def _default_model_client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") diff --git a/griptape/tokenizers/huggingface_tokenizer.py b/griptape/tokenizers/huggingface_tokenizer.py index dbfba5429..a8312567d 100644 --- a/griptape/tokenizers/huggingface_tokenizer.py +++ b/griptape/tokenizers/huggingface_tokenizer.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING from attrs import define, field, Factory +from griptape.utils import import_optional_dependency from griptape.tokenizers import BaseTokenizer if TYPE_CHECKING: @@ -9,15 +10,17 @@ @define() class HuggingFaceTokenizer(BaseTokenizer): - tokenizer: PreTrainedTokenizerBase = field(kw_only=True) - model: None = field(init=False, default=None, kw_only=True) + tokenizer: PreTrainedTokenizerBase = field( + default=Factory( + lambda self: import_optional_dependency("transformers").AutoTokenizer.from_pretrained(self.model), + takes_self=True, + ), + kw_only=True, + ) max_input_tokens: int = field( default=Factory(lambda self: self.tokenizer.model_max_length, takes_self=True), kw_only=True ) - max_output_tokens: int = field(kw_only=True) # pyright: ignore[reportGeneralTypeIssues] + max_output_tokens: int = field(default=4096, kw_only=True) - def count_tokens(self, text: str | list) -> int: - if isinstance(text, str): - return len(self.tokenizer.encode(text)) - else: - raise ValueError("Text must be a string.") + def count_tokens(self, text: str) -> int: + return len(self.tokenizer.encode(text)) diff --git a/griptape/tokenizers/openai_tokenizer.py b/griptape/tokenizers/openai_tokenizer.py index ec127ca1a..39a2a033e 100644 --- a/griptape/tokenizers/openai_tokenizer.py +++ b/griptape/tokenizers/openai_tokenizer.py @@ -1,9 +1,9 @@ from __future__ import annotations import logging -from attrs import define +from attrs import define, field, Factory import tiktoken -from griptape.tokenizers import BaseTokenizer from typing import Optional +from griptape.tokenizers import BaseTokenizer @define() @@ -41,6 +41,13 @@ class OpenAiTokenizer(BaseTokenizer): "text-embedding-3-large", ] + max_input_tokens: int = field( + kw_only=True, default=Factory(lambda self: self._default_max_input_tokens(), takes_self=True) + ) + max_output_tokens: int = field( + kw_only=True, default=Factory(lambda self: self._default_max_output_tokens(), takes_self=True) + ) + @property def encoding(self) -> tiktoken.Encoding: try: diff --git a/griptape/tokenizers/simple_tokenizer.py b/griptape/tokenizers/simple_tokenizer.py index 484afe69f..b4e125680 100644 --- a/griptape/tokenizers/simple_tokenizer.py +++ b/griptape/tokenizers/simple_tokenizer.py @@ -1,18 +1,14 @@ from __future__ import annotations -from typing import Optional from attrs import define, field from griptape.tokenizers import BaseTokenizer @define() class SimpleTokenizer(BaseTokenizer): - model: Optional[str] = field(init=False, kw_only=True, default=None) + model: str = field(init=False, kw_only=True) characters_per_token: int = field(kw_only=True) - def count_tokens(self, text: str | list) -> int: - if isinstance(text, str): - num_tokens = (len(text) + self.characters_per_token - 1) // self.characters_per_token + def count_tokens(self, text: str) -> int: + num_tokens = (len(text) + self.characters_per_token - 1) // self.characters_per_token - return num_tokens - else: - raise ValueError("Text must be a string.") + return num_tokens diff --git a/griptape/tokenizers/voyageai_tokenizer.py b/griptape/tokenizers/voyageai_tokenizer.py index 565e53faa..d8fb5adf1 100644 --- a/griptape/tokenizers/voyageai_tokenizer.py +++ b/griptape/tokenizers/voyageai_tokenizer.py @@ -26,8 +26,5 @@ class VoyageAiTokenizer(BaseTokenizer): kw_only=True, ) - def count_tokens(self, text: str | list) -> int: - if isinstance(text, str): - return self.client.count_tokens([text]) - else: - raise ValueError("Text must be a str.") + def count_tokens(self, text: str) -> int: + return self.client.count_tokens([text]) diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index 0c9b6e01d..1c152c95a 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -26,6 +26,8 @@ from .griptape_cloud_knowledge_base_client.tool import GriptapeCloudKnowledgeBaseClient from .structure_run_client.tool import StructureRunClient from .image_query_client.tool import ImageQueryClient +from .text_to_speech_client.tool import TextToSpeechClient +from .audio_transcription_client.tool import AudioTranscriptionClient __all__ = [ "BaseTool", @@ -56,4 +58,6 @@ "GriptapeCloudKnowledgeBaseClient", "StructureRunClient", "ImageQueryClient", + "TextToSpeechClient", + "AudioTranscriptionClient", ] diff --git a/griptape/tools/griptape_cloud_knowledge_base_client/tool.py b/griptape/tools/griptape_cloud_knowledge_base_client/tool.py index 917406cdd..6fed6e618 100644 --- a/griptape/tools/griptape_cloud_knowledge_base_client/tool.py +++ b/griptape/tools/griptape_cloud_knowledge_base_client/tool.py @@ -48,8 +48,14 @@ def _get_knowledge_base_description(self) -> str: else: url = urljoin(self.base_url.strip("/"), f"/api/knowledge-bases/{self.knowledge_base_id}/") - response = get(url, headers=self.headers).json() - if "description" in response: - return response["description"] + response = get(url, headers=self.headers) + response_body = response.json() + if response.status_code == 200: + if "description" in response_body: + return response_body["description"] + else: + raise ValueError( + f"No description found for Knowledge Base {self.knowledge_base_id}. Please set a description, or manually set the `GriptapeCloudKnowledgeBaseClient.description` attribute." + ) else: - raise ValueError(f'Error getting Knowledge Base description: {response["message"]}') + raise ValueError(f"Error accessing Knowledge Base {self.knowledge_base_id}.") diff --git a/griptape/tools/web_search/tool.py b/griptape/tools/web_search/tool.py index 6ad25ebf0..acc358342 100644 --- a/griptape/tools/web_search/tool.py +++ b/griptape/tools/web_search/tool.py @@ -5,6 +5,7 @@ from griptape.tools import BaseTool from griptape.utils.decorators import activity import requests +import json @define @@ -32,7 +33,7 @@ def search(self, props: dict) -> ListArtifact | ErrorArtifact: query = props["values"]["query"] try: - return ListArtifact([TextArtifact(str(result)) for result in self._search_google(query)]) + return ListArtifact([TextArtifact(json.dumps(result)) for result in self._search_google(query)]) except Exception as e: return ErrorArtifact(f"error searching Google: {e}") diff --git a/griptape/utils/__init__.py b/griptape/utils/__init__.py index 64ca9a9f7..daac63f4e 100644 --- a/griptape/utils/__init__.py +++ b/griptape/utils/__init__.py @@ -14,9 +14,9 @@ from .import_utils import import_optional_dependency from .import_utils import is_dependency_installed from .stream import Stream -from .constants import Constants as constants from .load_artifact_from_memory import load_artifact_from_memory from .deprecation import deprecation_warn +from .structure_visualizer import StructureVisualizer def minify_json(value: str) -> str: @@ -40,9 +40,9 @@ def minify_json(value: str) -> str: "remove_null_values_in_dict_recursively", "dict_merge", "Stream", - "constants", "load_artifact_from_memory", "deprecation_warn", "load_file", "load_files", + "StructureVisualizer", ] diff --git a/griptape/utils/constants.py b/griptape/utils/constants.py index 7bee76750..e69de29bb 100644 --- a/griptape/utils/constants.py +++ b/griptape/utils/constants.py @@ -1,4 +0,0 @@ -class Constants: - # Stop sequence for chain-of-thought in the framework. Using this "token-like" string to make it more unique, - # so that it doesn't trigger on accident. - RESPONSE_STOP_SEQUENCE = "<|Response|>" diff --git a/griptape/utils/prompt_stack.py b/griptape/utils/prompt_stack.py index f04cef486..378f9dd1e 100644 --- a/griptape/utils/prompt_stack.py +++ b/griptape/utils/prompt_stack.py @@ -1,12 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional from attrs import define, field from griptape.mixins import SerializableMixin -if TYPE_CHECKING: - from griptape.memory.structure import BaseConversationMemory - @define class PromptStack(SerializableMixin): @@ -50,50 +46,3 @@ def add_user_input(self, content: str) -> Input: def add_assistant_input(self, content: str) -> Input: return self.add_input(content, self.ASSISTANT_ROLE) - - def add_conversation_memory(self, memory: BaseConversationMemory, index: Optional[int] = None) -> list[Input]: - """Add the Conversation Memory runs to the Prompt Stack. - - If autoprune is enabled, this will fit as many Conversation Memory runs into the Prompt Stack - as possible without exceeding the token limit. - - Args: - memory: The Conversation Memory to add the Prompt Stack to. - index: Optional index to insert the Conversation Memory runs at. - Defaults to appending to the end of the Prompt Stack. - """ - num_runs_to_fit_in_prompt = len(memory.runs) - - if memory.autoprune and hasattr(memory, "structure"): - should_prune = True - prompt_driver = memory.structure.config.prompt_driver - temp_stack = PromptStack() - - # Try to determine how many Conversation Memory runs we can - # fit into the Prompt Stack without exceeding the token limit. - while should_prune and num_runs_to_fit_in_prompt > 0: - temp_stack.inputs = self.inputs.copy() - - # Add n runs from Conversation Memory. - # Where we insert into the Prompt Stack doesn't matter here - # since we only care about the total token count. - memory_inputs = memory.to_prompt_stack(num_runs_to_fit_in_prompt).inputs - temp_stack.inputs.extend(memory_inputs) - - # Convert the prompt stack into tokens left. - prompt_string = prompt_driver.prompt_stack_to_string(temp_stack) - tokens_left = prompt_driver.tokenizer.count_input_tokens_left(prompt_string) - if tokens_left > 0: - # There are still tokens left, no need to prune. - should_prune = False - else: - # There were not any tokens left, prune one run and try again. - num_runs_to_fit_in_prompt -= 1 - - if num_runs_to_fit_in_prompt: - memory_inputs = memory.to_prompt_stack(num_runs_to_fit_in_prompt).inputs - if index: - self.inputs[index:index] = memory_inputs - else: - self.inputs.extend(memory_inputs) - return self.inputs diff --git a/griptape/utils/structure_visualizer.py b/griptape/utils/structure_visualizer.py new file mode 100644 index 000000000..ede282761 --- /dev/null +++ b/griptape/utils/structure_visualizer.py @@ -0,0 +1,42 @@ +from __future__ import annotations +import base64 + +from attrs import define, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from griptape.tasks import BaseTask + from griptape.structures import Structure + + +@define +class StructureVisualizer: + """Utility class to visualize a Structure structure""" + + structure: Structure = field() + header: str = field(default="graph TD;", kw_only=True) + + def to_url(self) -> str: + """Generates a url that renders the Workflow structure as a Mermaid flowchart + Reference: https://mermaid.js.org/ecosystem/tutorials#jupyter-integration-with-mermaid-js + + Returns: + str: URL to the rendered image + """ + self.structure.resolve_relationships() + + tasks = "\n\t" + "\n\t".join([self.__render_task(task) for task in self.structure.tasks]) + graph = f"{self.header}{tasks}" + + graph_bytes = graph.encode("utf-8") + base64_string = base64.b64encode(graph_bytes).decode("utf-8") + + url = f"https://mermaid.ink/svg/{base64_string}" + + return url + + def __render_task(self, task: BaseTask) -> str: + if task.children: + return f'{task.id}--> {" & ".join([child.id for child in task.children])};' + else: + return f"{task.id};" diff --git a/mkdocs.yml b/mkdocs.yml index 99231cce7..317409c55 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,6 +1,7 @@ site_name: Griptape Docs hooks: - docs/plugins/swagger_ui_plugin.py +strict: true plugins: - glightbox - search @@ -96,7 +97,6 @@ nav: - Extraction Engines: "griptape-framework/engines/extraction-engines.md" - Summary Engines: "griptape-framework/engines/summary-engines.md" - Image Generation Engines: "griptape-framework/engines/image-generation-engines.md" - - Image Query Engines: "griptape-framework/engines/image-query-engines.md" - Audio Engines: "griptape-framework/engines/audio-engines.md" - Drivers: - Prompt Drivers: "griptape-framework/drivers/prompt-drivers.md" @@ -148,6 +148,7 @@ nav: - ImageQueryClient: "griptape-tools/official-tools/image-query-client.md" - TextToSpeechClient: "griptape-tools/official-tools/text-to-speech-client.md" - AudioTranscriptionClient: "griptape-tools/official-tools/audio-transcription-client.md" + - GriptapeCloudKnowledgeBaseClient: "griptape-tools/official-tools/griptape-cloud-knowledge-base-client.md" - Custom Tools: - Building Custom Tools: "griptape-tools/custom-tools/index.md" - Recipes: diff --git a/poetry.lock b/poetry.lock index cf8fde4ac..b94c05b2a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -278,17 +278,17 @@ lxml = ["lxml"] [[package]] name = "boto3" -version = "1.34.106" +version = "1.34.119" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.34.106-py3-none-any.whl", hash = "sha256:d3be4e1dd5d546a001cd4da805816934cbde9d395316546e9411fec341ade5cf"}, - {file = "boto3-1.34.106.tar.gz", hash = "sha256:6165b8cf1c7e625628ab28b32f9027064c8f5e5fca1c38d7fc228cd22069a19f"}, + {file = "boto3-1.34.119-py3-none-any.whl", hash = "sha256:8f9c43c54b3dfaa36c4a0d7b42c417227a515bc7a2e163e62802780000a5a3e2"}, + {file = "boto3-1.34.119.tar.gz", hash = "sha256:cea2365a25b2b83a97e77f24ac6f922ef62e20636b42f9f6ee9f97188f9c1c03"}, ] [package.dependencies] -botocore = ">=1.34.106,<1.35.0" +botocore = ">=1.34.119,<1.35.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -706,13 +706,13 @@ xray = ["mypy-boto3-xray (>=1.34.0,<1.35.0)"] [[package]] name = "botocore" -version = "1.34.106" +version = "1.34.119" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.34.106-py3-none-any.whl", hash = "sha256:4baf0e27c2dfc4f4d0dee7c217c716e0782f9b30e8e1fff983fce237d88f73ae"}, - {file = "botocore-1.34.106.tar.gz", hash = "sha256:921fa5202f88c3e58fdcb4b3acffd56d65b24bca47092ee4b27aa988556c0be6"}, + {file = "botocore-1.34.119-py3-none-any.whl", hash = "sha256:4bdf7926a1290b2650d62899ceba65073dd2693e61c35f5cdeb3a286a0aaa27b"}, + {file = "botocore-1.34.119.tar.gz", hash = "sha256:b253f15b24b87b070e176af48e8ef146516090429d30a7d8b136a4c079b28008"}, ] [package.dependencies] @@ -955,13 +955,13 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [[package]] name = "cohere" -version = "5.5.0" +version = "5.5.4" description = "" optional = true python-versions = "<4.0,>=3.8" files = [ - {file = "cohere-5.5.0-py3-none-any.whl", hash = "sha256:7792e8898c95f2cb955b2d9f23b8602f73f3b698d59f1a1b4896c53809671da0"}, - {file = "cohere-5.5.0.tar.gz", hash = "sha256:00b492ebf8921e83cb2371f2ee36ddf301422daae3024343a87d4316f02b711b"}, + {file = "cohere-5.5.4-py3-none-any.whl", hash = "sha256:8b692dcb5e86b554e5884168a7d2454951ce102fbd983e9053ec933e06bf02fa"}, + {file = "cohere-5.5.4.tar.gz", hash = "sha256:14acb2ccf272e958f79f9241ae972fd82e96b9b8ee9e6922a5687370761203ec"}, ] [package.dependencies] @@ -971,7 +971,7 @@ httpx = ">=0.21.2" httpx-sse = ">=0.4.0,<0.5.0" pydantic = ">=1.9.2" requests = ">=2.0.0,<3.0.0" -tokenizers = ">=0.19,<0.20" +tokenizers = ">=0.15,<0.16" types-requests = ">=2.0.0,<3.0.0" typing_extensions = ">=4.0.0" @@ -1127,7 +1127,7 @@ test-randomorder = ["pytest-randomly"] name = "dateparser" version = "1.2.0" description = "Date parsing library designed to parse dates from HTML pages" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "dateparser-1.2.0-py2.py3-none-any.whl", hash = "sha256:0b21ad96534e562920a0083e97fd45fa959882d4162acc358705144520a35830"}, @@ -3235,6 +3235,20 @@ files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, ] +[[package]] +name = "ollama" +version = "0.2.1" +description = "The official Python client for Ollama." +optional = true +python-versions = "<4.0,>=3.8" +files = [ + {file = "ollama-0.2.1-py3-none-any.whl", hash = "sha256:b6e2414921c94f573a903d1069d682ba2fb2607070ea9e19ca4a7872f2a460ec"}, + {file = "ollama-0.2.1.tar.gz", hash = "sha256:fa316baa9a81eac3beb4affb0a17deb3008fdd6ed05b123c26306cfbe4c349b6"}, +] + +[package.dependencies] +httpx = ">=0.27.0,<0.28.0" + [[package]] name = "openai" version = "1.30.1" @@ -4248,7 +4262,7 @@ six = ">=1.5" name = "pytz" version = "2024.1" description = "World timezone definitions, modern and historical" -optional = false +optional = true python-versions = "*" files = [ {file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"}, @@ -5263,120 +5277,130 @@ files = [ [[package]] name = "tokenizers" -version = "0.19.1" +version = "0.15.2" description = "" optional = true python-versions = ">=3.7" files = [ - {file = "tokenizers-0.19.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:952078130b3d101e05ecfc7fc3640282d74ed26bcf691400f872563fca15ac97"}, - {file = "tokenizers-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:82c8b8063de6c0468f08e82c4e198763e7b97aabfe573fd4cf7b33930ca4df77"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f03727225feaf340ceeb7e00604825addef622d551cbd46b7b775ac834c1e1c4"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:453e4422efdfc9c6b6bf2eae00d5e323f263fff62b29a8c9cd526c5003f3f642"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:02e81bf089ebf0e7f4df34fa0207519f07e66d8491d963618252f2e0729e0b46"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b07c538ba956843833fee1190cf769c60dc62e1cf934ed50d77d5502194d63b1"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e28cab1582e0eec38b1f38c1c1fb2e56bce5dc180acb1724574fc5f47da2a4fe"}, - {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b01afb7193d47439f091cd8f070a1ced347ad0f9144952a30a41836902fe09e"}, - {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7fb297edec6c6841ab2e4e8f357209519188e4a59b557ea4fafcf4691d1b4c98"}, - {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e8a3dd055e515df7054378dc9d6fa8c8c34e1f32777fb9a01fea81496b3f9d3"}, - {file = "tokenizers-0.19.1-cp310-none-win32.whl", hash = "sha256:7ff898780a155ea053f5d934925f3902be2ed1f4d916461e1a93019cc7250837"}, - {file = "tokenizers-0.19.1-cp310-none-win_amd64.whl", hash = "sha256:bea6f9947e9419c2fda21ae6c32871e3d398cba549b93f4a65a2d369662d9403"}, - {file = "tokenizers-0.19.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:5c88d1481f1882c2e53e6bb06491e474e420d9ac7bdff172610c4f9ad3898059"}, - {file = "tokenizers-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddf672ed719b4ed82b51499100f5417d7d9f6fb05a65e232249268f35de5ed14"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:dadc509cc8a9fe460bd274c0e16ac4184d0958117cf026e0ea8b32b438171594"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfedf31824ca4915b511b03441784ff640378191918264268e6923da48104acc"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac11016d0a04aa6487b1513a3a36e7bee7eec0e5d30057c9c0408067345c48d2"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:76951121890fea8330d3a0df9a954b3f2a37e3ec20e5b0530e9a0044ca2e11fe"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b342d2ce8fc8d00f376af068e3274e2e8649562e3bc6ae4a67784ded6b99428d"}, - {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d16ff18907f4909dca9b076b9c2d899114dd6abceeb074eca0c93e2353f943aa"}, - {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:706a37cc5332f85f26efbe2bdc9ef8a9b372b77e4645331a405073e4b3a8c1c6"}, - {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:16baac68651701364b0289979ecec728546133e8e8fe38f66fe48ad07996b88b"}, - {file = "tokenizers-0.19.1-cp311-none-win32.whl", hash = "sha256:9ed240c56b4403e22b9584ee37d87b8bfa14865134e3e1c3fb4b2c42fafd3256"}, - {file = "tokenizers-0.19.1-cp311-none-win_amd64.whl", hash = "sha256:ad57d59341710b94a7d9dbea13f5c1e7d76fd8d9bcd944a7a6ab0b0da6e0cc66"}, - {file = "tokenizers-0.19.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:621d670e1b1c281a1c9698ed89451395d318802ff88d1fc1accff0867a06f153"}, - {file = "tokenizers-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d924204a3dbe50b75630bd16f821ebda6a5f729928df30f582fb5aade90c818a"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4f3fefdc0446b1a1e6d81cd4c07088ac015665d2e812f6dbba4a06267d1a2c95"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9620b78e0b2d52ef07b0d428323fb34e8ea1219c5eac98c2596311f20f1f9266"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04ce49e82d100594715ac1b2ce87d1a36e61891a91de774755f743babcd0dd52"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5c2ff13d157afe413bf7e25789879dd463e5a4abfb529a2d8f8473d8042e28f"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3174c76efd9d08f836bfccaca7cfec3f4d1c0a4cf3acbc7236ad577cc423c840"}, - {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c9d5b6c0e7a1e979bec10ff960fae925e947aab95619a6fdb4c1d8ff3708ce3"}, - {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a179856d1caee06577220ebcfa332af046d576fb73454b8f4d4b0ba8324423ea"}, - {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:952b80dac1a6492170f8c2429bd11fcaa14377e097d12a1dbe0ef2fb2241e16c"}, - {file = "tokenizers-0.19.1-cp312-none-win32.whl", hash = "sha256:01d62812454c188306755c94755465505836fd616f75067abcae529c35edeb57"}, - {file = "tokenizers-0.19.1-cp312-none-win_amd64.whl", hash = "sha256:b70bfbe3a82d3e3fb2a5e9b22a39f8d1740c96c68b6ace0086b39074f08ab89a"}, - {file = "tokenizers-0.19.1-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:bb9dfe7dae85bc6119d705a76dc068c062b8b575abe3595e3c6276480e67e3f1"}, - {file = "tokenizers-0.19.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:1f0360cbea28ea99944ac089c00de7b2e3e1c58f479fb8613b6d8d511ce98267"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:71e3ec71f0e78780851fef28c2a9babe20270404c921b756d7c532d280349214"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b82931fa619dbad979c0ee8e54dd5278acc418209cc897e42fac041f5366d626"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e8ff5b90eabdcdaa19af697885f70fe0b714ce16709cf43d4952f1f85299e73a"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e742d76ad84acbdb1a8e4694f915fe59ff6edc381c97d6dfdd054954e3478ad4"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d8c5d59d7b59885eab559d5bc082b2985555a54cda04dda4c65528d90ad252ad"}, - {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b2da5c32ed869bebd990c9420df49813709e953674c0722ff471a116d97b22d"}, - {file = "tokenizers-0.19.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:638e43936cc8b2cbb9f9d8dde0fe5e7e30766a3318d2342999ae27f68fdc9bd6"}, - {file = "tokenizers-0.19.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:78e769eb3b2c79687d9cb0f89ef77223e8e279b75c0a968e637ca7043a84463f"}, - {file = "tokenizers-0.19.1-cp37-none-win32.whl", hash = "sha256:72791f9bb1ca78e3ae525d4782e85272c63faaef9940d92142aa3eb79f3407a3"}, - {file = "tokenizers-0.19.1-cp37-none-win_amd64.whl", hash = "sha256:f3bbb7a0c5fcb692950b041ae11067ac54826204318922da754f908d95619fbc"}, - {file = "tokenizers-0.19.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:07f9295349bbbcedae8cefdbcfa7f686aa420be8aca5d4f7d1ae6016c128c0c5"}, - {file = "tokenizers-0.19.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:10a707cc6c4b6b183ec5dbfc5c34f3064e18cf62b4a938cb41699e33a99e03c1"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6309271f57b397aa0aff0cbbe632ca9d70430839ca3178bf0f06f825924eca22"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ad23d37d68cf00d54af184586d79b84075ada495e7c5c0f601f051b162112dc"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:427c4f0f3df9109314d4f75b8d1f65d9477033e67ffaec4bca53293d3aca286d"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e83a31c9cf181a0a3ef0abad2b5f6b43399faf5da7e696196ddd110d332519ee"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c27b99889bd58b7e301468c0838c5ed75e60c66df0d4db80c08f43462f82e0d3"}, - {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bac0b0eb952412b0b196ca7a40e7dce4ed6f6926489313414010f2e6b9ec2adf"}, - {file = "tokenizers-0.19.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8a6298bde623725ca31c9035a04bf2ef63208d266acd2bed8c2cb7d2b7d53ce6"}, - {file = "tokenizers-0.19.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:08a44864e42fa6d7d76d7be4bec62c9982f6f6248b4aa42f7302aa01e0abfd26"}, - {file = "tokenizers-0.19.1-cp38-none-win32.whl", hash = "sha256:1de5bc8652252d9357a666e609cb1453d4f8e160eb1fb2830ee369dd658e8975"}, - {file = "tokenizers-0.19.1-cp38-none-win_amd64.whl", hash = "sha256:0bcce02bf1ad9882345b34d5bd25ed4949a480cf0e656bbd468f4d8986f7a3f1"}, - {file = "tokenizers-0.19.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0b9394bd204842a2a1fd37fe29935353742be4a3460b6ccbaefa93f58a8df43d"}, - {file = "tokenizers-0.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4692ab92f91b87769d950ca14dbb61f8a9ef36a62f94bad6c82cc84a51f76f6a"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6258c2ef6f06259f70a682491c78561d492e885adeaf9f64f5389f78aa49a051"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c85cf76561fbd01e0d9ea2d1cbe711a65400092bc52b5242b16cfd22e51f0c58"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:670b802d4d82bbbb832ddb0d41df7015b3e549714c0e77f9bed3e74d42400fbe"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:85aa3ab4b03d5e99fdd31660872249df5e855334b6c333e0bc13032ff4469c4a"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cbf001afbbed111a79ca47d75941e9e5361297a87d186cbfc11ed45e30b5daba"}, - {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c89aa46c269e4e70c4d4f9d6bc644fcc39bb409cb2a81227923404dd6f5227"}, - {file = "tokenizers-0.19.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:39c1ec76ea1027438fafe16ecb0fb84795e62e9d643444c1090179e63808c69d"}, - {file = "tokenizers-0.19.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c2a0d47a89b48d7daa241e004e71fb5a50533718897a4cd6235cb846d511a478"}, - {file = "tokenizers-0.19.1-cp39-none-win32.whl", hash = "sha256:61b7fe8886f2e104d4caf9218b157b106207e0f2a4905c9c7ac98890688aabeb"}, - {file = "tokenizers-0.19.1-cp39-none-win_amd64.whl", hash = "sha256:f97660f6c43efd3e0bfd3f2e3e5615bf215680bad6ee3d469df6454b8c6e8256"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3b11853f17b54c2fe47742c56d8a33bf49ce31caf531e87ac0d7d13d327c9334"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d26194ef6c13302f446d39972aaa36a1dda6450bc8949f5eb4c27f51191375bd"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e8d1ed93beda54bbd6131a2cb363a576eac746d5c26ba5b7556bc6f964425594"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca407133536f19bdec44b3da117ef0d12e43f6d4b56ac4c765f37eca501c7bda"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce05fde79d2bc2e46ac08aacbc142bead21614d937aac950be88dc79f9db9022"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:35583cd46d16f07c054efd18b5d46af4a2f070a2dd0a47914e66f3ff5efb2b1e"}, - {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:43350270bfc16b06ad3f6f07eab21f089adb835544417afda0f83256a8bf8b75"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b4399b59d1af5645bcee2072a463318114c39b8547437a7c2d6a186a1b5a0e2d"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6852c5b2a853b8b0ddc5993cd4f33bfffdca4fcc5d52f89dd4b8eada99379285"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcd266ae85c3d39df2f7e7d0e07f6c41a55e9a3123bb11f854412952deacd828"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecb2651956eea2aa0a2d099434134b1b68f1c31f9a5084d6d53f08ed43d45ff2"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:b279ab506ec4445166ac476fb4d3cc383accde1ea152998509a94d82547c8e2a"}, - {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:89183e55fb86e61d848ff83753f64cded119f5d6e1f553d14ffee3700d0a4a49"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2edbc75744235eea94d595a8b70fe279dd42f3296f76d5a86dde1d46e35f574"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:0e64bfde9a723274e9a71630c3e9494ed7b4c0f76a1faacf7fe294cd26f7ae7c"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0b5ca92bfa717759c052e345770792d02d1f43b06f9e790ca0a1db62838816f3"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f8a20266e695ec9d7a946a019c1d5ca4eddb6613d4f466888eee04f16eedb85"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63c38f45d8f2a2ec0f3a20073cccb335b9f99f73b3c69483cd52ebc75369d8a1"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dd26e3afe8a7b61422df3176e06664503d3f5973b94f45d5c45987e1cb711876"}, - {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:eddd5783a4a6309ce23432353cdb36220e25cbb779bfa9122320666508b44b88"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:56ae39d4036b753994476a1b935584071093b55c7a72e3b8288e68c313ca26e7"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f9939ca7e58c2758c01b40324a59c034ce0cebad18e0d4563a9b1beab3018243"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6c330c0eb815d212893c67a032e9dc1b38a803eccb32f3e8172c19cc69fbb439"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec11802450a2487cdf0e634b750a04cbdc1c4d066b97d94ce7dd2cb51ebb325b"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2b718f316b596f36e1dae097a7d5b91fc5b85e90bf08b01ff139bd8953b25af"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ed69af290c2b65169f0ba9034d1dc39a5db9459b32f1dd8b5f3f32a3fcf06eab"}, - {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f8a9c828277133af13f3859d1b6bf1c3cb6e9e1637df0e45312e6b7c2e622b1f"}, - {file = "tokenizers-0.19.1.tar.gz", hash = "sha256:ee59e6680ed0fdbe6b724cf38bd70400a0c1dd623b07ac729087270caeac88e3"}, -] - -[package.dependencies] -huggingface-hub = ">=0.16.4,<1.0" + {file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"}, + {file = "tokenizers-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:054c1cc9c6d68f7ffa4e810b3d5131e0ba511b6e4be34157aa08ee54c2f8d9ee"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9b9b070fdad06e347563b88c278995735292ded1132f8657084989a4c84a6d5"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea621a7eef4b70e1f7a4e84dd989ae3f0eeb50fc8690254eacc08acb623e82f1"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf7fd9a5141634fa3aa8d6b7be362e6ae1b4cda60da81388fa533e0b552c98fd"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44f2a832cd0825295f7179eaf173381dc45230f9227ec4b44378322d900447c9"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8b9ec69247a23747669ec4b0ca10f8e3dfb3545d550258129bd62291aabe8605"}, + {file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40b6a4c78da863ff26dbd5ad9a8ecc33d8a8d97b535172601cf00aee9d7ce9ce"}, + {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5ab2a4d21dcf76af60e05af8063138849eb1d6553a0d059f6534357bce8ba364"}, + {file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a47acfac7e511f6bbfcf2d3fb8c26979c780a91e06fb5b9a43831b2c0153d024"}, + {file = "tokenizers-0.15.2-cp310-none-win32.whl", hash = "sha256:064ff87bb6acdbd693666de9a4b692add41308a2c0ec0770d6385737117215f2"}, + {file = "tokenizers-0.15.2-cp310-none-win_amd64.whl", hash = "sha256:3b919afe4df7eb6ac7cafd2bd14fb507d3f408db7a68c43117f579c984a73843"}, + {file = "tokenizers-0.15.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:89cd1cb93e4b12ff39bb2d626ad77e35209de9309a71e4d3d4672667b4b256e7"}, + {file = "tokenizers-0.15.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cfed5c64e5be23d7ee0f0e98081a25c2a46b0b77ce99a4f0605b1ec43dd481fa"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a907d76dcfda37023ba203ab4ceeb21bc5683436ebefbd895a0841fd52f6f6f2"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20ea60479de6fc7b8ae756b4b097572372d7e4032e2521c1bbf3d90c90a99ff0"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48e2b9335be2bc0171df9281385c2ed06a15f5cf121c44094338306ab7b33f2c"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:112a1dd436d2cc06e6ffdc0b06d55ac019a35a63afd26475205cb4b1bf0bfbff"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4620cca5c2817177ee8706f860364cc3a8845bc1e291aaf661fb899e5d1c45b0"}, + {file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccd73a82751c523b3fc31ff8194702e4af4db21dc20e55b30ecc2079c5d43cb7"}, + {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:107089f135b4ae7817affe6264f8c7a5c5b4fd9a90f9439ed495f54fcea56fb4"}, + {file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0ff110ecc57b7aa4a594396525a3451ad70988e517237fe91c540997c4e50e29"}, + {file = "tokenizers-0.15.2-cp311-none-win32.whl", hash = "sha256:6d76f00f5c32da36c61f41c58346a4fa7f0a61be02f4301fd30ad59834977cc3"}, + {file = "tokenizers-0.15.2-cp311-none-win_amd64.whl", hash = "sha256:cc90102ed17271cf0a1262babe5939e0134b3890345d11a19c3145184b706055"}, + {file = "tokenizers-0.15.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f86593c18d2e6248e72fb91c77d413a815153b8ea4e31f7cd443bdf28e467670"}, + {file = "tokenizers-0.15.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0774bccc6608eca23eb9d620196687c8b2360624619623cf4ba9dc9bd53e8b51"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d0222c5b7c9b26c0b4822a82f6a7011de0a9d3060e1da176f66274b70f846b98"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3835738be1de66624fff2f4f6f6684775da4e9c00bde053be7564cbf3545cc66"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0143e7d9dcd811855c1ce1ab9bf5d96d29bf5e528fd6c7824d0465741e8c10fd"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db35825f6d54215f6b6009a7ff3eedee0848c99a6271c870d2826fbbedf31a38"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f5e64b0389a2be47091d8cc53c87859783b837ea1a06edd9d8e04004df55a5c"}, + {file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e0480c452217edd35eca56fafe2029fb4d368b7c0475f8dfa3c5c9c400a7456"}, + {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a33ab881c8fe70474980577e033d0bc9a27b7ab8272896e500708b212995d834"}, + {file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a308a607ca9de2c64c1b9ba79ec9a403969715a1b8ba5f998a676826f1a7039d"}, + {file = "tokenizers-0.15.2-cp312-none-win32.whl", hash = "sha256:b8fcfa81bcb9447df582c5bc96a031e6df4da2a774b8080d4f02c0c16b42be0b"}, + {file = "tokenizers-0.15.2-cp312-none-win_amd64.whl", hash = "sha256:38d7ab43c6825abfc0b661d95f39c7f8af2449364f01d331f3b51c94dcff7221"}, + {file = "tokenizers-0.15.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:38bfb0204ff3246ca4d5e726e8cc8403bfc931090151e6eede54d0e0cf162ef0"}, + {file = "tokenizers-0.15.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9c861d35e8286a53e06e9e28d030b5a05bcbf5ac9d7229e561e53c352a85b1fc"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:936bf3842db5b2048eaa53dade907b1160f318e7c90c74bfab86f1e47720bdd6"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:620beacc3373277700d0e27718aa8b25f7b383eb8001fba94ee00aeea1459d89"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2735ecbbf37e52db4ea970e539fd2d450d213517b77745114f92867f3fc246eb"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:473c83c5e2359bb81b0b6fde870b41b2764fcdd36d997485e07e72cc3a62264a"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:968fa1fb3c27398b28a4eca1cbd1e19355c4d3a6007f7398d48826bbe3a0f728"}, + {file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:865c60ae6eaebdde7da66191ee9b7db52e542ed8ee9d2c653b6d190a9351b980"}, + {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7c0d8b52664ab2d4a8d6686eb5effc68b78608a9008f086a122a7b2996befbab"}, + {file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:f33dfbdec3784093a9aebb3680d1f91336c56d86cc70ddf88708251da1fe9064"}, + {file = "tokenizers-0.15.2-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:d44ba80988ff9424e33e0a49445072ac7029d8c0e1601ad25a0ca5f41ed0c1d6"}, + {file = "tokenizers-0.15.2-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:dce74266919b892f82b1b86025a613956ea0ea62a4843d4c4237be2c5498ed3a"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0ef06b9707baeb98b316577acb04f4852239d856b93e9ec3a299622f6084e4be"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c73e2e74bbb07910da0d37c326869f34113137b23eadad3fc00856e6b3d9930c"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4eeb12daf02a59e29f578a865f55d87cd103ce62bd8a3a5874f8fdeaa82e336b"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ba9f6895af58487ca4f54e8a664a322f16c26bbb442effd01087eba391a719e"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ccec77aa7150e38eec6878a493bf8c263ff1fa8a62404e16c6203c64c1f16a26"}, + {file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f40604f5042ff210ba82743dda2b6aa3e55aa12df4e9f2378ee01a17e2855e"}, + {file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5645938a42d78c4885086767c70923abad047163d809c16da75d6b290cb30bbe"}, + {file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:05a77cbfebe28a61ab5c3891f9939cc24798b63fa236d84e5f29f3a85a200c00"}, + {file = "tokenizers-0.15.2-cp37-none-win32.whl", hash = "sha256:361abdc068e8afe9c5b818769a48624687fb6aaed49636ee39bec4e95e1a215b"}, + {file = "tokenizers-0.15.2-cp37-none-win_amd64.whl", hash = "sha256:7ef789f83eb0f9baeb4d09a86cd639c0a5518528f9992f38b28e819df397eb06"}, + {file = "tokenizers-0.15.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4fe1f74a902bee74a3b25aff180fbfbf4f8b444ab37c4d496af7afd13a784ed2"}, + {file = "tokenizers-0.15.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c4b89038a684f40a6b15d6b09f49650ac64d951ad0f2a3ea9169687bbf2a8ba"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d05a1b06f986d41aed5f2de464c003004b2df8aaf66f2b7628254bcbfb72a438"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508711a108684111ec8af89d3a9e9e08755247eda27d0ba5e3c50e9da1600f6d"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:daa348f02d15160cb35439098ac96e3a53bacf35885072611cd9e5be7d333daa"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:494fdbe5932d3416de2a85fc2470b797e6f3226c12845cadf054dd906afd0442"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2d60f5246f4da9373f75ff18d64c69cbf60c3bca597290cea01059c336d2470"}, + {file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93268e788825f52de4c7bdcb6ebc1fcd4a5442c02e730faa9b6b08f23ead0e24"}, + {file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6fc7083ab404019fc9acafe78662c192673c1e696bd598d16dc005bd663a5cf9"}, + {file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:41e39b41e5531d6b2122a77532dbea60e171ef87a3820b5a3888daa847df4153"}, + {file = "tokenizers-0.15.2-cp38-none-win32.whl", hash = "sha256:06cd0487b1cbfabefb2cc52fbd6b1f8d4c37799bd6c6e1641281adaa6b2504a7"}, + {file = "tokenizers-0.15.2-cp38-none-win_amd64.whl", hash = "sha256:5179c271aa5de9c71712e31cb5a79e436ecd0d7532a408fa42a8dbfa4bc23fd9"}, + {file = "tokenizers-0.15.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:82f8652a74cc107052328b87ea8b34291c0f55b96d8fb261b3880216a9f9e48e"}, + {file = "tokenizers-0.15.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:02458bee6f5f3139f1ebbb6d042b283af712c0981f5bc50edf771d6b762d5e4f"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c9a09cd26cca2e1c349f91aa665309ddb48d71636370749414fbf67bc83c5343"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:158be8ea8554e5ed69acc1ce3fbb23a06060bd4bbb09029431ad6b9a466a7121"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ddba9a2b0c8c81633eca0bb2e1aa5b3a15362b1277f1ae64176d0f6eba78ab1"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ef5dd1d39797044642dbe53eb2bc56435308432e9c7907728da74c69ee2adca"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:454c203164e07a860dbeb3b1f4a733be52b0edbb4dd2e5bd75023ffa8b49403a"}, + {file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cf6b7f1d4dc59af960e6ffdc4faffe6460bbfa8dce27a58bf75755ffdb2526d"}, + {file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2ef09bbc16519f6c25d0c7fc0c6a33a6f62923e263c9d7cca4e58b8c61572afb"}, + {file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c9a2ebdd2ad4ec7a68e7615086e633857c85e2f18025bd05d2a4399e6c5f7169"}, + {file = "tokenizers-0.15.2-cp39-none-win32.whl", hash = "sha256:918fbb0eab96fe08e72a8c2b5461e9cce95585d82a58688e7f01c2bd546c79d0"}, + {file = "tokenizers-0.15.2-cp39-none-win_amd64.whl", hash = "sha256:524e60da0135e106b254bd71f0659be9f89d83f006ea9093ce4d1fab498c6d0d"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6a9b648a58281c4672212fab04e60648fde574877d0139cd4b4f93fe28ca8944"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7c7d18b733be6bbca8a55084027f7be428c947ddf871c500ee603e375013ffba"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:13ca3611de8d9ddfbc4dc39ef54ab1d2d4aaa114ac8727dfdc6a6ec4be017378"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:237d1bf3361cf2e6463e6c140628e6406766e8b27274f5fcc62c747ae3c6f094"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67a0fe1e49e60c664915e9fb6b0cb19bac082ab1f309188230e4b2920230edb3"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4e022fe65e99230b8fd89ebdfea138c24421f91c1a4f4781a8f5016fd5cdfb4d"}, + {file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d857be2df69763362ac699f8b251a8cd3fac9d21893de129bc788f8baaef2693"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:708bb3e4283177236309e698da5fcd0879ce8fd37457d7c266d16b550bcbbd18"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c35e09e9899b72a76e762f9854e8750213f67567787d45f37ce06daf57ca78"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1257f4394be0d3b00de8c9e840ca5601d0a4a8438361ce9c2b05c7d25f6057b"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02272fe48280e0293a04245ca5d919b2c94a48b408b55e858feae9618138aeda"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dc3ad9ebc76eabe8b1d7c04d38be884b8f9d60c0cdc09b0aa4e3bcf746de0388"}, + {file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:32e16bdeffa7c4f46bf2152172ca511808b952701d13e7c18833c0b73cb5c23f"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fb16ba563d59003028b678d2361a27f7e4ae0ab29c7a80690efa20d829c81fdb"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:2277c36d2d6cdb7876c274547921a42425b6810d38354327dd65a8009acf870c"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1cf75d32e8d250781940d07f7eece253f2fe9ecdb1dc7ba6e3833fa17b82fcbc"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1b3b31884dc8e9b21508bb76da80ebf7308fdb947a17affce815665d5c4d028"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10122d8d8e30afb43bb1fe21a3619f62c3e2574bff2699cf8af8b0b6c5dc4a3"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d88b96ff0fe8e91f6ef01ba50b0d71db5017fa4e3b1d99681cec89a85faf7bf7"}, + {file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:37aaec5a52e959892870a7c47cef80c53797c0db9149d458460f4f31e2fb250e"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e2ea752f2b0fe96eb6e2f3adbbf4d72aaa1272079b0dfa1145507bd6a5d537e6"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:4b19a808d8799fda23504a5cd31d2f58e6f52f140380082b352f877017d6342b"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c86e5e068ac8b19204419ed8ca90f9d25db20578f5881e337d203b314f4104"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de19c4dc503c612847edf833c82e9f73cd79926a384af9d801dcf93f110cea4e"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea09acd2fe3324174063d61ad620dec3bcf042b495515f27f638270a7d466e8b"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cf27fd43472e07b57cf420eee1e814549203d56de00b5af8659cb99885472f1f"}, + {file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7ca22bd897537a0080521445d91a58886c8c04084a6a19e6c78c586e0cfa92a5"}, + {file = "tokenizers-0.15.2.tar.gz", hash = "sha256:e6e9c6e019dd5484be5beafc775ae6c925f4c69a3487040ed09b45e13df2cb91"}, +] + +[package.dependencies] +huggingface_hub = ">=0.16.4,<1.0" [package.extras] dev = ["tokenizers[testing]"] -docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] -testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] +docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] [[package]] name = "tomli" @@ -5515,13 +5539,13 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0, [[package]] name = "transformers" -version = "4.40.2" +version = "4.39.3" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = true python-versions = ">=3.8.0" files = [ - {file = "transformers-4.40.2-py3-none-any.whl", hash = "sha256:71cb94301ec211a2e1d4b8c8d18dcfaa902dfa00a089dceca167a8aa265d6f2d"}, - {file = "transformers-4.40.2.tar.gz", hash = "sha256:657b6054a2097671398d976ad46e60836e7e15f9ea9551631a96e33cb9240649"}, + {file = "transformers-4.39.3-py3-none-any.whl", hash = "sha256:7838034a12cca3168247f9d2d1dba6724c9de3ae0f73a108258c6b8fc5912601"}, + {file = "transformers-4.39.3.tar.gz", hash = "sha256:2586e5ff4150f122716fc40f5530e92871befc051848fbe82600969c535b762d"}, ] [package.dependencies] @@ -5533,21 +5557,21 @@ pyyaml = ">=5.1" regex = "!=2019.12.17" requests = "*" safetensors = ">=0.4.1" -tokenizers = ">=0.19,<0.20" +tokenizers = ">=0.14,<0.19" tqdm = ">=4.27" [package.extras] accelerate = ["accelerate (>=0.21.0)"] agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] codecarbon = ["codecarbon (==1.2.0)"] deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.19)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.19)", "torch", "torchaudio", "torchvision"] docs-specific = ["hf-doc-builder"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] @@ -5568,16 +5592,16 @@ serving = ["fastapi", "pydantic", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] timm = ["timm"] -tokenizers = ["tokenizers (>=0.19,<0.20)"] +tokenizers = ["tokenizers (>=0.14,<0.19)"] torch = ["accelerate (>=0.21.0)", "torch"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] +torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.14,<0.19)", "torch", "tqdm (>=4.27)"] video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] @@ -5688,7 +5712,7 @@ files = [ name = "tzdata" version = "2024.1" description = "Provider of IANA time zone data" -optional = false +optional = true python-versions = ">=2" files = [ {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, @@ -5699,7 +5723,7 @@ files = [ name = "tzlocal" version = "5.2" description = "tzinfo object for the local timezone" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "tzlocal-5.2-py3-none-any.whl", hash = "sha256:49816ef2fe65ea8ac19d19aa7a1ae0551c834303d5014c6d5a62e4cbda8047b8"}, @@ -6046,9 +6070,10 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["anthropic", "beautifulsoup4", "boto3", "cohere", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "opensearch-py", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "redis", "snowflake-sqlalchemy", "sqlalchemy-redshift", "torch", "trafilatura", "transformers", "voyageai"] +all = ["anthropic", "beautifulsoup4", "boto3", "cohere", "elevenlabs", "filetype", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "redis", "snowflake-sqlalchemy", "sqlalchemy-redshift", "torch", "trafilatura", "transformers", "voyageai"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] +drivers-embedding-cohere = ["cohere"] drivers-embedding-google = ["google-generativeai"] drivers-embedding-huggingface = ["huggingface-hub", "transformers"] drivers-embedding-voyageai = ["voyageai"] @@ -6064,6 +6089,7 @@ drivers-prompt-cohere = ["cohere"] drivers-prompt-google = ["google-generativeai"] drivers-prompt-huggingface = ["huggingface-hub", "transformers"] drivers-prompt-huggingface-pipeline = ["huggingface-hub", "torch", "transformers"] +drivers-prompt-ollama = ["ollama"] drivers-sql-postgres = ["pgvector", "psycopg2-binary"] drivers-sql-redshift = ["boto3", "sqlalchemy-redshift"] drivers-sql-snowflake = ["snowflake-sqlalchemy"] @@ -6085,4 +6111,4 @@ loaders-pdf = ["pypdf"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4e98cb17a098a86dc3c109bad3cb8b5299e333d20e7c12903a7fe9472b7a3d31" +content-hash = "6ccbbba60b534e4756d1c36e37ce1379ee2126d37b822f186b5dbb8e8f701ff3" diff --git a/pyproject.toml b/pyproject.toml index 60fc345f2..0869256a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "griptape" -version = "0.26.0" +version = "0.27.0" description = "Modular Python framework for LLM workflows, tools, memory, and data." authors = ["Griptape "] license = "Apache 2.0" @@ -27,15 +27,14 @@ numpy = ">=1" stringcase = "^1.2.0" docker = "^7.1.0" sqlalchemy = "~=1.0" -dateparser = "^1.1.8" requests = "^2" # drivers -cohere = { version = ">=4", optional = true } +cohere = { version = "^5.5.4", optional = true } anthropic = { version = "^0.20.0", optional = true } -transformers = { version = "^4.30", optional = true } +transformers = { version = "^4.39.3", optional = true } huggingface-hub = { version = ">=0.13", optional = true } -boto3 = { version = "^1.28.2", optional = true } +boto3 = { version = "^1.34.119", optional = true } sqlalchemy-redshift = { version = "*", optional = true } snowflake-sqlalchemy = { version = "^1.4.7", optional = true } pinecone-client = { version = "^3", optional = true } @@ -54,6 +53,7 @@ voyageai = {version = "^0.2.1", optional = true} elevenlabs = {version = "^1.1.2", optional = true} torch = {version = "^2.3.0", optional = true} pusher = {version = "^3.3.2", optional = true} +ollama = {version = "^0.2.1", optional = true} # loaders pandas = {version = "^1.3", optional = true} @@ -70,6 +70,7 @@ drivers-prompt-huggingface-pipeline = ["huggingface-hub", "transformers", "torch drivers-prompt-amazon-bedrock = ["boto3", "anthropic"] drivers-prompt-amazon-sagemaker = ["boto3", "transformers"] drivers-prompt-google = ["google-generativeai"] +drivers-prompt-ollama = ["ollama"] drivers-sql-redshift = ["sqlalchemy-redshift", "boto3"] drivers-sql-snowflake = ["snowflake-sqlalchemy", "snowflake", "snowflake-connector-python"] @@ -91,6 +92,7 @@ drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-huggingface = ["huggingface-hub", "transformers"] drivers-embedding-voyageai = ["voyageai"] drivers-embedding-google = ["google-generativeai"] +drivers-embedding-cohere = ["cohere"] drivers-web-scraper-trafilatura = ["trafilatura"] drivers-web-scraper-markdownify = ["playwright", "beautifulsoup4", "markdownify"] @@ -131,6 +133,7 @@ all = [ "elevenlabs", "torch", "pusher", + "ollama", # loaders "pandas", diff --git a/tests/mocks/mock_audio_input_task.py b/tests/mocks/mock_audio_input_task.py new file mode 100644 index 000000000..d6a27d968 --- /dev/null +++ b/tests/mocks/mock_audio_input_task.py @@ -0,0 +1,9 @@ +from attrs import define +from griptape.artifacts import TextArtifact +from griptape.tasks.base_audio_input_task import BaseAudioInputTask + + +@define +class MockAudioInputTask(BaseAudioInputTask): + def run(self) -> TextArtifact: + return TextArtifact(self.input.to_text()) diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index dc4cde69e..3235f7cd5 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -13,10 +13,15 @@ class MockPromptDriver(BasePromptDriver): model: str = "test-model" tokenizer: BaseTokenizer = MockTokenizer(model="test-model", max_input_tokens=4096, max_output_tokens=4096) - mock_output: str | Callable[[], str] = field(default="mock output", kw_only=True) + mock_input: str | Callable[[], str] = field(default="mock input", kw_only=True) + mock_output: str | Callable[[PromptStack], str] = field(default="mock output", kw_only=True) def try_run(self, prompt_stack: PromptStack) -> TextArtifact: - return TextArtifact(value=self.mock_output() if isinstance(self.mock_output, Callable) else self.mock_output) + return TextArtifact( + value=self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output + ) def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]: - yield TextArtifact(value=self.mock_output() if isinstance(self.mock_output, Callable) else self.mock_output) + yield TextArtifact( + value=self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output + ) diff --git a/tests/mocks/mock_tokenizer.py b/tests/mocks/mock_tokenizer.py index a333f9a13..eff103e99 100644 --- a/tests/mocks/mock_tokenizer.py +++ b/tests/mocks/mock_tokenizer.py @@ -1,13 +1,9 @@ from __future__ import annotations -from attrs import define, field +from attrs import define from griptape.tokenizers import BaseTokenizer @define() class MockTokenizer(BaseTokenizer): - model: str = field(kw_only=True) - max_input_tokens: int = field(default=1000, kw_only=True) - max_output_tokens: int = field(default=1000, kw_only=True) - - def count_tokens(self, text: str | list[dict]) -> int: + def count_tokens(self, text: str) -> int: return len(text) diff --git a/tests/unit/chunkers/test_text_chunker.py b/tests/unit/chunkers/test_text_chunker.py index c1fb40137..243b287e1 100644 --- a/tests/unit/chunkers/test_text_chunker.py +++ b/tests/unit/chunkers/test_text_chunker.py @@ -106,3 +106,7 @@ def test_separators(self, chunker): assert chunks[5].value.endswith("? foo-12?") assert chunks[6].value.endswith(" foo-5") assert chunks[7].value.endswith(" foo-16") + + def test_chunk_with_max_tokens(self, chunker): + with pytest.raises(ValueError): + TextChunker(max_tokens=-1) diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index 5b8c63a98..33b286f94 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -38,7 +38,6 @@ def test_to_dict(self, config): "prompt_driver": { "max_tokens": None, "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, "stream": False, "temperature": 0.1, "type": "AmazonBedrockPromptDriver", diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_structure_config.py index 8279fb091..1dd83f96c 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_structure_config.py @@ -18,7 +18,7 @@ def test_to_dict(self, config): "prompt_driver": { "type": "AnthropicPromptDriver", "temperature": 0.1, - "max_tokens": None, + "max_tokens": 1000, "stream": False, "model": "claude-3-opus-20240229", "top_p": 0.999, diff --git a/tests/unit/config/test_cohere_structure_config.py b/tests/unit/config/test_cohere_structure_config.py new file mode 100644 index 000000000..1dc585421 --- /dev/null +++ b/tests/unit/config/test_cohere_structure_config.py @@ -0,0 +1,38 @@ +from pytest import fixture +from griptape.config import CohereStructureConfig + + +class TestCohereStructureConfig: + @fixture + def config(self): + return CohereStructureConfig(api_key="api_key") + + def test_to_dict(self, config): + assert config.to_dict() == { + "type": "CohereStructureConfig", + "image_generation_driver": {"type": "DummyImageGenerationDriver"}, + "image_query_driver": {"type": "DummyImageQueryDriver"}, + "conversation_memory_driver": None, + "text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, + "audio_transcription_driver": {"type": "DummyAudioTranscriptionDriver"}, + "prompt_driver": { + "type": "CoherePromptDriver", + "temperature": 0.1, + "max_tokens": None, + "stream": False, + "model": "command-r", + }, + "embedding_driver": { + "type": "CohereEmbeddingDriver", + "model": "embed-english-v3.0", + "input_type": "search_document", + }, + "vector_store_driver": { + "type": "LocalVectorStoreDriver", + "embedding_driver": { + "type": "CohereEmbeddingDriver", + "model": "embed-english-v3.0", + "input_type": "search_document", + }, + }, + } diff --git a/tests/unit/drivers/embedding/test_cohere_embedding_driver.py b/tests/unit/drivers/embedding/test_cohere_embedding_driver.py new file mode 100644 index 000000000..af6a5576d --- /dev/null +++ b/tests/unit/drivers/embedding/test_cohere_embedding_driver.py @@ -0,0 +1,21 @@ +from unittest.mock import Mock +import pytest +from griptape.drivers import CohereEmbeddingDriver + + +class TestCohereEmbeddingDriver: + @pytest.fixture(autouse=True) + def mock_client(self, mocker): + mock_client = mocker.patch("cohere.Client").return_value + + mock_client.embed.return_value = Mock(embeddings=[[0, 1, 0]]) + + return mock_client + + def test_init(self): + assert CohereEmbeddingDriver(model="embed-english-v3.0", api_key="bar", input_type="search_document") + + def test_try_embed_chunk(self): + assert CohereEmbeddingDriver( + model="embed-english-v3.0", api_key="bar", input_type="search_document" + ).try_embed_chunk("foobar") == [0, 1, 0] diff --git a/tests/unit/drivers/embedding/test_sagemaker_embedding_driver.py b/tests/unit/drivers/embedding/test_sagemaker_embedding_driver.py deleted file mode 100644 index 9ceb98557..000000000 --- a/tests/unit/drivers/embedding/test_sagemaker_embedding_driver.py +++ /dev/null @@ -1,33 +0,0 @@ -import pytest -from unittest import mock -from griptape.drivers import AmazonSageMakerEmbeddingDriver, SageMakerHuggingFaceEmbeddingModelDriver -from griptape.tokenizers.openai_tokenizer import OpenAiTokenizer - - -class TestAmazonSagemakerEmbeddingDriver: - @pytest.fixture(autouse=True) - def mock_session(self, mocker): - fake_embeddings = b'{"embedding": [[0, 1, 0]]}' - mock_session_class = mocker.patch("boto3.Session") - mock_session_object = mock.Mock() - mock_client = mock.Mock() - mock_response = mock.Mock() - - mock_response.get().read.return_value = fake_embeddings - mock_client.invoke_endpoint.return_value = mock_response - mock_session_object.client.return_value = mock_client - mock_session_class.return_value = mock_session_object - - def test_init(self): - assert AmazonSageMakerEmbeddingDriver( - model="test-endpoint", - tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), - embedding_model_driver=SageMakerHuggingFaceEmbeddingModelDriver(), - ) - - def test_try_embed_chunk(self): - assert AmazonSageMakerEmbeddingDriver( - model="test-endpoint", - tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), - embedding_model_driver=SageMakerHuggingFaceEmbeddingModelDriver(), - ).try_embed_chunk("foobar") == [0, 1, 0] diff --git a/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py b/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py new file mode 100644 index 000000000..268b47c54 --- /dev/null +++ b/tests/unit/drivers/embedding/test_sagemaker_jumpstart_embedding_driver.py @@ -0,0 +1,59 @@ +import pytest +from unittest import mock +from griptape.drivers import AmazonSageMakerJumpstartEmbeddingDriver +from griptape.tokenizers.openai_tokenizer import OpenAiTokenizer + + +class TestAmazonSageMakerJumpstartEmbeddingDriver: + @pytest.fixture(autouse=True) + def mock_client(self, mocker): + mock_session_class = mocker.patch("boto3.Session") + mock_session_object = mock.Mock() + mock_client = mock.Mock() + mock_response = mock.Mock() + + mock_client.invoke_endpoint.return_value = mock_response + mock_session_object.client.return_value = mock_client + mock_session_class.return_value = mock_session_object + + return mock_response + + def test_init(self): + assert AmazonSageMakerJumpstartEmbeddingDriver( + endpoint="test-endpoint", + model="test-endpoint", + tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), + ) + + def test_try_embed_chunk(self, mock_client): + mock_client.get().read.return_value = b'{"embedding": [[0, 1, 0]]}' + assert AmazonSageMakerJumpstartEmbeddingDriver( + endpoint="test-endpoint", + model="test-model", + tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), + ).try_embed_chunk("foobar") == [0, 1, 0] + + mock_client.get().read.return_value = b'{"embedding": [0, 2, 0]}' + assert AmazonSageMakerJumpstartEmbeddingDriver( + endpoint="test-endpoint", + model="test-model", + tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), + ).try_embed_chunk("foobar") == [0, 2, 0] + + mock_client.get().read.return_value = b'{"embedding": []}' + with pytest.raises(ValueError) as e: + assert AmazonSageMakerJumpstartEmbeddingDriver( + endpoint="test-endpoint", + model="test-model", + tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), + ).try_embed_chunk("foobar") == [0, 2, 0] + assert str(e) == "model response is empty" + + mock_client.get().read.return_value = b"{}" + with pytest.raises(ValueError) as e: + assert AmazonSageMakerJumpstartEmbeddingDriver( + endpoint="test-endpoint", + model="test-model", + tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), + ).try_embed_chunk("foobar") == [0, 2, 0] + assert str(e) == "invalid response from model" diff --git a/tests/unit/drivers/embedding_model/__init__.py b/tests/unit/drivers/embedding_model/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/drivers/embedding_model/test_sagemaker_huggingface_embedding_model_driver.py b/tests/unit/drivers/embedding_model/test_sagemaker_huggingface_embedding_model_driver.py deleted file mode 100644 index b63e593b3..000000000 --- a/tests/unit/drivers/embedding_model/test_sagemaker_huggingface_embedding_model_driver.py +++ /dev/null @@ -1,21 +0,0 @@ -import boto3 -import pytest -from griptape.drivers import AmazonSageMakerEmbeddingDriver, SageMakerHuggingFaceEmbeddingModelDriver -from tests.mocks.mock_tokenizer import MockTokenizer - - -class TestSageMakerHuggingFaceEmbeddingModelDriver: - @pytest.fixture - def driver(self): - return AmazonSageMakerEmbeddingDriver( - model="foo", - session=boto3.Session(region_name="us-east-1"), - tokenizer=MockTokenizer(model="foo"), - embedding_model_driver=SageMakerHuggingFaceEmbeddingModelDriver(), - ).embedding_model_driver - - def test_chunk_to_model_params(self, driver): - assert driver.chunk_to_model_params("foobar")["text_inputs"] == "foobar" - - def test_process_output(self, driver): - assert driver.process_output({"embedding": [["foobar"]]}) == ["foobar"] diff --git a/tests/unit/drivers/embedding_model/test_sagemaker_tensorflow_hub_embedding_model_driver.py b/tests/unit/drivers/embedding_model/test_sagemaker_tensorflow_hub_embedding_model_driver.py deleted file mode 100644 index 7080b93fb..000000000 --- a/tests/unit/drivers/embedding_model/test_sagemaker_tensorflow_hub_embedding_model_driver.py +++ /dev/null @@ -1,21 +0,0 @@ -import boto3 -import pytest -from griptape.drivers import AmazonSageMakerEmbeddingDriver, SageMakerTensorFlowHubEmbeddingModelDriver -from tests.mocks.mock_tokenizer import MockTokenizer - - -class TestSageMakerTensorFlowHubFaceEmbeddingModelDriver: - @pytest.fixture - def driver(self): - return AmazonSageMakerEmbeddingDriver( - model="foo", - session=boto3.Session(region_name="us-east-1"), - tokenizer=MockTokenizer(model="foo"), - embedding_model_driver=SageMakerTensorFlowHubEmbeddingModelDriver(), - ).embedding_model_driver - - def test_chunk_to_model_params(self, driver): - assert driver.chunk_to_model_params("foobar")["text_inputs"] == "foobar" - - def test_process_output(self, driver): - assert driver.process_output({"embedding": ["foobar"]}) == ["foobar"] diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index 1bd94d3e9..8aa345595 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,117 +1,83 @@ -from botocore.response import StreamingBody -from griptape.artifacts import TextArtifact -from griptape.drivers import AmazonBedrockPromptDriver -from griptape.drivers import BedrockClaudePromptModelDriver, BedrockTitanPromptModelDriver -from griptape.tokenizers import AnthropicTokenizer, BedrockTitanTokenizer -from io import StringIO -from unittest.mock import Mock -import json import pytest +from griptape.utils import PromptStack +from griptape.drivers import AmazonBedrockPromptDriver + class TestAmazonBedrockPromptDriver: @pytest.fixture - def mock_prompt_model_driver(self): - mock_prompt_model_driver = Mock() - mock_prompt_model_driver.prompt_stack_to_model_params.return_value = {"model-param-key": "model-param-value"} - mock_prompt_model_driver.process_output.return_value = TextArtifact("model-output") - return mock_prompt_model_driver - - @pytest.fixture(autouse=True) - def mock_client(self, mocker): - return mocker.patch("boto3.Session").return_value.client.return_value - - def test_init(self): - assert AmazonBedrockPromptDriver(model="anthropic.claude", prompt_model_driver=BedrockClaudePromptModelDriver()) - - def test_custom_tokenizer(self): - assert isinstance( - AmazonBedrockPromptDriver( - model="anthropic.claude", prompt_model_driver=BedrockClaudePromptModelDriver() - ).tokenizer, - AnthropicTokenizer, - ) + def mock_converse(self, mocker): + mock_converse = mocker.patch("boto3.Session").return_value.client.return_value.converse - assert isinstance( - AmazonBedrockPromptDriver( - model="titan", - tokenizer=BedrockTitanTokenizer(model="amazon"), - prompt_model_driver=BedrockTitanPromptModelDriver(), - ).tokenizer, - BedrockTitanTokenizer, - ) + mock_converse.return_value = {"output": {"message": {"content": [{"text": "model-output"}]}}} + + return mock_converse + + @pytest.fixture + def mock_converse_stream(self, mocker): + mock_converse_stream = mocker.patch("boto3.Session").return_value.client.return_value.converse_stream + + mock_converse_stream.return_value = {"stream": [{"contentBlockDelta": {"delta": {"text": "model-output"}}}]} + + return mock_converse_stream + + @pytest.fixture + def prompt_stack(self): + prompt_stack = PromptStack() + prompt_stack.add_generic_input("generic-input") + prompt_stack.add_system_input("system-input") + prompt_stack.add_user_input("user-input") + prompt_stack.add_assistant_input("assistant-input") + + return prompt_stack - @pytest.mark.parametrize("model_inputs", [{"model-input-key": "model-input-value"}, "not-a-dict"]) - def test_try_run(self, model_inputs, mock_prompt_model_driver, mock_client): + @pytest.fixture + def messages(self): + return [ + {"role": "user", "content": [{"text": "generic-input"}]}, + {"role": "system", "content": [{"text": "system-input"}]}, + {"role": "user", "content": [{"text": "user-input"}]}, + {"role": "assistant", "content": [{"text": "assistant-input"}]}, + ] + + def test_try_run(self, mock_converse, prompt_stack, messages): # Given - driver = AmazonBedrockPromptDriver(model="model", prompt_model_driver=mock_prompt_model_driver) - prompt_stack = "prompt-stack" - response_body = "invoke-model-response-body" - mock_prompt_model_driver.prompt_stack_to_model_input.return_value = model_inputs - mock_client.invoke_model.return_value = {"body": to_streaming_body(response_body)} + driver = AmazonBedrockPromptDriver(model="ai21.j2") # When text_artifact = driver.try_run(prompt_stack) # Then - mock_prompt_model_driver.prompt_stack_to_model_input.assert_called_once_with(prompt_stack) - mock_prompt_model_driver.prompt_stack_to_model_params.assert_called_once_with(prompt_stack) - mock_client.invoke_model.assert_called_once_with( + mock_converse.assert_called_once_with( modelId=driver.model, - contentType="application/json", - accept="application/json", - body=json.dumps( - { - **mock_prompt_model_driver.prompt_stack_to_model_params.return_value, - **(model_inputs if isinstance(model_inputs, dict) else {}), - } - ), + messages=[ + {"role": "user", "content": [{"text": "generic-input"}]}, + {"role": "user", "content": [{"text": "user-input"}]}, + {"role": "assistant", "content": [{"text": "assistant-input"}]}, + ], + system=[{"text": "system-input"}], + inferenceConfig={"temperature": driver.temperature}, + additionalModelRequestFields={}, ) - mock_prompt_model_driver.process_output.assert_called_once_with(response_body) - assert text_artifact == mock_prompt_model_driver.process_output.return_value + assert text_artifact.value == "model-output" - @pytest.mark.parametrize("model_inputs", [{"model-input-key": "model-input-value"}, "not-a-dict"]) - def test_try_stream_run(self, model_inputs, mock_prompt_model_driver, mock_client): + def test_try_stream_run(self, mock_converse_stream, prompt_stack, messages): # Given - driver = AmazonBedrockPromptDriver(model="model", prompt_model_driver=mock_prompt_model_driver, stream=True) - prompt_stack = "prompt-stack" - model_response = "invoke-model-response-body" - response_body = [{"chunk": {"bytes": model_response}}] - mock_prompt_model_driver.prompt_stack_to_model_input.return_value = model_inputs - mock_client.invoke_model_with_response_stream.return_value = {"body": response_body} + driver = AmazonBedrockPromptDriver(model="ai21.j2", stream=True) # When text_artifact = next(driver.try_stream(prompt_stack)) # Then - mock_prompt_model_driver.prompt_stack_to_model_input.assert_called_once_with(prompt_stack) - mock_prompt_model_driver.prompt_stack_to_model_params.assert_called_once_with(prompt_stack) - mock_client.invoke_model_with_response_stream.assert_called_once_with( + mock_converse_stream.assert_called_once_with( modelId=driver.model, - contentType="application/json", - accept="application/json", - body=json.dumps( - { - **mock_prompt_model_driver.prompt_stack_to_model_params.return_value, - **(model_inputs if isinstance(model_inputs, dict) else {}), - } - ), + messages=[ + {"role": "user", "content": [{"text": "generic-input"}]}, + {"role": "user", "content": [{"text": "user-input"}]}, + {"role": "assistant", "content": [{"text": "assistant-input"}]}, + ], + system=[{"text": "system-input"}], + inferenceConfig={"temperature": driver.temperature}, + additionalModelRequestFields={}, ) - mock_prompt_model_driver.process_output.assert_called_once_with(model_response) - assert text_artifact.value == mock_prompt_model_driver.process_output.return_value.value - - def test_try_run_throws_on_empty_response(self, mock_prompt_model_driver, mock_client): - # Given - driver = AmazonBedrockPromptDriver(model="model", prompt_model_driver=mock_prompt_model_driver) - mock_client.invoke_model.return_value = {"body": to_streaming_body("")} - - # When - with pytest.raises(Exception) as e: - driver.try_run("prompt-stack") - - # Then - assert e.value.args[0] == "model response is empty" - - -def to_streaming_body(text: str) -> StreamingBody: - return StreamingBody(StringIO(text), len(text)) + assert text_artifact.value == "model-output" diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py new file mode 100644 index 000000000..4ae8fe944 --- /dev/null +++ b/tests/unit/drivers/prompt/test_amazon_sagemaker_jumpstart_prompt_driver.py @@ -0,0 +1,132 @@ +from typing import Any +from botocore.response import StreamingBody +from griptape.tokenizers import HuggingFaceTokenizer +from griptape.drivers.prompt.amazon_sagemaker_jumpstart_prompt_driver import AmazonSageMakerJumpstartPromptDriver +from griptape.utils import PromptStack +from io import BytesIO +import json +import pytest + + +def to_streaming_body(data: Any) -> StreamingBody: + bytes = json.dumps(data).encode("utf-8") + + return StreamingBody(BytesIO(bytes), len(bytes)) + + +class TestAmazonSageMakerJumpstartPromptDriver: + @pytest.fixture(autouse=True) + def tokenizer(self, mocker): + from_pretrained = mocker.patch("transformers.AutoTokenizer").from_pretrained + from_pretrained.return_value.apply_chat_template.return_value = "foo\n\nUser: bar" + from_pretrained.return_value.model_max_length = 8000 + from_pretrained.return_value.eos_token_id = 1 + + return from_pretrained + + @pytest.fixture(autouse=True) + def mock_client(self, mocker): + return mocker.patch("boto3.Session").return_value.client.return_value + + def test_init(self): + assert AmazonSageMakerJumpstartPromptDriver(endpoint="foo", model="bar") + + def test_try_run(self, mock_client): + # Given + driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") + prompt_stack = PromptStack() + prompt_stack.add_user_input("prompt-stack") + + # When + response_body = [{"generated_text": "foobar"}] + mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body(response_body)} + text_artifact = driver.try_run(prompt_stack) + assert isinstance(driver.tokenizer, HuggingFaceTokenizer) + + # Then + mock_client.invoke_endpoint.assert_called_with( + EndpointName=driver.endpoint, + ContentType="application/json", + Body=json.dumps( + { + "inputs": "foo\n\nUser: bar", + "parameters": { + "temperature": driver.temperature, + "max_new_tokens": 250, + "do_sample": True, + "eos_token_id": 1, + "stop_strings": [], + "return_full_text": False, + }, + } + ), + CustomAttributes="accept_eula=true", + ) + + assert text_artifact.value == "foobar" + + # When + response_body = {"generated_text": "foobar"} + mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body(response_body)} + text_artifact = driver.try_run(prompt_stack) + assert isinstance(driver.tokenizer, HuggingFaceTokenizer) + + # Then + mock_client.invoke_endpoint.assert_called_with( + EndpointName=driver.endpoint, + ContentType="application/json", + Body=json.dumps( + { + "inputs": "foo\n\nUser: bar", + "parameters": { + "temperature": driver.temperature, + "max_new_tokens": 250, + "do_sample": True, + "eos_token_id": 1, + "stop_strings": [], + "return_full_text": False, + }, + } + ), + CustomAttributes="accept_eula=true", + ) + + assert text_artifact.value == "foobar" + + def test_try_stream(self, mock_client): + # Given + driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") + prompt_stack = PromptStack() + prompt_stack.add_user_input("prompt-stack") + + # When + with pytest.raises(NotImplementedError) as e: + driver.try_stream(prompt_stack) + + # Then + assert e.value.args[0] == "streaming is not supported" + + def test_stream_init(self): + # Given + driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") + + # When + with pytest.raises(ValueError) as e: + driver.stream = True + + # Then + assert e.value.args[0] == "streaming is not supported" + + def test_try_run_throws_on_empty_response(self, mock_client): + # Given + driver = AmazonSageMakerJumpstartPromptDriver(endpoint="model", model="model") + mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body([])} + prompt_stack = PromptStack() + prompt_stack.add_user_input("prompt-stack") + + # When + with pytest.raises(Exception) as e: + driver.try_run(prompt_stack) + + # Then + assert e.value.args[0] == "model response is empty" diff --git a/tests/unit/drivers/prompt/test_amazon_sagemaker_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_sagemaker_prompt_driver.py deleted file mode 100644 index c6692e1ba..000000000 --- a/tests/unit/drivers/prompt/test_amazon_sagemaker_prompt_driver.py +++ /dev/null @@ -1,90 +0,0 @@ -from botocore.response import StreamingBody -from griptape.artifacts import TextArtifact -from griptape.drivers import AmazonSageMakerPromptDriver, SageMakerFalconPromptModelDriver -from griptape.tokenizers import HuggingFaceTokenizer, OpenAiTokenizer -from griptape.utils import PromptStack -from io import BytesIO -from unittest.mock import Mock -import json -import pytest - - -class TestAmazonSageMakerPromptDriver: - @pytest.fixture - def mock_model_driver(self): - mock_model_driver = Mock() - mock_model_driver.prompt_stack_to_model_input.return_value = "model-inputs" - mock_model_driver.prompt_stack_to_model_params.return_value = "model-params" - mock_model_driver.process_output.return_value = TextArtifact("model-output") - return mock_model_driver - - @pytest.fixture(autouse=True) - def mock_client(self, mocker): - return mocker.patch("boto3.Session").return_value.client.return_value - - def test_init(self): - assert AmazonSageMakerPromptDriver(endpoint="foo", prompt_model_driver=SageMakerFalconPromptModelDriver()) - - def test_custom_tokenizer(self): - assert isinstance( - AmazonSageMakerPromptDriver( - endpoint="foo", prompt_model_driver=SageMakerFalconPromptModelDriver() - ).tokenizer, - HuggingFaceTokenizer, - ) - - assert isinstance( - AmazonSageMakerPromptDriver( - endpoint="foo", - tokenizer=OpenAiTokenizer(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL), - prompt_model_driver=SageMakerFalconPromptModelDriver(), - ).tokenizer, - OpenAiTokenizer, - ) - - def test_try_run(self, mock_model_driver, mock_client): - # Given - driver = AmazonSageMakerPromptDriver(endpoint="model", prompt_model_driver=mock_model_driver) - prompt_stack = PromptStack() - prompt_stack.add_user_input("prompt-stack") - response_body = "invoke-endpoint-response-body" - mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body(response_body)} - - # When - text_artifact = driver.try_run(prompt_stack) - - # Then - mock_model_driver.prompt_stack_to_model_input.assert_called_once_with(prompt_stack) - mock_model_driver.prompt_stack_to_model_params.assert_called_once_with(prompt_stack) - mock_client.invoke_endpoint.assert_called_once_with( - EndpointName=driver.endpoint, - ContentType="application/json", - Body=json.dumps( - { - "inputs": mock_model_driver.prompt_stack_to_model_input.return_value, - "parameters": mock_model_driver.prompt_stack_to_model_params.return_value, - } - ), - CustomAttributes="accept_eula=true", - ) - mock_model_driver.process_output.assert_called_once_with(response_body) - assert text_artifact == mock_model_driver.process_output.return_value - - def test_try_run_throws_on_empty_response(self, mock_model_driver, mock_client): - # Given - driver = AmazonSageMakerPromptDriver(endpoint="model", prompt_model_driver=mock_model_driver) - mock_client.invoke_endpoint.return_value = {"Body": to_streaming_body("")} - prompt_stack = PromptStack() - prompt_stack.add_user_input("prompt-stack") - - # When - with pytest.raises(Exception) as e: - driver.try_run(prompt_stack) - - # Then - assert e.value.args[0] == "model response is empty" - - -def to_streaming_body(text: str) -> StreamingBody: - bytes = json.dumps(text).encode("utf-8") - return StreamingBody(BytesIO(bytes), len(bytes)) diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index c5009afac..22178bbf3 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -63,9 +63,9 @@ def test_try_run(self, mock_client, model, system_enabled): # Then mock_client.return_value.messages.create.assert_called_once_with( messages=expected_messages, - stop_sequences=["<|Response|>"], + stop_sequences=[], model=driver.model, - max_tokens=4091, + max_tokens=1000, temperature=0.1, top_p=0.999, top_k=250, @@ -106,9 +106,9 @@ def test_try_stream_run(self, mock_stream_client, model, system_enabled): # Then mock_stream_client.return_value.messages.create.assert_called_once_with( messages=expected_messages, - stop_sequences=["<|Response|>"], + stop_sequences=[], model=driver.model, - max_tokens=4091, + max_tokens=1000, temperature=0.1, stream=True, top_p=0.999, diff --git a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py index 9446d0520..f6bd12d80 100644 --- a/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_azure_openai_chat_prompt_driver.py @@ -7,11 +7,11 @@ class TestAzureOpenAiChatPromptDriver(TestOpenAiChatPromptDriverFixtureMixin): @pytest.fixture def mock_chat_completion_create(self, mocker): - mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.with_raw_response.create + mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.chat.completions.create mock_choice = Mock() mock_choice.message.content = "model-output" mock_chat_create.return_value.headers = {} - mock_chat_create.return_value.parse.return_value.choices = [mock_choice] + mock_chat_create.return_value.choices = [mock_choice] return mock_chat_create @pytest.fixture diff --git a/tests/unit/drivers/prompt/test_azure_openai_completion_prompt_driver.py b/tests/unit/drivers/prompt/test_azure_openai_completion_prompt_driver.py deleted file mode 100644 index 65758843a..000000000 --- a/tests/unit/drivers/prompt/test_azure_openai_completion_prompt_driver.py +++ /dev/null @@ -1,75 +0,0 @@ -import pytest -from unittest.mock import Mock -from griptape.drivers import AzureOpenAiCompletionPromptDriver -from tests.unit.drivers.prompt.test_openai_completion_prompt_driver import TestOpenAiCompletionPromptDriverFixtureMixin -from unittest.mock import ANY - - -class TestAzureOpenAiCompletionPromptDriver(TestOpenAiCompletionPromptDriverFixtureMixin): - @pytest.fixture - def mock_completion_create(self, mocker): - mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.completions.create - mock_choice = Mock() - mock_choice.text = "model-output" - mock_chat_create.return_value.choices = [mock_choice] - return mock_chat_create - - @pytest.fixture - def mock_completion_stream_create(self, mocker): - mock_chat_create = mocker.patch("openai.AzureOpenAI").return_value.completions.create - mock_chunk = Mock() - mock_choice = Mock() - mock_choice.text = "model-output" - mock_chunk.choices = [mock_choice] - mock_chat_create.return_value = iter([mock_chunk]) - return mock_chat_create - - def test_init(self): - assert AzureOpenAiCompletionPromptDriver( - azure_endpoint="endpoint", azure_deployment="deployment", model="text-davinci-003" - ) - assert ( - AzureOpenAiCompletionPromptDriver(azure_endpoint="endpoint", model="text-davinci-003").azure_deployment - == "text-davinci-003" - ) - - def test_try_run(self, mock_completion_create, prompt_stack, prompt): - # Given - driver = AzureOpenAiCompletionPromptDriver( - azure_endpoint="endpoint", azure_deployment="deployment", model="text-davinci-003" - ) - - # When - text_artifact = driver.try_run(prompt_stack) - - # Then - mock_completion_create.assert_called_once_with( - model=driver.model, - max_tokens=ANY, - temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, - user=driver.user, - prompt=prompt, - ) - assert text_artifact.value == "model-output" - - def test_try_stream_run(self, mock_completion_stream_create, prompt_stack, prompt): - # Given - driver = AzureOpenAiCompletionPromptDriver( - azure_endpoint="endpoint", azure_deployment="deployment", model="text-davinci-003", stream=True - ) - - # When - text_artifact = next(driver.try_stream(prompt_stack)) - - # Then - mock_completion_stream_create.assert_called_once_with( - model=driver.model, - max_tokens=ANY, - temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, - user=driver.user, - stream=True, - prompt=prompt, - ) - assert text_artifact.value == "model-output" diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 1a06b5907..0743402aa 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -39,35 +39,6 @@ def test_run_via_pipeline_publishes_events(self, mocker): def test_run(self): assert isinstance(MockPromptDriver().run(PromptStack(inputs=[])), TextArtifact) - def test_token_count(self): - assert ( - MockPromptDriver().token_count( - PromptStack(inputs=[PromptStack.Input("foobar", role=PromptStack.USER_ROLE)]) - ) - == 24 - ) - - def test_max_output_tokens(self): - assert MockPromptDriver().max_output_tokens("foobar") == 4090 - assert MockPromptDriver(max_tokens=5000).max_output_tokens("foobar") == 4090 - assert MockPromptDriver(max_tokens=100).max_output_tokens("foobar") == 100 - - def test_prompt_stack_to_string(self): - assert ( - MockPromptDriver().prompt_stack_to_string( - PromptStack(inputs=[PromptStack.Input("foobar", role=PromptStack.USER_ROLE)]) - ) - == "User: foobar\n\nAssistant:" - ) - - def test_custom_prompt_stack_to_string(self): - assert ( - MockPromptDriver( - prompt_stack_to_string=lambda stack: f"Foo: {stack.inputs[0].content}" - ).prompt_stack_to_string(PromptStack(inputs=[PromptStack.Input("foobar", role=PromptStack.USER_ROLE)])) - == "Foo: foobar" - ) - def instance_count(instances, clazz): return len([instance for instance in instances if isinstance(instance, clazz)]) diff --git a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py index 907e27325..b3ceb11a4 100644 --- a/tests/unit/drivers/prompt/test_cohere_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_cohere_prompt_driver.py @@ -8,16 +8,16 @@ class TestCoherePromptDriver: @pytest.fixture def mock_client(self, mocker): mock_client = mocker.patch("cohere.Client").return_value - mock_client.generate.return_value.generations = [Mock()] - mock_client.generate.return_value.generations[0].text = "model-output" + mock_client.chat.return_value = Mock(text="model-output") + return mock_client @pytest.fixture def mock_stream_client(self, mocker): mock_client = mocker.patch("cohere.Client").return_value - mock_chunk = Mock() - mock_chunk.text = "model-output" - mock_client.generate.return_value = iter([mock_chunk]) + mock_chunk = Mock(text="model-output", event_type="text-generation") + mock_client.chat_stream.return_value = iter([mock_chunk]) + return mock_client @pytest.fixture(autouse=True) @@ -55,16 +55,3 @@ def test_try_stream_run(self, mock_stream_client, prompt_stack): # pyright: ign # Then assert text_artifact.value == "model-output" - - @pytest.mark.parametrize("choices", [[], [1, 2]]) - def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_client, prompt_stack): - # Given - driver = CoherePromptDriver(model="command", api_key="api-key") - mock_client.generate.return_value.generations = choices - - # When - with pytest.raises(Exception) as e: - driver.try_run(prompt_stack) - - # Then - e.value.args[0] == "Completion with more than one choice is not supported yet." diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index 6e38bd503..f655d3e51 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -2,7 +2,6 @@ from griptape.drivers import GooglePromptDriver from griptape.utils import PromptStack from unittest.mock import Mock -from tests.mocks.mock_tokenizer import MockTokenizer import pytest @@ -32,9 +31,7 @@ def test_try_run(self, mock_generative_model): prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") prompt_stack.add_generic_input("generic-input") - driver = GooglePromptDriver( - model="gemini-pro", api_key="api-key", tokenizer=MockTokenizer(model="gemini-pro"), top_p=0.5, top_k=50 - ) + driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", top_p=0.5, top_k=50) # When text_artifact = driver.try_run(prompt_stack) @@ -47,7 +44,7 @@ def test_try_run(self, mock_generative_model): {"parts": ["generic-input"], "role": "user"}, ], generation_config=GenerationConfig( - max_output_tokens=997, temperature=0.1, top_p=0.5, top_k=50, stop_sequences=["<|Response|>"] + max_output_tokens=None, temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[] ), ) assert text_artifact.value == "model-output" @@ -59,14 +56,7 @@ def test_try_stream(self, mock_stream_generative_model): prompt_stack.add_user_input("user-input") prompt_stack.add_assistant_input("assistant-input") prompt_stack.add_generic_input("generic-input") - driver = GooglePromptDriver( - model="gemini-pro", - api_key="api-key", - stream=True, - tokenizer=MockTokenizer(model="gemini-pro"), - top_p=0.5, - top_k=50, - ) + driver = GooglePromptDriver(model="gemini-pro", api_key="api-key", stream=True, top_p=0.5, top_k=50) # When text_artifact_stream = driver.try_stream(prompt_stack) @@ -80,9 +70,7 @@ def test_try_stream(self, mock_stream_generative_model): {"parts": ["generic-input"], "role": "user"}, ], stream=True, - generation_config=GenerationConfig( - max_output_tokens=997, temperature=0.1, top_p=0.5, top_k=50, stop_sequences=["<|Response|>"] - ), + generation_config=GenerationConfig(temperature=0.1, top_p=0.5, top_k=50, stop_sequences=[]), ) assert text_artifact.value == "model-output" @@ -108,26 +96,3 @@ def test_prompt_stack_to_model_input(self): {"role": "model", "parts": ["assistant-input"]}, {"role": "user", "parts": ["user-input"]}, ] - - def test_to_content_dict(self): - # Given - driver = GooglePromptDriver(model="gemini-pro", api_key="1234") - - # When - assert driver._GooglePromptDriver__to_content_dict(PromptStack.Input("system-input", "system")) == { - "role": "user", - "parts": ["system-input"], - } - assert driver._GooglePromptDriver__to_content_dict(PromptStack.Input("user-input", "user")) == { - "role": "user", - "parts": ["user-input"], - } - assert driver._GooglePromptDriver__to_content_dict(PromptStack.Input("assistant-input", "assistant")) == { - "role": "model", - "parts": ["assistant-input"], - } - - assert driver._GooglePromptDriver__to_content_dict(PromptStack.Input("generic-input", "generic")) == { - "role": "user", - "parts": ["generic-input"], - } diff --git a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py index 6b91b56cb..15bbb4ead 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_hub_prompt_driver.py @@ -10,6 +10,13 @@ def mock_client(self, mocker): mock_client.text_generation.return_value = "model-output" return mock_client + @pytest.fixture(autouse=True) + def tokenizer(self, mocker): + from_pretrained = tokenizer = mocker.patch("transformers.AutoTokenizer").from_pretrained + from_pretrained.return_value.apply_chat_template.return_value = [1, 2, 3] + + return tokenizer + @pytest.fixture def mock_client_stream(self, mocker): mock_client = mocker.patch("huggingface_hub.InferenceClient").return_value diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index fec39da4d..b2746ca58 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -13,13 +13,15 @@ def mock_pipeline(self, mocker): def mock_generator(self, mock_pipeline): mock_generator = mock_pipeline.return_value mock_generator.task = "text-generation" - mock_generator.return_value = [{"generated_text": "model-output"}] + mock_generator.return_value = [{"generated_text": [{"content": "model-output"}]}] return mock_generator @pytest.fixture(autouse=True) def mock_autotokenizer(self, mocker): mock_autotokenizer = mocker.patch("transformers.AutoTokenizer.from_pretrained").return_value mock_autotokenizer.model_max_length = 42 + mock_autotokenizer.apply_chat_template.return_value = [1, 2, 3] + mock_autotokenizer.decode.return_value = "model-output" return mock_autotokenizer @pytest.fixture @@ -44,6 +46,16 @@ def test_try_run(self, prompt_stack): # Then assert text_artifact.value == "model-output" + def test_try_stream(self, prompt_stack): + # Given + driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) + + # When + with pytest.raises(Exception) as e: + driver.try_stream(prompt_stack) + + assert e.value.args[0] == "streaming is not supported" + @pytest.mark.parametrize("choices", [[], [1, 2]]) def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_generator, prompt_stack): # Given @@ -55,16 +67,26 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_gener driver.try_run(prompt_stack) # Then - e.value.args[0] == "completion with more than one choice is not supported yet" + assert e.value.args[0] == "completion with more than one choice is not supported yet" - def test_try_run_throws_when_unsupported_task_returned(self, prompt_stack, mock_generator): + def test_try_run_throws_when_non_list(self, mock_generator, prompt_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) - mock_generator.task = "obviously-an-unsupported-task" + mock_generator.return_value = {} # When with pytest.raises(Exception) as e: driver.try_run(prompt_stack) # Then - assert e.value.args[0].startswith("only models with the following tasks are supported: ") + assert e.value.args[0] == "invalid output format" + + def test_prompt_stack_to_string(self, prompt_stack): + # Given + driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) + + # When + result = driver.prompt_stack_to_string(prompt_stack) + + # Then + assert result == "model-output" diff --git a/tests/unit/drivers/prompt/test_ollama_prompt_driver.py b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py new file mode 100644 index 000000000..d42a8b45d --- /dev/null +++ b/tests/unit/drivers/prompt/test_ollama_prompt_driver.py @@ -0,0 +1,96 @@ +from griptape.drivers import OllamaPromptDriver +from griptape.utils import PromptStack +import pytest + + +class TestOllamaPromptDriver: + @pytest.fixture + def mock_client(self, mocker): + mock_client = mocker.patch("ollama.Client") + + mock_client.return_value.chat.return_value = {"message": {"content": "model-output"}} + + return mock_client + + @pytest.fixture + def mock_stream_client(self, mocker): + mock_stream_client = mocker.patch("ollama.Client") + mock_stream_client.return_value.chat.return_value = iter([{"message": {"content": "model-output"}}]) + + return mock_stream_client + + def test_init(self): + assert OllamaPromptDriver(model="llama") + + def test_try_run(self, mock_client): + # Given + prompt_stack = PromptStack() + prompt_stack.add_generic_input("generic-input") + prompt_stack.add_system_input("system-input") + prompt_stack.add_user_input("user-input") + prompt_stack.add_assistant_input("assistant-input") + driver = OllamaPromptDriver(model="llama") + expected_messages = [ + {"role": "generic", "content": "generic-input"}, + {"role": "system", "content": "system-input"}, + {"role": "user", "content": "user-input"}, + {"role": "assistant", "content": "assistant-input"}, + ] + + # When + text_artifact = driver.try_run(prompt_stack) + + # Then + mock_client.return_value.chat.assert_called_once_with( + messages=expected_messages, + model=driver.model, + options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, + ) + assert text_artifact.value == "model-output" + + def test_try_run_bad_response(self, mock_client): + # Given + prompt_stack = PromptStack() + driver = OllamaPromptDriver(model="llama") + mock_client.return_value.chat.return_value = "bad-response" + + # When/Then + with pytest.raises(Exception, match="invalid model response"): + driver.try_run(prompt_stack) + + def test_try_stream_run(self, mock_stream_client): + # Given + prompt_stack = PromptStack() + prompt_stack.add_generic_input("generic-input") + prompt_stack.add_system_input("system-input") + prompt_stack.add_user_input("user-input") + prompt_stack.add_assistant_input("assistant-input") + expected_messages = [ + {"role": "generic", "content": "generic-input"}, + {"role": "system", "content": "system-input"}, + {"role": "user", "content": "user-input"}, + {"role": "assistant", "content": "assistant-input"}, + ] + driver = OllamaPromptDriver(model="llama", stream=True) + + # When + text_artifact = next(driver.try_stream(prompt_stack)) + + # Then + mock_stream_client.return_value.chat.assert_called_once_with( + messages=expected_messages, + model=driver.model, + options={"temperature": driver.temperature, "stop": [], "num_predict": driver.max_tokens}, + stream=True, + ) + assert text_artifact.value == "model-output" + + def test_try_stream_bad_response(self, mock_stream_client): + # Given + prompt_stack = PromptStack() + driver = OllamaPromptDriver(model="llama", stream=True) + mock_stream_client.return_value.chat.return_value = "bad-response" + + # When/Then + with pytest.raises(Exception, match="invalid model response"): + next(driver.try_stream(prompt_stack)) diff --git a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py index fbc939005..a2900d4d3 100644 --- a/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py @@ -1,7 +1,3 @@ -import datetime - -from transformers import AutoTokenizer - from griptape.drivers import OpenAiChatPromptDriver from griptape.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer from griptape.utils import PromptStack @@ -13,11 +9,11 @@ class TestOpenAiChatPromptDriverFixtureMixin: @pytest.fixture def mock_chat_completion_create(self, mocker): - mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.with_raw_response.create + mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create mock_choice = Mock() mock_choice.message.content = "model-output" mock_chat_create.return_value.headers = {} - mock_chat_create.return_value.parse.return_value.choices = [mock_choice] + mock_chat_create.return_value.choices = [mock_choice] return mock_chat_create @pytest.fixture @@ -163,30 +159,6 @@ def test_try_run_with_max_tokens(self, mock_chat_completion_create, prompt_stack ) assert text_artifact.value == "model-output" - def test_try_run_max_tokens_limited_by_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): - # Given - max_tokens_request = 9999999 - driver = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=max_tokens_request - ) - tokens_left = driver.tokenizer.count_input_tokens_left(driver._prompt_stack_to_messages(prompt_stack)) - - # When - text_artifact = driver.try_run(prompt_stack) - - # Then - mock_chat_completion_create.assert_called_once_with( - model=driver.model, - temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, - user=driver.user, - messages=messages, - max_tokens=max_tokens_request, - seed=driver.seed, - ) - assert max_tokens_request > tokens_left - assert text_artifact.value == "model-output" - def test_try_run_throws_when_prompt_stack_is_string(self): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) @@ -202,7 +174,7 @@ def test_try_run_throws_when_prompt_stack_is_string(self): def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_chat_completion_create, prompt_stack): # Given driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, api_key="api-key") - mock_chat_completion_create.return_value.parse.return_value.choices = [choices] + mock_chat_completion_create.return_value.choices = [choices] # When with pytest.raises(Exception) as e: @@ -211,116 +183,10 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_chat_ # Then e.value.args[0] == "Completion with more than one choice is not supported yet." - def test_token_count(self, prompt_stack, messages): - # Given - mock_tokenizer = Mock(spec=OpenAiTokenizer) - mock_tokenizer.count_tokens.return_value = 42 - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=mock_tokenizer) - - # When - token_count = driver.token_count(prompt_stack) - - # Then - mock_tokenizer.count_tokens.assert_called_once_with(messages) - assert token_count == 42 - - # Given - mock_tokenizer = Mock() - mock_tokenizer.count_tokens.return_value = 42 - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=mock_tokenizer) - - # When - token_count = driver.token_count(prompt_stack) - - # Then - mock_tokenizer.count_tokens.assert_called_once_with(driver.prompt_stack_to_string(prompt_stack)) - assert token_count == 42 - - def test_max_output_tokens(self, messages): - # Given - mock_tokenizer = Mock() - mock_tokenizer.count_output_tokens_left.return_value = 42 - driver = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, tokenizer=mock_tokenizer, max_tokens=45 - ) - - # When - max_output_tokens = driver.max_output_tokens(messages) - - # Then - mock_tokenizer.count_output_tokens_left.assert_called_once_with(messages) - assert max_output_tokens == 42 - - def test_max_output_tokens_with_max_tokens(self, messages): - max_tokens = OpenAiChatPromptDriver( - model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, max_tokens=42 - ).max_output_tokens(messages) - - assert max_tokens == 42 - - def test_extract_ratelimit_metadata(self): - response_with_headers = OpenAiApiResponseWithHeaders() - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - driver._extract_ratelimit_metadata(response_with_headers) - - assert driver._ratelimit_requests_remaining == response_with_headers.remaining_requests - assert driver._ratelimit_tokens_remaining == response_with_headers.remaining_tokens - assert driver._ratelimit_request_limit == response_with_headers.limit_requests - assert driver._ratelimit_token_limit == response_with_headers.limit_tokens - - # Assert that the reset times are within one second of the expected value. - expected_request_reset_time = datetime.datetime.now() + datetime.timedelta( - seconds=response_with_headers.reset_requests_in - ) - expected_token_reset_time = datetime.datetime.now() + datetime.timedelta( - seconds=response_with_headers.reset_tokens_in - ) - - assert driver._ratelimit_requests_reset_at is not None - assert abs(driver._ratelimit_requests_reset_at - expected_request_reset_time) < datetime.timedelta(seconds=1) - assert driver._ratelimit_tokens_reset_at is not None - assert abs(driver._ratelimit_tokens_reset_at - expected_token_reset_time) < datetime.timedelta(seconds=1) - - def test_extract_ratelimit_metadata_with_subsecond_reset_times(self): - response_with_headers = OpenAiApiResponseWithHeaders( - reset_requests_in=1, reset_requests_in_unit="ms", reset_tokens_in=10, reset_tokens_in_unit="ms" - ) - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, api_key="api-key") - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - driver._extract_ratelimit_metadata(response_with_headers) - - # Assert that the reset times are within one second of the expected value. With a sub-second reset time, - # this is rounded up to one second in the future. - expected_request_reset_time = datetime.datetime.now() + datetime.timedelta(seconds=1) - expected_token_reset_time = datetime.datetime.now() + datetime.timedelta(seconds=1) - - assert driver._ratelimit_requests_reset_at is not None - assert abs(driver._ratelimit_requests_reset_at - expected_request_reset_time) < datetime.timedelta(seconds=1) - assert driver._ratelimit_tokens_reset_at is not None - assert abs(driver._ratelimit_tokens_reset_at - expected_token_reset_time) < datetime.timedelta(seconds=1) - - def test_extract_ratelimit_metadata_missing_headers(self): - class OpenAiApiResponseNoHeaders: - @property - def headers(self): - return {} - - response_without_headers = OpenAiApiResponseNoHeaders() - - driver = OpenAiChatPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - driver._extract_ratelimit_metadata(response_without_headers) - - assert driver._ratelimit_request_limit is None - assert driver._ratelimit_requests_remaining is None - assert driver._ratelimit_requests_reset_at is None - assert driver._ratelimit_token_limit is None - assert driver._ratelimit_tokens_remaining is None - assert driver._ratelimit_tokens_reset_at is None - def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messages): driver = OpenAiChatPromptDriver( model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, - tokenizer=HuggingFaceTokenizer(tokenizer=AutoTokenizer.from_pretrained("gpt2"), max_output_tokens=1000), + tokenizer=HuggingFaceTokenizer(model="gpt2", max_output_tokens=1000), max_tokens=1, ) @@ -333,7 +199,12 @@ def test_custom_tokenizer(self, mock_chat_completion_create, prompt_stack, messa temperature=driver.temperature, stop=driver.tokenizer.stop_sequences, user=driver.user, - messages=messages, + messages=[ + {"role": "user", "content": "generic-input"}, + {"role": "system", "content": "system-input"}, + {"role": "user", "content": "user-input"}, + {"role": "assistant", "content": "assistant-input"}, + ], seed=driver.seed, max_tokens=1, ) diff --git a/tests/unit/drivers/prompt/test_openai_completion_prompt_driver.py b/tests/unit/drivers/prompt/test_openai_completion_prompt_driver.py deleted file mode 100644 index 66998c261..000000000 --- a/tests/unit/drivers/prompt/test_openai_completion_prompt_driver.py +++ /dev/null @@ -1,112 +0,0 @@ -from griptape.drivers import OpenAiCompletionPromptDriver -from griptape.utils import PromptStack -from unittest.mock import ANY, Mock -from griptape.tokenizers import OpenAiTokenizer -import pytest - - -class TestOpenAiCompletionPromptDriverFixtureMixin: - @pytest.fixture - def mock_completion_create(self, mocker): - mock_chat_create = mocker.patch("openai.OpenAI").return_value.completions.create - mock_choice = Mock() - mock_choice.text = "model-output" - mock_chat_create.return_value.choices = [mock_choice] - return mock_chat_create - - @pytest.fixture - def mock_completion_stream_create(self, mocker): - mock_chat_create = mocker.patch("openai.OpenAI").return_value.completions.create - mock_chunk = Mock() - mock_choice = Mock() - mock_choice.text = "model-output" - mock_chunk.choices = [mock_choice] - mock_chat_create.return_value = iter([mock_chunk]) - return mock_chat_create - - @pytest.fixture - def prompt_stack(self): - prompt_stack = PromptStack() - prompt_stack.add_generic_input("generic-input") - prompt_stack.add_system_input("system-input") - prompt_stack.add_user_input("user-input") - prompt_stack.add_assistant_input("assistant-input") - return prompt_stack - - @pytest.fixture - def prompt(self): - return "".join( - [ - "generic-input\n\n", - "system-input\n\n", - "User: user-input\n\n", - "Assistant: assistant-input\n\n", - "Assistant:", - ] - ) - - -class TestOpenAiCompletionPromptDriver(TestOpenAiCompletionPromptDriverFixtureMixin): - def test_init(self): - assert OpenAiCompletionPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - - def test_try_run(self, mock_completion_create, prompt_stack, prompt): - # Given - driver = OpenAiCompletionPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - - # When - text_artifact = driver.try_run(prompt_stack) - - # Then - mock_completion_create.assert_called_once_with( - model=driver.model, - max_tokens=ANY, - temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, - user=driver.user, - prompt=prompt, - ) - assert text_artifact.value == "model-output" - - def test_try_stream_run(self, mock_completion_stream_create, prompt_stack, prompt): - # Given - driver = OpenAiCompletionPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, stream=True) - - # When - text_artifact = next(driver.try_stream(prompt_stack)) - - # Then - mock_completion_stream_create.assert_called_once_with( - model=driver.model, - max_tokens=ANY, - temperature=driver.temperature, - stop=driver.tokenizer.stop_sequences, - user=driver.user, - stream=True, - prompt=prompt, - ) - assert text_artifact.value == "model-output" - - def test_try_run_throws_when_prompt_stack_is_string(self): - # Given - driver = OpenAiCompletionPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - - # When - with pytest.raises(Exception) as e: - driver.try_run("prompt-stack") # pyright: ignore - - # Then - assert e.value.args[0] == "'str' object has no attribute 'inputs'" - - @pytest.mark.parametrize("choices", [[], [1, 2]]) - def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_completion_create, prompt_stack): - # Given - driver = OpenAiCompletionPromptDriver(model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL) - mock_completion_create.return_value.choices = choices - - # When - with pytest.raises(Exception) as e: - driver.try_run(prompt_stack) - - # Then - e.value.args[0] == "Completion with more than one choice is not supported yet." diff --git a/tests/unit/drivers/prompt_models/__init__.py b/tests/unit/drivers/prompt_models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/drivers/prompt_models/test_bedrock_claude_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_bedrock_claude_prompt_model_driver.py deleted file mode 100644 index 94bc97122..000000000 --- a/tests/unit/drivers/prompt_models/test_bedrock_claude_prompt_model_driver.py +++ /dev/null @@ -1,129 +0,0 @@ -from unittest import mock -import json -import boto3 -import pytest -from griptape.utils import PromptStack -from griptape.drivers import AmazonBedrockPromptDriver, BedrockClaudePromptModelDriver - - -class TestBedrockClaudePromptModelDriver: - @pytest.fixture(autouse=True) - def mock_session(self, mocker): - mock_session_class = mocker.patch("boto3.Session") - - mock_session_object = mock.Mock() - mock_client = mock.Mock() - - mock_session_object.client.return_value = mock_client - mock_session_class.return_value = mock_session_object - - @pytest.fixture - def driver(self, request): - return AmazonBedrockPromptDriver( - model=request.param, - session=boto3.Session(region_name="us-east-1"), - prompt_model_driver=BedrockClaudePromptModelDriver(), - temperature=0.12345, - ).prompt_model_driver - - @pytest.mark.parametrize( - "driver,", - [ - ("anthropic.claude-v2"), - ("anthropic.claude-v2:1"), - ("anthropic.claude-3-sonnet-20240229-v1:0"), - ("anthropic.claude-3-haiku-20240307-v1:0"), - ], - indirect=["driver"], - ) - def test_init(self, driver): - assert driver.prompt_driver is not None - - @pytest.mark.parametrize( - "driver,", - [ - ("anthropic.claude-v2"), - ("anthropic.claude-v2:1"), - ("anthropic.claude-3-sonnet-20240229-v1:0"), - ("anthropic.claude-3-haiku-20240307-v1:0"), - ], - indirect=["driver"], - ) - @pytest.mark.parametrize("system_enabled", [True, False]) - def test_prompt_stack_to_model_input(self, driver, system_enabled): - stack = PromptStack() - if system_enabled: - stack.add_system_input("foo") - stack.add_user_input("bar") - stack.add_assistant_input("baz") - stack.add_generic_input("qux") - - expected_messages = [ - {"role": "user", "content": "bar"}, - {"role": "assistant", "content": "baz"}, - {"role": "user", "content": "qux"}, - ] - actual = driver.prompt_stack_to_model_input(stack) - expected = {"messages": expected_messages, **({"system": "foo"} if system_enabled else {})} - - assert actual == expected - - @pytest.mark.parametrize( - "driver,", - [ - ("anthropic.claude-v2"), - ("anthropic.claude-v2:1"), - ("anthropic.claude-3-sonnet-20240229-v1:0"), - ("anthropic.claude-3-haiku-20240307-v1:0"), - ], - indirect=["driver"], - ) - @pytest.mark.parametrize("system_enabled", [True, False]) - def test_prompt_stack_to_model_params(self, driver, system_enabled): - stack = PromptStack() - if system_enabled: - stack.add_system_input("foo") - stack.add_user_input("bar") - stack.add_assistant_input("baz") - stack.add_generic_input("qux") - - max_tokens = driver.prompt_driver.max_output_tokens(driver.prompt_driver.prompt_stack_to_string(stack)) - - expected = { - "temperature": 0.12345, - "max_tokens": max_tokens, - "anthropic_version": driver.ANTHROPIC_VERSION, - "messages": [ - {"role": "user", "content": "bar"}, - {"role": "assistant", "content": "baz"}, - {"role": "user", "content": "qux"}, - ], - "top_p": 0.999, - "top_k": 250, - "stop_sequences": ["<|Response|>"], - **({"system": "foo"} if system_enabled else {}), - } - - assert driver.prompt_stack_to_model_params(stack) == expected - - @pytest.mark.parametrize( - "driver,", - [ - ("anthropic.claude-v2"), - ("anthropic.claude-v2:1"), - ("anthropic.claude-3-sonnet-20240229-v1:0"), - ("anthropic.claude-3-haiku-20240307-v1:0"), - ], - indirect=["driver"], - ) - def test_process_output(self, driver): - assert ( - driver.process_output(json.dumps({"type": "message", "content": [{"text": "foobar"}]}).encode()).value - == "foobar" - ) - assert ( - driver.process_output( - json.dumps({"type": "content_block_delta", "delta": {"text": "foobar"}}).encode() - ).value - == "foobar" - ) diff --git a/tests/unit/drivers/prompt_models/test_bedrock_jurassic_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_bedrock_jurassic_prompt_model_driver.py deleted file mode 100644 index e0d6f1f02..000000000 --- a/tests/unit/drivers/prompt_models/test_bedrock_jurassic_prompt_model_driver.py +++ /dev/null @@ -1,71 +0,0 @@ -from unittest import mock -import json -import boto3 -import pytest -from griptape.utils import PromptStack -from griptape.drivers import AmazonBedrockPromptDriver, BedrockJurassicPromptModelDriver - - -class TestBedrockJurassicPromptModelDriver: - @pytest.fixture(autouse=True) - def mock_session(self, mocker): - fake_tokenization = '{"prompt": {"tokens": [{}, {}, {}]}}' - mock_session_class = mocker.patch("boto3.Session") - - mock_session_object = mock.Mock() - mock_client = mock.Mock() - mock_response = mock.Mock() - - mock_response.get().read.return_value = fake_tokenization - mock_client.invoke_model.return_value = mock_response - mock_session_object.client.return_value = mock_client - mock_session_class.return_value = mock_session_object - - return mock_session_object - - @pytest.fixture - def driver(self): - return AmazonBedrockPromptDriver( - model="ai21.j2-ultra", - session=boto3.Session(region_name="us-east-1"), - prompt_model_driver=BedrockJurassicPromptModelDriver(), - temperature=0.12345, - ).prompt_model_driver - - @pytest.fixture - def stack(self): - stack = PromptStack() - - stack.add_system_input("foo") - stack.add_user_input("bar") - - return stack - - def test_driver_stream(self): - with pytest.raises(ValueError): - AmazonBedrockPromptDriver( - model="ai21.j2-ultra", - session=boto3.Session(region_name="us-east-1"), - prompt_model_driver=BedrockJurassicPromptModelDriver(), - temperature=0.12345, - stream=True, - ).prompt_model_driver - - def test_init(self, driver): - assert driver.prompt_driver is not None - - def test_prompt_stack_to_model_input(self, driver, stack): - model_input = driver.prompt_stack_to_model_input(stack) - - assert isinstance(model_input, dict) - assert model_input["prompt"].startswith("System: foo\nUser: bar\nAssistant:") - - def test_prompt_stack_to_model_params(self, driver, stack): - assert driver.prompt_stack_to_model_params(stack)["maxTokens"] == 2042 - assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345 - - def test_process_output(self, driver): - assert ( - driver.process_output(json.dumps({"completions": [{"data": {"text": "foobar"}}]}).encode()).value - == "foobar" - ) diff --git a/tests/unit/drivers/prompt_models/test_bedrock_llama_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_bedrock_llama_prompt_model_driver.py deleted file mode 100644 index 8cb4b2c94..000000000 --- a/tests/unit/drivers/prompt_models/test_bedrock_llama_prompt_model_driver.py +++ /dev/null @@ -1,64 +0,0 @@ -from unittest import mock -import json -import boto3 -import pytest -from griptape.tokenizers import BedrockLlamaTokenizer -from griptape.utils import PromptStack -from griptape.drivers import AmazonBedrockPromptDriver, BedrockLlamaPromptModelDriver - - -class TestBedrockLlamaPromptModelDriver: - @pytest.fixture(autouse=True) - def mock_session(self, mocker): - fake_tokenization = '{"generation_token_count": 13}' - mock_session_class = mocker.patch("boto3.Session") - - mock_session_object = mock.Mock() - mock_client = mock.Mock() - mock_response = mock.Mock() - - mock_response.get().read.return_value = fake_tokenization - mock_client.invoke_model.return_value = mock_response - mock_session_object.client.return_value = mock_client - mock_session_class.return_value = mock_session_object - - return mock_session_object - - @pytest.fixture - def driver(self): - return AmazonBedrockPromptDriver( - model="meta.llama", - session=boto3.Session(region_name="us-east-1"), - prompt_model_driver=BedrockLlamaPromptModelDriver(), - temperature=0.12345, - ).prompt_model_driver - - @pytest.fixture - def stack(self): - stack = PromptStack() - - stack.add_system_input("{{ system_prompt }}") - stack.add_user_input("{{ usr_msg_1 }}") - stack.add_assistant_input("{{ model_msg_1 }}") - stack.add_user_input("{{ usr_msg_2 }}") - - return stack - - def test_init(self, driver): - assert driver.prompt_driver is not None - - def test_prompt_stack_to_model_input(self, driver, stack): - model_input = driver.prompt_stack_to_model_input(stack) - - assert isinstance(model_input, str) - assert ( - model_input - == "[INST] <>\n{{ system_prompt }}\n<>\n\n{{ usr_msg_1 }} [/INST] {{ model_msg_1 }} [INST] {{ usr_msg_2 }} [/INST]" - ) - - def test_prompt_stack_to_model_params(self, driver, stack): - assert driver.prompt_stack_to_model_params(stack)["max_gen_len"] == 2026 - assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345 - - def test_process_output(self, driver): - assert driver.process_output(json.dumps({"generation": "foobar"})).value == "foobar" diff --git a/tests/unit/drivers/prompt_models/test_bedrock_titan_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_bedrock_titan_prompt_model_driver.py deleted file mode 100644 index ae6d436b5..000000000 --- a/tests/unit/drivers/prompt_models/test_bedrock_titan_prompt_model_driver.py +++ /dev/null @@ -1,58 +0,0 @@ -from unittest import mock -import json -import boto3 -import pytest -from griptape.utils import PromptStack -from griptape.drivers import AmazonBedrockPromptDriver, BedrockTitanPromptModelDriver - - -class TestBedrockTitanPromptModelDriver: - @pytest.fixture(autouse=True) - def mock_session(self, mocker): - fake_tokenization = '{"inputTextTokenCount": 13}' - mock_session_class = mocker.patch("boto3.Session") - - mock_session_object = mock.Mock() - mock_client = mock.Mock() - mock_response = mock.Mock() - - mock_response.get().read.return_value = fake_tokenization - mock_client.invoke_model.return_value = mock_response - mock_session_object.client.return_value = mock_client - mock_session_class.return_value = mock_session_object - - return mock_session_object - - @pytest.fixture - def driver(self): - return AmazonBedrockPromptDriver( - model="amazon.titan", - session=boto3.Session(region_name="us-east-1"), - prompt_model_driver=BedrockTitanPromptModelDriver(), - temperature=0.12345, - ).prompt_model_driver - - @pytest.fixture - def stack(self): - stack = PromptStack() - - stack.add_system_input("foo") - stack.add_user_input("bar") - - return stack - - def test_init(self, driver): - assert driver.prompt_driver is not None - - def test_prompt_stack_to_model_input(self, driver, stack): - model_input = driver.prompt_stack_to_model_input(stack) - - assert isinstance(model_input, dict) - assert model_input["inputText"].startswith("Instructions: foo\n\nUser: bar\n\nBot:") - - def test_prompt_stack_to_model_params(self, driver, stack): - assert driver.prompt_stack_to_model_params(stack)["textGenerationConfig"]["maxTokenCount"] == 7994 - assert driver.prompt_stack_to_model_params(stack)["textGenerationConfig"]["temperature"] == 0.12345 - - def test_process_output(self, driver): - assert driver.process_output(json.dumps({"results": [{"outputText": "foobar"}]})).value == "foobar" diff --git a/tests/unit/drivers/prompt_models/test_sagemaker_falcon_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_sagemaker_falcon_prompt_model_driver.py deleted file mode 100644 index 78d990229..000000000 --- a/tests/unit/drivers/prompt_models/test_sagemaker_falcon_prompt_model_driver.py +++ /dev/null @@ -1,43 +0,0 @@ -import boto3 -import pytest -from griptape.utils import PromptStack -from griptape.drivers import AmazonSageMakerPromptDriver, SageMakerFalconPromptModelDriver - - -class TestSageMakerFalconPromptModelDriver: - @pytest.fixture - def driver(self): - return AmazonSageMakerPromptDriver( - endpoint="endpoint-name", - session=boto3.Session(region_name="us-east-1"), - prompt_model_driver=SageMakerFalconPromptModelDriver(), - temperature=0.12345, - ).prompt_model_driver - - @pytest.fixture - def stack(self): - stack = PromptStack() - - stack.add_system_input("foo") - stack.add_user_input("bar") - - return stack - - def test_init(self, driver): - assert driver.prompt_driver is not None - - def test_prompt_stack_to_model_input(self, driver, stack): - model_input = driver.prompt_stack_to_model_input(stack) - - assert isinstance(model_input, str) - assert model_input.startswith("foo\n\nUser: bar") - - def test_prompt_stack_to_model_params(self, driver, stack): - assert driver.prompt_stack_to_model_params(stack)["max_new_tokens"] == 590 - assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345 - - def test_process_output(self, driver, stack): - assert driver.process_output([{"generated_text": "foobar"}]).value == "foobar" - - def test_tokenizer_max_model_length(self, driver): - assert driver.tokenizer.tokenizer.model_max_length == 2048 diff --git a/tests/unit/drivers/prompt_models/test_sagemaker_llama_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_sagemaker_llama_prompt_model_driver.py deleted file mode 100644 index b39ce458e..000000000 --- a/tests/unit/drivers/prompt_models/test_sagemaker_llama_prompt_model_driver.py +++ /dev/null @@ -1,67 +0,0 @@ -import boto3 -import pytest -from griptape.utils import PromptStack -from griptape.drivers import AmazonSageMakerPromptDriver, SageMakerLlamaPromptModelDriver - - -class TestSageMakerLlamaPromptModelDriver: - @pytest.fixture(autouse=True) - def llama3_instruct_tokenizer(self, mocker): - tokenizer = mocker.patch("transformers.AutoTokenizer").return_value - tokenizer.model_max_length = 8000 - - return tokenizer - - @pytest.fixture(autouse=True) - def hugging_face_tokenizer(self, mocker, llama3_instruct_tokenizer): - tokenizer = mocker.patch( - "griptape.drivers.prompt_model.sagemaker_llama_prompt_model_driver.HuggingFaceTokenizer" - ).return_value - tokenizer.count_output_tokens_left.return_value = 7991 - tokenizer.tokenizer = llama3_instruct_tokenizer - return tokenizer - - @pytest.fixture - def driver(self): - return AmazonSageMakerPromptDriver( - endpoint="endpoint-name", - model="inference-component-name", - session=boto3.Session(region_name="us-east-1"), - prompt_model_driver=SageMakerLlamaPromptModelDriver(), - temperature=0.12345, - ).prompt_model_driver - - @pytest.fixture - def stack(self): - stack = PromptStack() - - stack.add_system_input("foo") - stack.add_user_input("bar") - - return stack - - def test_init(self, driver): - assert driver.prompt_driver is not None - - def test_prompt_stack_to_model_input(self, driver, stack, hugging_face_tokenizer): - driver.prompt_stack_to_model_input(stack) - - hugging_face_tokenizer.tokenizer.apply_chat_template.assert_called_once_with( - [{"role": "system", "content": "foo"}, {"role": "user", "content": "bar"}], - tokenize=False, - add_generation_prompt=True, - ) - - def test_prompt_stack_to_model_params(self, driver, stack): - assert driver.prompt_stack_to_model_params(stack)["max_new_tokens"] == 7991 - assert driver.prompt_stack_to_model_params(stack)["temperature"] == 0.12345 - - def test_process_output(self, driver, stack): - assert driver.process_output({"generated_text": "foobar"}).value == "foobar" - - def test_process_output_invalid_format(self, driver, stack): - with pytest.raises(ValueError): - assert driver.process_output([{"generated_text": "foobar"}]) - - def test_tokenizer_max_model_length(self, driver): - assert driver.tokenizer.tokenizer.model_max_length == 8000 diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index 921a4b013..cb7b3058e 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -22,7 +22,7 @@ def test_run(self): def test_run_with_env(self): pipeline = Pipeline() - agent = Agent(prompt_driver=MockPromptDriver(mock_output=lambda: os.environ["key"])) + agent = Agent(prompt_driver=MockPromptDriver(mock_output=lambda _: os.environ["key"])) driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"key": "value"}) task = StructureRunTask(driver=driver) diff --git a/tests/unit/engines/summary/test_prompt_summary_engine.py b/tests/unit/engines/summary/test_prompt_summary_engine.py index 52179bdfb..59b36f48e 100644 --- a/tests/unit/engines/summary/test_prompt_summary_engine.py +++ b/tests/unit/engines/summary/test_prompt_summary_engine.py @@ -1,7 +1,9 @@ import pytest from griptape.artifacts import TextArtifact, ListArtifact from griptape.engines import PromptSummaryEngine +from griptape.utils import PromptStack from tests.mocks.mock_prompt_driver import MockPromptDriver +import os class TestPromptSummaryEngine: @@ -16,3 +18,25 @@ def test_summarize_artifacts(self, engine): assert ( engine.summarize_artifacts(ListArtifact([TextArtifact("foo"), TextArtifact("bar")])).value == "mock output" ) + + def test_max_token_multiplier_invalid(self, engine): + with pytest.raises(ValueError): + PromptSummaryEngine(prompt_driver=MockPromptDriver(), max_token_multiplier=0) + + with pytest.raises(ValueError): + PromptSummaryEngine(prompt_driver=MockPromptDriver(), max_token_multiplier=10000) + + def test_chunked_summary(self, engine): + def smaller_input(prompt_stack: PromptStack): + return prompt_stack.inputs[0].content[: (len(prompt_stack.inputs[0].content) // 2)] + + engine = PromptSummaryEngine(prompt_driver=MockPromptDriver(mock_output="smaller_input")) + + def copy_test_resource(resource_path: str): + file_dir = os.path.dirname(__file__) + full_path = os.path.join(file_dir, "../../../resources", resource_path) + full_path = os.path.normpath(full_path) + with open(full_path) as f: + return f.read() + + assert engine.summarize_text(copy_test_resource("test.txt") * 50) diff --git a/tests/unit/loaders/test_web_loader.py b/tests/unit/loaders/test_web_loader.py index fb63e87eb..e26573539 100644 --- a/tests/unit/loaders/test_web_loader.py +++ b/tests/unit/loaders/test_web_loader.py @@ -1,4 +1,5 @@ import pytest +from griptape.artifacts.error_artifact import ErrorArtifact from griptape.loaders import WebLoader from tests.mocks.mock_embedding_driver import MockEmbeddingDriver @@ -22,6 +23,14 @@ def test_load(self, loader): assert artifacts[0].embedding == [0, 1] + def test_load_exception(self, mocker, loader): + mocker.patch("trafilatura.fetch_url", side_effect=Exception("error")) + source = "https://github.com/griptape-ai/griptape" + artifact = loader.load(source) + + assert isinstance(artifact, ErrorArtifact) + assert f"Error loading from source: {source}" == artifact.value + def test_load_collection(self, loader): artifacts = loader.load_collection( ["https://github.com/griptape-ai/griptape", "https://github.com/griptape-ai/griptape-docs"] @@ -38,11 +47,13 @@ def test_load_collection(self, loader): def test_empty_page_string_response(self, loader, mocker): mocker.patch("trafilatura.extract", return_value="") - with pytest.raises(Exception, match="can't extract page"): - loader.load("https://example.com/") + artifact = loader.load("https://example.com/") + assert isinstance(artifact, ErrorArtifact) + assert str(artifact.exception) == "can't extract page" def test_empty_page_none_response(self, loader, mocker): mocker.patch("trafilatura.extract", return_value=None) - with pytest.raises(Exception, match="can't extract page"): - loader.load("https://example.com/") + artifact = loader.load("https://example.com/") + assert isinstance(artifact, ErrorArtifact) + assert str(artifact.exception) == "can't extract page" diff --git a/tests/unit/memory/structure/test_conversation_memory.py b/tests/unit/memory/structure/test_conversation_memory.py index 665dca9b2..298e5ac3f 100644 --- a/tests/unit/memory/structure/test_conversation_memory.py +++ b/tests/unit/memory/structure/test_conversation_memory.py @@ -1,7 +1,10 @@ import json +from griptape.structures import Agent +from griptape.utils import PromptStack from griptape.memory.structure import ConversationMemory, Run, BaseConversationMemory from griptape.structures import Pipeline from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_tokenizer import MockTokenizer from griptape.tasks import PromptTask @@ -69,3 +72,97 @@ def test_buffering(self): assert len(pipeline.conversation_memory.runs) == 2 assert pipeline.conversation_memory.runs[0].input == "run4" assert pipeline.conversation_memory.runs[1].input == "run5" + + def test_add_to_prompt_stack_autopruing_disabled(self): + agent = Agent(prompt_driver=MockPromptDriver()) + memory = ConversationMemory( + autoprune=False, + runs=[ + Run(input="foo1", output="bar1"), + Run(input="foo2", output="bar2"), + Run(input="foo3", output="bar3"), + Run(input="foo4", output="bar4"), + Run(input="foo5", output="bar5"), + ], + ) + memory.structure = agent + prompt_stack = PromptStack() + prompt_stack.add_user_input("foo") + prompt_stack.add_assistant_input("bar") + memory.add_to_prompt_stack(prompt_stack) + + assert len(prompt_stack.inputs) == 12 + + def test_add_to_prompt_stack_autopruning_enabled(self): + # All memory is pruned. + agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0))) + memory = ConversationMemory( + autoprune=True, + runs=[ + Run(input="foo1", output="bar1"), + Run(input="foo2", output="bar2"), + Run(input="foo3", output="bar3"), + Run(input="foo4", output="bar4"), + Run(input="foo5", output="bar5"), + ], + ) + memory.structure = agent + prompt_stack = PromptStack() + prompt_stack.add_system_input("fizz") + prompt_stack.add_user_input("foo") + prompt_stack.add_assistant_input("bar") + memory.add_to_prompt_stack(prompt_stack) + + assert len(prompt_stack.inputs) == 3 + + # No memory is pruned. + agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000))) + memory = ConversationMemory( + autoprune=True, + runs=[ + Run(input="foo1", output="bar1"), + Run(input="foo2", output="bar2"), + Run(input="foo3", output="bar3"), + Run(input="foo4", output="bar4"), + Run(input="foo5", output="bar5"), + ], + ) + memory.structure = agent + prompt_stack = PromptStack() + prompt_stack.add_system_input("fizz") + prompt_stack.add_user_input("foo") + prompt_stack.add_assistant_input("bar") + memory.add_to_prompt_stack(prompt_stack) + + assert len(prompt_stack.inputs) == 13 + + # One memory is pruned. + # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens + # so that a single memory is pruned. + agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160))) + memory = ConversationMemory( + autoprune=True, + runs=[ + # All of these sum to 155 tokens with the MockTokenizer. + Run(input="foo1", output="bar1"), + Run(input="foo2", output="bar2"), + Run(input="foo3", output="bar3"), + Run(input="foo4", output="bar4"), + Run(input="foo5", output="bar5"), + ], + ) + memory.structure = agent + prompt_stack = PromptStack() + # And then another 6 tokens from fizz for a total of 161 tokens. + prompt_stack.add_system_input("fizz") + prompt_stack.add_user_input("foo") + prompt_stack.add_assistant_input("bar") + memory.add_to_prompt_stack(prompt_stack, 1) + + # We expect one run (2 prompt stack inputs) to be pruned. + assert len(prompt_stack.inputs) == 11 + assert prompt_stack.inputs[0].content == "fizz" + assert prompt_stack.inputs[1].content == "foo2" + assert prompt_stack.inputs[2].content == "bar2" + assert prompt_stack.inputs[-2].content == "foo" + assert prompt_stack.inputs[-1].content == "bar" diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 371aee8a8..e6c2a1f01 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -252,3 +252,13 @@ def test_deprecation(self): with pytest.deprecated_call(): Agent(stream=True) + + def finished_tasks(self): + task = PromptTask("test prompt") + agent = Agent(prompt_driver=MockPromptDriver()) + + agent.add_task(task) + + agent.run("hello") + + assert len(agent.finished_tasks) == 1 diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 970d43e74..bbcf5138e 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -1,16 +1,34 @@ +import time import pytest +from pytest import fixture from griptape.memory.task.storage import TextArtifactStorage from tests.mocks.mock_prompt_driver import MockPromptDriver from griptape.rules import Rule, Ruleset -from griptape.tasks import PromptTask, BaseTask, ToolkitTask +from griptape.tasks import PromptTask, BaseTask, ToolkitTask, CodeExecutionTask from griptape.structures import Workflow +from griptape.artifacts import ErrorArtifact, TextArtifact from griptape.memory.structure import ConversationMemory from tests.mocks.mock_tool.tool import MockTool from tests.mocks.mock_embedding_driver import MockEmbeddingDriver class TestWorkflow: + @fixture + def waiting_task(self): + def fn(task): + time.sleep(10) + return TextArtifact("done") + + return CodeExecutionTask(run_fn=fn) + + @fixture + def error_artifact_task(self): + def fn(task): + return ErrorArtifact("error") + + return CodeExecutionTask(run_fn=fn) + def test_init(self): driver = MockPromptDriver() workflow = Workflow(prompt_driver=driver, rulesets=[Ruleset("TestRuleset", [Rule("test")])]) @@ -140,10 +158,10 @@ def test_tasks_initialization(self): assert workflow.tasks[1].id == "test2" assert workflow.tasks[2].id == "test3" assert len(first_task.parents) == 0 - assert len(first_task.children) == 1 - assert len(second_task.parents) == 1 - assert len(second_task.children) == 1 - assert len(third_task.parents) == 1 + assert len(first_task.children) == 0 + assert len(second_task.parents) == 0 + assert len(second_task.children) == 0 + assert len(third_task.parents) == 0 assert len(third_task.children) == 0 def test_add_task(self): @@ -161,8 +179,8 @@ def test_add_task(self): assert first_task.structure == workflow assert second_task.structure == workflow assert len(first_task.parents) == 0 - assert len(first_task.children) == 1 - assert len(second_task.parents) == 1 + assert len(first_task.children) == 0 + assert len(second_task.parents) == 0 assert len(second_task.children) == 0 def test_add_tasks(self): @@ -179,8 +197,8 @@ def test_add_tasks(self): assert first_task.structure == workflow assert second_task.structure == workflow assert len(first_task.parents) == 0 - assert len(first_task.children) == 1 - assert len(second_task.parents) == 1 + assert len(first_task.children) == 0 + assert len(second_task.parents) == 0 assert len(second_task.children) == 0 def test_run(self): @@ -210,7 +228,111 @@ def test_run_with_args(self): assert task.input.to_text() == "-" - def test_run_topology_1(self): + @pytest.mark.parametrize( + "tasks", + [ + [PromptTask(id="task1", parent_ids=["missing"])], + [PromptTask(id="task1", child_ids=["missing"])], + [PromptTask(id="task1"), PromptTask(id="task2", parent_ids=["missing"])], + [PromptTask(id="task1"), PromptTask(id="task2", parent_ids=["task1", "missing"])], + [PromptTask(id="task1"), PromptTask(id="task2", parent_ids=["task1"], child_ids=["missing"])], + ], + ) + def test_run_raises_on_missing_parent_or_child_id(self, tasks): + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=tasks) + + with pytest.raises(ValueError) as e: + workflow.run() + + assert e.value.args[0] == "Task with id missing doesn't exist." + + def test_run_topology_1_declarative_parents(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1"), + PromptTask("test2", id="task2", parent_ids=["task1"]), + PromptTask("test3", id="task3", parent_ids=["task1"]), + PromptTask("test4", id="task4", parent_ids=["task2", "task3"]), + ], + ) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_declarative_children(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1", child_ids=["task2", "task3"]), + PromptTask("test2", id="task2", child_ids=["task4"]), + PromptTask("test3", id="task3", child_ids=["task4"]), + PromptTask("test4", id="task4"), + ], + ) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_declarative_mixed(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1", child_ids=["task3"]), + PromptTask("test2", id="task2", parent_ids=["task1"], child_ids=["task4"]), + PromptTask("test3", id="task3"), + PromptTask("test4", id="task4", parent_ids=["task2", "task3"]), + ], + ) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_imperative_parents(self): + task1 = PromptTask("test1", id="task1") + task2 = PromptTask("test2", id="task2") + task3 = PromptTask("test3", id="task3") + task4 = PromptTask("test4", id="task4") + task2.add_parent(task1) + task3.add_parent("task1") + task4.add_parents([task2, "task3"]) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_imperative_children(self): + task1 = PromptTask("test1", id="task1") + task2 = PromptTask("test2", id="task2") + task3 = PromptTask("test3", id="task3") + task4 = PromptTask("test4", id="task4") + task1.add_children([task2, task3]) + task2.add_child(task4) + task3.add_child(task4) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_imperative_mixed(self): + task1 = PromptTask("test1", id="task1") + task2 = PromptTask("test2", id="task2") + task3 = PromptTask("test3", id="task3") + task4 = PromptTask("test4", id="task4") + task1.add_children([task2, task3]) + task4.add_parents([task2, task3]) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task1, task2, task3, task4]) + + workflow.run() + + self._validate_topology_1(workflow) + + def test_run_topology_1_imperative_insert(self): task1 = PromptTask("test1", id="task1") task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") @@ -225,60 +347,152 @@ def test_run_topology_1(self): workflow.run() - assert task1.state == BaseTask.State.FINISHED - assert task1.parent_ids == [] - assert task1.child_ids == ["task2", "task3"] + self._validate_topology_1(workflow) + + def test_run_topology_2_declarative_parents(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("testa", id="taska"), + PromptTask("testb", id="taskb", parent_ids=["taska"]), + PromptTask("testc", id="taskc", parent_ids=["taska"]), + PromptTask("testd", id="taskd", parent_ids=["taska", "taskb", "taskc"]), + PromptTask("teste", id="taske", parent_ids=["taska", "taskd", "taskc"]), + ], + ) - assert task2.state == BaseTask.State.FINISHED - assert task2.parent_ids == ["task1"] - assert task2.child_ids == ["task4"] + workflow.run() - assert task3.state == BaseTask.State.FINISHED - assert task3.parent_ids == ["task1"] - assert task3.child_ids == ["task4"] + self._validate_topology_2(workflow) + + def test_run_topology_2_declarative_children(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("testa", id="taska", child_ids=["taskb", "taskc", "taskd", "taske"]), + PromptTask("testb", id="taskb", child_ids=["taskd"]), + PromptTask("testc", id="taskc", child_ids=["taskd", "taske"]), + PromptTask("testd", id="taskd", child_ids=["taske"]), + PromptTask("teste", id="taske", child_ids=[]), + ], + ) - assert task4.state == BaseTask.State.FINISHED - assert task4.parent_ids == ["task2", "task3"] - assert task4.child_ids == [] + workflow.run() - def test_run_topology_2(self): - """Adapted from https://en.wikipedia.org/wiki/Directed_acyclic_graph#/media/File:Tred-G.svg""" + self._validate_topology_2(workflow) + + def test_run_topology_2_imperative_parents(self): taska = PromptTask("testa", id="taska") taskb = PromptTask("testb", id="taskb") taskc = PromptTask("testc", id="taskc") taskd = PromptTask("testd", id="taskd") taske = PromptTask("teste", id="taske") - workflow = Workflow(prompt_driver=MockPromptDriver()) + taskb.add_parent(taska) + taskc.add_parent("taska") + taskd.add_parents([taska, taskb, taskc]) + taske.add_parents(["taska", taskd, "taskc"]) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + + workflow.run() + + self._validate_topology_2(workflow) + + def test_run_topology_2_imperative_children(self): + taska = PromptTask("testa", id="taska") + taskb = PromptTask("testb", id="taskb") + taskc = PromptTask("testc", id="taskc") + taskd = PromptTask("testd", id="taskd") + taske = PromptTask("teste", id="taske") + taska.add_children([taskb, taskc, taskd, taske]) + taskb.add_child(taskd) + taskc.add_children([taskd, taske]) + taskd.add_child(taske) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + + workflow.run() + + self._validate_topology_2(workflow) + + def test_run_topology_2_imperative_mixed(self): + taska = PromptTask("testa", id="taska") + taskb = PromptTask("testb", id="taskb") + taskc = PromptTask("testc", id="taskc") + taskd = PromptTask("testd", id="taskd") + taske = PromptTask("teste", id="taske") + taska.add_children([taskb, taskc, taskd, taske]) + taskb.add_child(taskd) + taskd.add_parent(taskc) + taske.add_parents(["taska", taskd, "taskc"]) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[taska, taskb, taskc, taskd, taske]) + + workflow.run() + + self._validate_topology_2(workflow) + def test_run_topology_2_imperative_insert(self): + taska = PromptTask("testa", id="taska") + taskb = PromptTask("testb", id="taskb") + taskc = PromptTask("testc", id="taskc") + taskd = PromptTask("testd", id="taskd") + taske = PromptTask("teste", id="taske") + workflow = Workflow(prompt_driver=MockPromptDriver()) workflow.add_task(taska) workflow.add_task(taske) + taske.add_parent(taska) workflow.insert_tasks(taska, taskd, taske, preserve_relationship=True) workflow.insert_tasks(taska, [taskc], [taskd, taske], preserve_relationship=True) workflow.insert_tasks(taska, taskb, taskd, preserve_relationship=True) workflow.run() - assert taska.state == BaseTask.State.FINISHED - assert taska.parent_ids == [] - assert set(taska.child_ids) == {"taskb", "taskd", "taskc", "taske"} + self._validate_topology_2(workflow) + + def test_run_topology_3_declarative_parents(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1"), + PromptTask("test2", id="task2", parent_ids=["task4"]), + PromptTask("test4", id="task4", parent_ids=["task1"]), + PromptTask("test3", id="task3", parent_ids=["task2"]), + ], + ) - assert taskb.state == BaseTask.State.FINISHED - assert taskb.parent_ids == ["taska"] - assert taskb.child_ids == ["taskd"] + workflow.run() - assert taskc.state == BaseTask.State.FINISHED - assert taskc.parent_ids == ["taska"] - assert set(taskc.child_ids) == {"taskd", "taske"} + self._validate_topology_3(workflow) + + def test_run_topology_3_declarative_children(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1", child_ids=["task4"]), + PromptTask("test2", id="task2", child_ids=["task3"]), + PromptTask("test4", id="task4", child_ids=["task2"]), + PromptTask("test3", id="task3", child_ids=[]), + ], + ) - assert taskd.state == BaseTask.State.FINISHED - assert set(taskd.parent_ids) == {"taskb", "taska", "taskc"} - assert taskd.child_ids == ["taske"] + workflow.run() - assert taske.state == BaseTask.State.FINISHED - assert set(taske.parent_ids) == {"taskd", "taskc", "taska"} - assert taske.child_ids == [] + self._validate_topology_3(workflow) + + def test_run_topology_3_declarative_mixed(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1"), + PromptTask("test2", id="task2", parent_ids=["task4"], child_ids=["task3"]), + PromptTask("test4", id="task4", parent_ids=["task1"], child_ids=["task2"]), + PromptTask("test3", id="task3"), + ], + ) - def test_run_topology_3(self): + workflow.run() + + self._validate_topology_3(workflow) + + def test_run_topology_3_imperative_insert(self): task1 = PromptTask("test1", id="task1") task2 = PromptTask("test2", id="task2") task3 = PromptTask("test3", id="task3") @@ -288,28 +502,75 @@ def test_run_topology_3(self): workflow + task1 workflow + task2 workflow + task3 + task2.add_parent(task1) + task3.add_parent(task2) workflow.insert_tasks(task1, task4, task2) workflow.run() - assert task1.state == BaseTask.State.FINISHED - assert task1.parent_ids == [] - assert task1.child_ids == ["task4"] + self._validate_topology_3(workflow) + + def test_run_topology_4_declarative_parents(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask(id="collect_movie_info"), + PromptTask(id="movie_info_1", parent_ids=["collect_movie_info"]), + PromptTask(id="movie_info_2", parent_ids=["collect_movie_info"]), + PromptTask(id="movie_info_3", parent_ids=["collect_movie_info"]), + PromptTask(id="compare_movies", parent_ids=["movie_info_1", "movie_info_2", "movie_info_3"]), + PromptTask(id="send_email_task", parent_ids=["compare_movies"]), + PromptTask(id="save_to_disk", parent_ids=["compare_movies"]), + PromptTask(id="publish_website", parent_ids=["compare_movies"]), + PromptTask(id="summarize_to_slack", parent_ids=["send_email_task", "save_to_disk", "publish_website"]), + ], + ) - assert task2.state == BaseTask.State.FINISHED - assert task2.parent_ids == ["task4"] - assert task2.child_ids == ["task3"] + workflow.run() - assert task3.state == BaseTask.State.FINISHED - assert task3.parent_ids == ["task2"] - assert task3.child_ids == [] + self._validate_topology_4(workflow) + + def test_run_topology_4_declarative_children(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask(id="collect_movie_info", child_ids=["movie_info_1", "movie_info_2", "movie_info_3"]), + PromptTask(id="movie_info_1", child_ids=["compare_movies"]), + PromptTask(id="movie_info_2", child_ids=["compare_movies"]), + PromptTask(id="movie_info_3", child_ids=["compare_movies"]), + PromptTask(id="compare_movies", child_ids=["send_email_task", "save_to_disk", "publish_website"]), + PromptTask(id="send_email_task", child_ids=["summarize_to_slack"]), + PromptTask(id="save_to_disk", child_ids=["summarize_to_slack"]), + PromptTask(id="publish_website", child_ids=["summarize_to_slack"]), + PromptTask(id="summarize_to_slack", child_ids=[]), + ], + ) - assert task4.state == BaseTask.State.FINISHED - assert task4.parent_ids == ["task1"] - assert task4.child_ids == ["task2"] + workflow.run() - def test_run_topology_4(self): - workflow = Workflow(prompt_driver=MockPromptDriver()) + self._validate_topology_4(workflow) + + def test_run_topology_4_declarative_mixed(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask(id="collect_movie_info"), + PromptTask(id="movie_info_1", parent_ids=["collect_movie_info"], child_ids=["compare_movies"]), + PromptTask(id="movie_info_2", parent_ids=["collect_movie_info"], child_ids=["compare_movies"]), + PromptTask(id="movie_info_3", parent_ids=["collect_movie_info"], child_ids=["compare_movies"]), + PromptTask(id="compare_movies"), + PromptTask(id="send_email_task", parent_ids=["compare_movies"], child_ids=["summarize_to_slack"]), + PromptTask(id="save_to_disk", parent_ids=["compare_movies"], child_ids=["summarize_to_slack"]), + PromptTask(id="publish_website", parent_ids=["compare_movies"], child_ids=["summarize_to_slack"]), + PromptTask(id="summarize_to_slack"), + ], + ) + + workflow.run() + + self._validate_topology_4(workflow) + + def test_run_topology_4_imperative_insert(self): collect_movie_info = PromptTask(id="collect_movie_info") summarize_to_slack = PromptTask(id="summarize_to_slack") movie_info_1 = PromptTask(id="movie_info_1") @@ -321,30 +582,34 @@ def test_run_topology_4(self): publish_website = PromptTask(id="publish_website") movie_info_3 = PromptTask(id="movie_info_3") + workflow = Workflow(prompt_driver=MockPromptDriver()) workflow.add_tasks(collect_movie_info, summarize_to_slack) workflow.insert_tasks(collect_movie_info, [movie_info_1, movie_info_2, movie_info_3], summarize_to_slack) workflow.insert_tasks([movie_info_1, movie_info_2, movie_info_3], compare_movies, summarize_to_slack) workflow.insert_tasks(compare_movies, [send_email_task, save_to_disk, publish_website], summarize_to_slack) - assert set(collect_movie_info.child_ids) == {"movie_info_1", "movie_info_2", "movie_info_3"} - - assert set(movie_info_1.parent_ids) == {"collect_movie_info"} - assert set(movie_info_2.parent_ids) == {"collect_movie_info"} - assert set(movie_info_3.parent_ids) == {"collect_movie_info"} - assert set(movie_info_1.child_ids) == {"compare_movies"} - assert set(movie_info_2.child_ids) == {"compare_movies"} - assert set(movie_info_3.child_ids) == {"compare_movies"} - - assert set(compare_movies.parent_ids) == {"movie_info_1", "movie_info_2", "movie_info_3"} - assert set(compare_movies.child_ids) == {"send_email_task", "save_to_disk", "publish_website"} - - assert set(send_email_task.parent_ids) == {"compare_movies"} - assert set(save_to_disk.parent_ids) == {"compare_movies"} - assert set(publish_website.parent_ids) == {"compare_movies"} - - assert set(send_email_task.child_ids) == {"summarize_to_slack"} - assert set(save_to_disk.child_ids) == {"summarize_to_slack"} - assert set(publish_website.child_ids) == {"summarize_to_slack"} + self._validate_topology_4(workflow) + + @pytest.mark.parametrize( + "tasks", + [ + [PromptTask(id="a", parent_ids=["a"])], + [PromptTask(id="a"), PromptTask(id="b", parent_ids=["a", "b"])], + [PromptTask(id="a", parent_ids=["b"]), PromptTask(id="b", parent_ids=["a"])], + [ + PromptTask(id="a", parent_ids=["c"]), + PromptTask(id="b", parent_ids=["a"]), + PromptTask(id="c", parent_ids=["b"]), + ], + ], + ) + def test_run_raises_on_cycle(self, tasks): + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=tasks) + + with pytest.raises(ValueError) as e: + workflow.run() + + assert e.value.args[0] == "nodes are in a cycle" def test_input_task(self): task1 = PromptTask("prompt1") @@ -372,6 +637,15 @@ def test_output_task(self): assert task4 == workflow.output_task + task4.add_parents([task2, task3]) + task1.add_children([task2, task3]) + + # task4 is the final task, but its defined at index 0 + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[task4, task1, task2, task3]) + + # ouput_task topologically should be task4 + assert task4 == workflow.output_task + def test_to_graph(self): task1 = PromptTask("prompt1", id="task1") task2 = PromptTask("prompt2", id="task2") @@ -417,6 +691,9 @@ def test_context(self): workflow + task workflow + child + task.add_parent(parent) + task.add_child(child) + context = workflow.context(task) assert context["parent_outputs"] == {parent.id: ""} @@ -426,6 +703,7 @@ def test_context(self): context = workflow.context(task) assert context["parent_outputs"] == {parent.id: parent.output.to_text()} + assert context["parents_output_text"] == "mock output" assert context["structure"] == workflow assert context["parents"] == {parent.id: parent} assert context["children"] == {child.id: child} @@ -439,3 +717,141 @@ def test_deprecation(self): with pytest.deprecated_call(): Workflow(stream=True) + + def test_run_with_error_artifact(self, error_artifact_task, waiting_task): + end_task = PromptTask("end") + end_task.add_parents([error_artifact_task, waiting_task]) + workflow = Workflow(prompt_driver=MockPromptDriver(), tasks=[waiting_task, error_artifact_task, end_task]) + workflow.run() + + assert workflow.output is None + + @staticmethod + def _validate_topology_1(workflow): + assert len(workflow.tasks) == 4 + assert workflow.input_task.id == "task1" + assert workflow.output_task.id == "task4" + assert workflow.input_task.id == workflow.tasks[0].id + assert workflow.output_task.id == workflow.tasks[-1].id + + task1 = workflow.find_task("task1") + assert task1.state == BaseTask.State.FINISHED + assert task1.parent_ids == [] + assert sorted(task1.child_ids) == ["task2", "task3"] + + task2 = workflow.find_task("task2") + assert task2.state == BaseTask.State.FINISHED + assert task2.parent_ids == ["task1"] + assert task2.child_ids == ["task4"] + + task3 = workflow.find_task("task3") + assert task3.state == BaseTask.State.FINISHED + assert task3.parent_ids == ["task1"] + assert task3.child_ids == ["task4"] + + task4 = workflow.find_task("task4") + assert task4.state == BaseTask.State.FINISHED + assert sorted(task4.parent_ids) == ["task2", "task3"] + assert task4.child_ids == [] + + @staticmethod + def _validate_topology_2(workflow): + """Adapted from https://en.wikipedia.org/wiki/Directed_acyclic_graph#/media/File:Tred-G.svg""" + assert len(workflow.tasks) == 5 + assert workflow.input_task.id == "taska" + assert workflow.output_task.id == "taske" + assert workflow.input_task.id == workflow.tasks[0].id + assert workflow.output_task.id == workflow.tasks[-1].id + + taska = workflow.find_task("taska") + assert taska.state == BaseTask.State.FINISHED + assert taska.parent_ids == [] + assert sorted(taska.child_ids) == ["taskb", "taskc", "taskd", "taske"] + + taskb = workflow.find_task("taskb") + assert taskb.state == BaseTask.State.FINISHED + assert taskb.parent_ids == ["taska"] + assert taskb.child_ids == ["taskd"] + + taskc = workflow.find_task("taskc") + assert taskc.state == BaseTask.State.FINISHED + assert taskc.parent_ids == ["taska"] + assert sorted(taskc.child_ids) == ["taskd", "taske"] + + taskd = workflow.find_task("taskd") + assert taskd.state == BaseTask.State.FINISHED + assert sorted(taskd.parent_ids) == ["taska", "taskb", "taskc"] + assert taskd.child_ids == ["taske"] + + taske = workflow.find_task("taske") + assert taske.state == BaseTask.State.FINISHED + assert sorted(taske.parent_ids) == ["taska", "taskc", "taskd"] + assert taske.child_ids == [] + + @staticmethod + def _validate_topology_3(workflow): + assert len(workflow.tasks) == 4 + assert workflow.input_task.id == "task1" + assert workflow.output_task.id == "task3" + assert workflow.input_task.id == workflow.tasks[0].id + assert workflow.output_task.id == workflow.tasks[-1].id + + task1 = workflow.find_task("task1") + assert task1.state == BaseTask.State.FINISHED + assert task1.parent_ids == [] + assert task1.child_ids == ["task4"] + + task2 = workflow.find_task("task2") + assert task2.state == BaseTask.State.FINISHED + assert task2.parent_ids == ["task4"] + assert task2.child_ids == ["task3"] + + task3 = workflow.find_task("task3") + assert task3.state == BaseTask.State.FINISHED + assert task3.parent_ids == ["task2"] + assert task3.child_ids == [] + + task4 = workflow.find_task("task4") + assert task4.state == BaseTask.State.FINISHED + assert task4.parent_ids == ["task1"] + assert task4.child_ids == ["task2"] + + @staticmethod + def _validate_topology_4(workflow): + assert len(workflow.tasks) == 9 + assert workflow.input_task.id == "collect_movie_info" + assert workflow.output_task.id == "summarize_to_slack" + assert workflow.input_task.id == workflow.tasks[0].id + assert workflow.output_task.id == workflow.tasks[-1].id + + collect_movie_info = workflow.find_task("collect_movie_info") + assert collect_movie_info.parent_ids == [] + assert sorted(collect_movie_info.child_ids) == ["movie_info_1", "movie_info_2", "movie_info_3"] + + movie_info_1 = workflow.find_task("movie_info_1") + assert movie_info_1.parent_ids == ["collect_movie_info"] + assert movie_info_1.child_ids == ["compare_movies"] + + movie_info_2 = workflow.find_task("movie_info_2") + assert movie_info_2.parent_ids == ["collect_movie_info"] + assert movie_info_2.child_ids == ["compare_movies"] + + movie_info_3 = workflow.find_task("movie_info_3") + assert movie_info_3.parent_ids == ["collect_movie_info"] + assert movie_info_3.child_ids == ["compare_movies"] + + compare_movies = workflow.find_task("compare_movies") + assert sorted(compare_movies.parent_ids) == ["movie_info_1", "movie_info_2", "movie_info_3"] + assert sorted(compare_movies.child_ids) == ["publish_website", "save_to_disk", "send_email_task"] + + send_email_task = workflow.find_task("send_email_task") + assert send_email_task.parent_ids == ["compare_movies"] + assert send_email_task.child_ids == ["summarize_to_slack"] + + save_to_disk = workflow.find_task("save_to_disk") + assert save_to_disk.parent_ids == ["compare_movies"] + assert save_to_disk.child_ids == ["summarize_to_slack"] + + publish_website = workflow.find_task("publish_website") + assert publish_website.parent_ids == ["compare_movies"] + assert publish_website.child_ids == ["summarize_to_slack"] diff --git a/tests/unit/tasks/test_audio_transcription_task.py b/tests/unit/tasks/test_audio_transcription_task.py index fdab5f730..3a53fd49d 100644 --- a/tests/unit/tasks/test_audio_transcription_task.py +++ b/tests/unit/tasks/test_audio_transcription_task.py @@ -2,10 +2,11 @@ import pytest -from griptape.artifacts import AudioArtifact +from griptape.artifacts import AudioArtifact, TextArtifact from griptape.engines import AudioTranscriptionEngine -from griptape.structures import Agent +from griptape.structures import Agent, Pipeline from griptape.tasks import BaseTask, AudioTranscriptionTask +from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_structure_config import MockStructureConfig @@ -36,3 +37,17 @@ def test_config_audio_transcription_engine(self, audio_artifact): Agent(config=MockStructureConfig()).add_task(task) assert isinstance(task.audio_transcription_engine, AudioTranscriptionEngine) + + def test_run(self, audio_artifact, audio_transcription_engine): + audio_transcription_engine.run.return_value = TextArtifact("mock transcription") + logger = Mock() + + task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) + pipeline = Pipeline(prompt_driver=MockPromptDriver(), logger=logger) + pipeline.add_task(task) + + assert pipeline.run().output.to_text() == "mock transcription" + + def test_before_run(self, audio_artifact, audio_transcription_engine): + task = AudioTranscriptionTask(audio_artifact, audio_transcription_engine=audio_transcription_engine) + task diff --git a/tests/unit/tasks/test_base_audio_input_task.py b/tests/unit/tasks/test_base_audio_input_task.py new file mode 100644 index 000000000..e11074880 --- /dev/null +++ b/tests/unit/tasks/test_base_audio_input_task.py @@ -0,0 +1,26 @@ +import pytest + +from tests.mocks.mock_audio_input_task import MockAudioInputTask +from griptape.artifacts import AudioArtifact, TextArtifact +from tests.mocks.mock_text_input_task import MockTextInputTask + + +class TestBaseAudioInputTask: + @pytest.fixture + def audio_artifact(self): + return AudioArtifact(b"audio content", format="mp3") + + def test_audio_artifact_input(self, audio_artifact): + task = MockAudioInputTask(audio_artifact) + assert task.input.value == audio_artifact.value + + audio_artifact.value = b"new audio content" + task.input = audio_artifact + assert task.input.value == audio_artifact.value + + def test_callable_input(self, audio_artifact): + assert MockTextInputTask(lambda _: audio_artifact).input.value == audio_artifact.value + + def test_bad_input(self): + with pytest.raises(ValueError): + assert MockAudioInputTask(TextArtifact("foobar")).input.value == "foobar" diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 3048b7d7a..7fe2810f5 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -3,6 +3,7 @@ from griptape.artifacts import TextArtifact from griptape.structures import Agent from griptape.tasks import ActionsSubtask +from griptape.structures import Workflow from tests.mocks.mock_embedding_driver import MockEmbeddingDriver from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_task import MockTask @@ -30,3 +31,40 @@ def test_meta_memories(self, task): task.structure.task_memory.process_output(MockTool().test, subtask, TextArtifact("foo")) assert len(task.meta_memories) == 2 + + def test_parent_outputs(self, task): + parent_1 = MockTask("foobar1", id="foobar1") + parent_2 = MockTask("foobar2", id="foobar2") + parent_3 = MockTask("foobar3", id="foobar3") + child = MockTask("foobar", id="foobar") + + child.add_parent(parent_1) + child.add_parent(parent_2) + child.add_parent(parent_3) + + workflow = Workflow(tasks=[parent_1, parent_2, parent_3, child]) + workflow.run() + + parent_3.output = None + assert child.parent_outputs == { + parent_1.id: parent_1.output.to_text(), + parent_2.id: parent_2.output.to_text(), + parent_3.id: "", + } + + def test_parents_output(self, task): + parent_1 = MockTask("foobar1", id="foobar1") + parent_2 = MockTask("foobar2", id="foobar2") + parent_3 = MockTask("foobar3", id="foobar3") + child = MockTask("foobar", id="foobar") + + child.add_parent(parent_1) + child.add_parent(parent_2) + child.add_parent(parent_3) + + workflow = Workflow(tasks=[parent_1, parent_2, parent_3, child]) + workflow.run() + + parent_2.output = None + + assert child.parents_output_text == "foobar1\nfoobar3" diff --git a/tests/unit/tasks/test_text_to_speech_task.py b/tests/unit/tasks/test_text_to_speech_task.py index f30893e8a..7a8e49364 100644 --- a/tests/unit/tasks/test_text_to_speech_task.py +++ b/tests/unit/tasks/test_text_to_speech_task.py @@ -1,9 +1,10 @@ from unittest.mock import Mock -from griptape.artifacts import TextArtifact +from griptape.artifacts import TextArtifact, AudioArtifact from griptape.engines import TextToSpeechEngine -from griptape.structures import Agent +from griptape.structures import Agent, Pipeline from griptape.tasks import BaseTask, TextToSpeechTask +from tests.mocks.mock_prompt_driver import MockPromptDriver from tests.mocks.mock_structure_config import MockStructureConfig @@ -28,3 +29,19 @@ def test_config_text_to_speech_engine(self): Agent(config=MockStructureConfig()).add_task(task) assert isinstance(task.text_to_speech_engine, TextToSpeechEngine) + + def test_calls(self): + text_to_speech_engine = Mock() + text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3") + + assert TextToSpeechTask("test", text_to_speech_engine=text_to_speech_engine).run().value == b"audio content" + + def test_run(self): + text_to_speech_engine = Mock() + text_to_speech_engine.run.return_value = AudioArtifact(b"audio content", format="mp3") + + task = TextToSpeechTask("some text", text_to_speech_engine=text_to_speech_engine) + pipeline = Pipeline(prompt_driver=MockPromptDriver()) + pipeline.add_task(task) + + assert isinstance(pipeline.run().output, AudioArtifact) diff --git a/griptape/drivers/embedding_model/__init__.py b/tests/unit/tokenizers/__init__.py similarity index 100% rename from griptape/drivers/embedding_model/__init__.py rename to tests/unit/tokenizers/__init__.py diff --git a/tests/unit/tokenizers/test_bedrock_claude_tokenizer.py b/tests/unit/tokenizers/test_amazon_bedrock_tokenizer.py similarity index 51% rename from tests/unit/tokenizers/test_bedrock_claude_tokenizer.py rename to tests/unit/tokenizers/test_amazon_bedrock_tokenizer.py index 4f6e6723f..2b77ba3dc 100644 --- a/tests/unit/tokenizers/test_bedrock_claude_tokenizer.py +++ b/tests/unit/tokenizers/test_amazon_bedrock_tokenizer.py @@ -1,19 +1,19 @@ +from griptape.tokenizers import AmazonBedrockTokenizer import pytest -from griptape.tokenizers import BedrockClaudeTokenizer -class TestBedrockClaudeTokenizer: +class TestAmazonBedrockTokenizer: @pytest.fixture def tokenizer(self, request): - return BedrockClaudeTokenizer(model=request.param) + return AmazonBedrockTokenizer(model=request.param) @pytest.mark.parametrize( "tokenizer,expected", [ - ("anthropic.claude-v2:1", 5), - ("anthropic.claude-v2", 5), - ("anthropic.claude-3-sonnet-20240229-v1:0", 5), - ("anthropic.claude-3-haiku-20240307-v1:0", 5), + ("anthropic.claude-v2:1", 4), + ("anthropic.claude-v2", 4), + ("anthropic.claude-3-sonnet-20240229-v1:0", 4), + ("anthropic.claude-3-haiku-20240307-v1:0", 4), ], indirect=["tokenizer"], ) @@ -23,10 +23,10 @@ def test_token_count(self, tokenizer, expected): @pytest.mark.parametrize( "tokenizer,expected", [ - ("anthropic.claude-v2", 99995), - ("anthropic.claude-v2:1", 199995), - ("anthropic.claude-3-sonnet-20240229-v1:0", 199995), - ("anthropic.claude-3-haiku-20240307-v1:0", 199995), + ("anthropic.claude-v2", 99996), + ("anthropic.claude-v2:1", 199996), + ("anthropic.claude-3-sonnet-20240229-v1:0", 199996), + ("anthropic.claude-3-haiku-20240307-v1:0", 199996), ], indirect=["tokenizer"], ) @@ -36,10 +36,10 @@ def test_input_tokens_left(self, tokenizer, expected): @pytest.mark.parametrize( "tokenizer,expected", [ - ("anthropic.claude-v2", 4091), - ("anthropic.claude-v2:1", 4091), - ("anthropic.claude-3-sonnet-20240229-v1:0", 4091), - ("anthropic.claude-3-haiku-20240307-v1:0", 4091), + ("anthropic.claude-v2", 4092), + ("anthropic.claude-v2:1", 4092), + ("anthropic.claude-3-sonnet-20240229-v1:0", 4092), + ("anthropic.claude-3-haiku-20240307-v1:0", 4092), ], indirect=["tokenizer"], ) diff --git a/tests/unit/tokenizers/test_base_tokenizer.py b/tests/unit/tokenizers/test_base_tokenizer.py new file mode 100644 index 000000000..eed15b9b2 --- /dev/null +++ b/tests/unit/tokenizers/test_base_tokenizer.py @@ -0,0 +1,13 @@ +import logging +from tests.mocks.mock_tokenizer import MockTokenizer + + +class TestBaseTokenizer: + def test_default_tokens(self, caplog): + with caplog.at_level(logging.WARNING): + tokenizer = MockTokenizer(model="gpt2") + + assert tokenizer.max_input_tokens == 4096 + assert tokenizer.max_output_tokens == 1000 + + assert "gpt2 not found" in caplog.text diff --git a/tests/unit/tokenizers/test_bedrock_cohere_tokenizer.py b/tests/unit/tokenizers/test_bedrock_cohere_tokenizer.py deleted file mode 100644 index 6238b0e54..000000000 --- a/tests/unit/tokenizers/test_bedrock_cohere_tokenizer.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest -from unittest import mock -from griptape.tokenizers import BedrockCohereTokenizer - - -class TestBedrockCohereTokenizer: - @pytest.fixture(autouse=True) - def mock_session(self, mocker): - fake_tokenization = '{"inputTextTokenCount": 2}' - mock_session_class = mocker.patch("boto3.Session") - - mock_session_object = mock.Mock() - mock_client = mock.Mock() - mock_response = mock.Mock() - - mock_response.get().read.return_value = fake_tokenization - mock_client.invoke_model.return_value = mock_response - mock_session_object.client.return_value = mock_client - mock_session_class.return_value = mock_session_object - - def test_input_tokens_left(self): - assert BedrockCohereTokenizer(model="cohere").count_input_tokens_left("foo bar") == 1022 - - def test_output_tokens_left(self): - assert BedrockCohereTokenizer(model="cohere").count_output_tokens_left("foo bar") == 4094 diff --git a/tests/unit/tokenizers/test_bedrock_jurassic_tokenizer.py b/tests/unit/tokenizers/test_bedrock_jurassic_tokenizer.py deleted file mode 100644 index 59c42493b..000000000 --- a/tests/unit/tokenizers/test_bedrock_jurassic_tokenizer.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest -from unittest import mock -from griptape.tokenizers import BedrockJurassicTokenizer - - -class TestBedrockJurassicTokenizer: - @pytest.fixture(autouse=True) - def mock_session(self, mocker): - fake_tokenization = '{"prompt": {"tokens": [{}, {}, {}]}}' - mock_session_class = mocker.patch("boto3.Session") - - mock_session_object = mock.Mock() - mock_client = mock.Mock() - mock_response = mock.Mock() - - mock_response.get().read.return_value = fake_tokenization - mock_client.invoke_model.return_value = mock_response - mock_session_object.client.return_value = mock_client - mock_session_class.return_value = mock_session_object - - @pytest.fixture - def tokenizer(self, request): - return BedrockJurassicTokenizer(model=request.param) - - @pytest.mark.parametrize( - "tokenizer,expected", - [("ai21.j2-mid-v1", 8186), ("ai21.j2-ultra-v1", 8186), ("ai21.j2-large-v1", 8186), ("ai21.j2-large-v2", 8186)], - indirect=["tokenizer"], - ) - def test_input_tokens_left(self, tokenizer, expected): - assert tokenizer.count_input_tokens_left("System: foo\nUser: bar\nAssistant:") == expected - - @pytest.mark.parametrize( - "tokenizer,expected", - [("ai21.j2-mid-v1", 8185), ("ai21.j2-ultra-v1", 8185), ("ai21.j2-large-v1", 8185), ("ai21.j2-large-v2", 2042)], - indirect=["tokenizer"], - ) - def test_output_tokens_left(self, tokenizer, expected): - assert tokenizer.count_output_tokens_left("System: foo\nUser: bar\nAssistant:") == expected diff --git a/tests/unit/tokenizers/test_bedrock_llama_tokenizer.py b/tests/unit/tokenizers/test_bedrock_llama_tokenizer.py deleted file mode 100644 index da842f0f3..000000000 --- a/tests/unit/tokenizers/test_bedrock_llama_tokenizer.py +++ /dev/null @@ -1,35 +0,0 @@ -import pytest -from unittest import mock -from griptape.tokenizers import BedrockLlamaTokenizer - - -class TestBedrockLlamaTokenizer: - @pytest.fixture(autouse=True) - def mock_session(self, mocker): - fake_tokenization = '{"generation_token_count": 13}' - mock_session_class = mocker.patch("boto3.Session") - - mock_session_object = mock.Mock() - mock_client = mock.Mock() - mock_response = mock.Mock() - - mock_response.get().read.return_value = fake_tokenization - mock_client.invoke_model.return_value = mock_response - mock_session_object.client.return_value = mock_client - mock_session_class.return_value = mock_session_object - - def test_input_tokens_left(self): - assert ( - BedrockLlamaTokenizer(model="meta.llama").count_input_tokens_left( - "[INST] <>\n{{ system_prompt }}\n<>\n\n{{ usr_msg_1 }} [/INST] {{ model_msg_1 }} [INST] {{ usr_msg_2 }} [/INST]" - ) - == 2026 - ) - - def test_ouput_tokens_left(self): - assert ( - BedrockLlamaTokenizer(model="meta.llama").count_output_tokens_left( - "[INST] <>\n{{ system_prompt }}\n<>\n\n{{ usr_msg_1 }} [/INST] {{ model_msg_1 }} [INST] {{ usr_msg_2 }} [/INST]" - ) - == 2026 - ) diff --git a/tests/unit/tokenizers/test_bedrock_titan_tokenizer.py b/tests/unit/tokenizers/test_bedrock_titan_tokenizer.py deleted file mode 100644 index c4f4f42ad..000000000 --- a/tests/unit/tokenizers/test_bedrock_titan_tokenizer.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest -from unittest import mock -from griptape.tokenizers import BedrockTitanTokenizer - - -class TestBedrockTitanTokenizer: - @pytest.fixture(autouse=True) - def mock_session(self, mocker): - fake_tokenization = '{"inputTextTokenCount": 13}' - mock_session_class = mocker.patch("boto3.Session") - - mock_session_object = mock.Mock() - mock_client = mock.Mock() - mock_response = mock.Mock() - - mock_response.get().read.return_value = fake_tokenization - mock_client.invoke_model.return_value = mock_response - mock_session_object.client.return_value = mock_client - mock_session_class.return_value = mock_session_object - - def test_input_tokens_left(self): - assert ( - BedrockTitanTokenizer(model="amazon.titan").count_input_tokens_left("Instructions: foo\nUser: bar\nBot:") - == 4090 - ) - - def test_output_tokens_left(self): - assert ( - BedrockTitanTokenizer(model="amazon.titan").count_output_tokens_left("Instructions: foo\nUser: bar\nBot:") - == 7994 - ) diff --git a/tests/unit/tokenizers/test_cohere_tokenizer.py b/tests/unit/tokenizers/test_cohere_tokenizer.py index ca724cee6..9ca23f4f0 100644 --- a/tests/unit/tokenizers/test_cohere_tokenizer.py +++ b/tests/unit/tokenizers/test_cohere_tokenizer.py @@ -7,6 +7,7 @@ class TestCohereTokenizer: @pytest.fixture(autouse=True) def mock_client(self, mocker): mock_client = mocker.patch("cohere.Client").return_value.tokenize.return_value.tokens = ["foo", "bar"] + return mock_client @pytest.fixture diff --git a/tests/unit/tokenizers/test_google_tokenizer.py b/tests/unit/tokenizers/test_google_tokenizer.py index 2b26ec6a6..955a0517f 100644 --- a/tests/unit/tokenizers/test_google_tokenizer.py +++ b/tests/unit/tokenizers/test_google_tokenizer.py @@ -1,5 +1,6 @@ import pytest from unittest.mock import Mock +from griptape.utils import PromptStack from griptape.tokenizers import GoogleTokenizer @@ -18,6 +19,7 @@ def tokenizer(self, request): @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 5)], indirect=["tokenizer"]) def test_token_count(self, tokenizer, expected): assert tokenizer.count_tokens("foo bar huzzah") == expected + assert tokenizer.count_tokens(PromptStack(inputs=[PromptStack.Input(content="foo", role="user")])) == expected assert tokenizer.count_tokens(["foo", "bar", "huzzah"]) == expected @pytest.mark.parametrize("tokenizer,expected", [("gemini-pro", 30715)], indirect=["tokenizer"]) diff --git a/tests/unit/tokenizers/test_hugging_face_tokenizer.py b/tests/unit/tokenizers/test_hugging_face_tokenizer.py index 5be74f667..dcb309a84 100644 --- a/tests/unit/tokenizers/test_hugging_face_tokenizer.py +++ b/tests/unit/tokenizers/test_hugging_face_tokenizer.py @@ -2,15 +2,14 @@ environ["TRANSFORMERS_VERBOSITY"] = "error" -import pytest -from transformers import GPT2Tokenizer -from griptape.tokenizers import HuggingFaceTokenizer +import pytest # noqa: E402 +from griptape.tokenizers import HuggingFaceTokenizer # noqa: E402 class TestHuggingFaceTokenizer: @pytest.fixture def tokenizer(self): - return HuggingFaceTokenizer(tokenizer=GPT2Tokenizer.from_pretrained("gpt2"), max_output_tokens=1024) + return HuggingFaceTokenizer(model="gpt2", max_output_tokens=1024) def test_token_count(self, tokenizer): assert tokenizer.count_tokens("foo bar huzzah") == 5 diff --git a/tests/unit/tokenizers/test_openai_tokenizer.py b/tests/unit/tokenizers/test_openai_tokenizer.py index b27080aa3..4aa42a87a 100644 --- a/tests/unit/tokenizers/test_openai_tokenizer.py +++ b/tests/unit/tokenizers/test_openai_tokenizer.py @@ -13,6 +13,7 @@ def tokenizer(self, request): ("gpt-4-1106", 5), ("gpt-4-32k", 5), ("gpt-4", 5), + ("gpt-4o", 5), ("gpt-3.5-turbo-0301", 5), ("gpt-3.5-turbo-16k", 5), ("gpt-3.5-turbo", 5), @@ -38,11 +39,13 @@ def test_initialize_with_unknown_model(self): ("gpt-4-1106", 19), ("gpt-4-32k", 19), ("gpt-4", 19), + ("gpt-4o", 19), ("gpt-3.5-turbo-0301", 21), ("gpt-3.5-turbo-16k", 19), ("gpt-3.5-turbo", 19), ("gpt-35-turbo-16k", 19), ("gpt-35-turbo", 19), + ("gpt-35-turbo", 19), ], indirect=["tokenizer"], ) @@ -54,10 +57,18 @@ def test_token_count_for_messages(self, tokenizer, expected): == expected ) + @pytest.mark.parametrize("tokenizer,expected", [("not-real-model", 19)], indirect=["tokenizer"]) + def test_token_count_for_messages_unknown_model(self, tokenizer, expected): + with pytest.raises(NotImplementedError): + tokenizer.count_tokens( + [{"role": "system", "content": "foobar baz"}, {"role": "user", "content": "how foobar am I?"}] + ) + @pytest.mark.parametrize( "tokenizer,expected", [ ("gpt-4-1106", 127987), + ("gpt-4o", 127987), ("gpt-4-32k", 32755), ("gpt-4", 8179), ("gpt-3.5-turbo-16k", 16371), @@ -80,6 +91,7 @@ def test_input_tokens_left(self, tokenizer, expected): ("gpt-4-1106", 4091), ("gpt-4-32k", 4091), ("gpt-4", 4091), + ("gpt-4o", 4091), ("gpt-3.5-turbo-16k", 4091), ("gpt-3.5-turbo", 4091), ("gpt-35-turbo-16k", 4091), diff --git a/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py b/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py index 46270d216..9feba9cbf 100644 --- a/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py +++ b/tests/unit/tools/test_griptape_cloud_knowledge_base_client.py @@ -1,5 +1,6 @@ import pytest -from griptape.artifacts import TextArtifact +from requests import exceptions +from griptape.artifacts import TextArtifact, ErrorArtifact class TestGriptapeCloudKnowledgeBaseClient: @@ -8,10 +9,12 @@ def client(self, mocker): from griptape.tools import GriptapeCloudKnowledgeBaseClient mock_response = mocker.Mock() + mock_response.status_code = 201 mock_response.text.return_value = "foo bar" mocker.patch("requests.post", return_value=mock_response) mock_response = mocker.Mock() + mock_response.status_code = 200 mock_response.json.return_value = {"description": "fizz buzz"} mocker.patch("requests.get", return_value=mock_response) @@ -19,11 +22,63 @@ def client(self, mocker): base_url="https://api.griptape.ai", api_key="foo bar", knowledge_base_id="1" ) + @pytest.fixture + def client_no_description(self, mocker): + from griptape.tools import GriptapeCloudKnowledgeBaseClient + + mock_response = mocker.Mock() + mock_response.json.return_value = {} + mock_response.status_code = 200 + mocker.patch("requests.get", return_value=mock_response) + + return GriptapeCloudKnowledgeBaseClient( + base_url="https://api.griptape.ai", api_key="foo bar", knowledge_base_id="1" + ) + + @pytest.fixture + def client_kb_not_found(self, mocker): + from griptape.tools import GriptapeCloudKnowledgeBaseClient + + mock_response = mocker.Mock() + mock_response.json.return_value = {} + mock_response.status_code = 404 + mocker.patch("requests.get", return_value=mock_response) + + return GriptapeCloudKnowledgeBaseClient( + base_url="https://api.griptape.ai", api_key="foo bar", knowledge_base_id="1" + ) + + @pytest.fixture + def client_kb_error(self, mocker): + from griptape.tools import GriptapeCloudKnowledgeBaseClient + + mock_response = mocker.Mock() + mock_response.status_code = 500 + mocker.patch("requests.post", return_value=mock_response, side_effect=exceptions.RequestException("error")) + + return GriptapeCloudKnowledgeBaseClient( + base_url="https://api.griptape.ai", api_key="foo bar", knowledge_base_id="1" + ) + def test_query(self, client): assert isinstance(client.query({"values": {"query": "foo bar"}}), TextArtifact) + def test_query_error(self, client_kb_error): + assert isinstance(client_kb_error.query({"values": {"query": "foo bar"}}), ErrorArtifact) + assert client_kb_error.query({"values": {"query": "foo bar"}}).value == "error" + def test_get_knowledge_base_description(self, client): assert client._get_knowledge_base_description() == "fizz buzz" client.description = "foo bar" assert client._get_knowledge_base_description() == "foo bar" + + def test_get_knowledge_base_description_error(self, client_no_description): + exception_match_text = f"No description found for Knowledge Base {client_no_description.knowledge_base_id}. Please set a description, or manually set the `GriptapeCloudKnowledgeBaseClient.description` attribute." + with pytest.raises(ValueError, match=exception_match_text) as e: + client_no_description._get_knowledge_base_description() + + def test_get_knowledge_base_kb_error(self, client_kb_not_found): + exception_match_text = f"Error accessing Knowledge Base {client_kb_not_found.knowledge_base_id}." + with pytest.raises(ValueError, match=exception_match_text) as e: + client_kb_not_found._get_knowledge_base_description() diff --git a/tests/unit/tools/test_web_search.py b/tests/unit/tools/test_web_search.py index faa1056b7..c9a79b452 100644 --- a/tests/unit/tools/test_web_search.py +++ b/tests/unit/tools/test_web_search.py @@ -1,9 +1,34 @@ -from griptape.artifacts import BaseArtifact +from griptape.artifacts import BaseArtifact, ErrorArtifact from griptape.tools import WebSearch +from pytest import fixture +import json class TestWebSearch: - def test_search(self): - tool = WebSearch(google_api_key="foo", google_api_search_id="bar") + @fixture + def websearch_tool(self, mocker): + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"items": [{"title": "foo", "link": "bar", "snippet": "baz"}]} + mocker.patch("requests.get", return_value=mock_response) - assert isinstance(tool.search({"values": {"query": "foo bar"}}), BaseArtifact) + return WebSearch(google_api_key="foo", google_api_search_id="bar") + + @fixture + def websearch_tool_with_error(self, mocker): + mock_response = mocker.Mock() + mock_response.status_code = 500 + mocker.patch("requests.get", return_value=mock_response) + + return WebSearch(google_api_key="foo", google_api_search_id="bar") + + def test_search(self, websearch_tool): + assert isinstance(websearch_tool.search({"values": {"query": "foo bar"}}), BaseArtifact) + assert json.loads(websearch_tool.search({"values": {"query": "foo bar"}}).value[0].value) == { + "title": "foo", + "url": "bar", + "description": "baz", + } + + def test_search_with_error(self, websearch_tool_with_error): + assert isinstance(websearch_tool_with_error.search({"values": {"query": "foo bar"}}), ErrorArtifact) diff --git a/tests/unit/utils/test_prompt_stack.py b/tests/unit/utils/test_prompt_stack.py index 253e8cd44..80010abec 100644 --- a/tests/unit/utils/test_prompt_stack.py +++ b/tests/unit/utils/test_prompt_stack.py @@ -1,9 +1,5 @@ import pytest from griptape.utils import PromptStack -from tests.mocks.mock_prompt_driver import MockPromptDriver -from tests.mocks.mock_tokenizer import MockTokenizer -from griptape.structures.agent import Agent -from griptape.memory.structure import ConversationMemory, Run class TestPromptStack: @@ -43,97 +39,3 @@ def test_add_assistant_input(self, prompt_stack): assert prompt_stack.inputs[0].role == "assistant" assert prompt_stack.inputs[0].content == "foo" - - def test_add_conversation_memory_autopruing_disabled(self): - agent = Agent(prompt_driver=MockPromptDriver()) - memory = ConversationMemory( - autoprune=False, - runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), - ], - ) - memory.structure = agent - prompt_stack = PromptStack() - prompt_stack.add_user_input("foo") - prompt_stack.add_assistant_input("bar") - prompt_stack.add_conversation_memory(memory) - - assert len(prompt_stack.inputs) == 12 - - def test_add_conversation_memory_autopruing_enabled(self): - # All memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=0))) - memory = ConversationMemory( - autoprune=True, - runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), - ], - ) - memory.structure = agent - prompt_stack = PromptStack() - prompt_stack.add_system_input("fizz") - prompt_stack.add_user_input("foo") - prompt_stack.add_assistant_input("bar") - prompt_stack.add_conversation_memory(memory) - - assert len(prompt_stack.inputs) == 3 - - # No memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=1000))) - memory = ConversationMemory( - autoprune=True, - runs=[ - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), - ], - ) - memory.structure = agent - prompt_stack = PromptStack() - prompt_stack.add_system_input("fizz") - prompt_stack.add_user_input("foo") - prompt_stack.add_assistant_input("bar") - prompt_stack.add_conversation_memory(memory) - - assert len(prompt_stack.inputs) == 13 - - # One memory is pruned. - # MockTokenizer's max_input_tokens set to one below the sum of memory + system prompt tokens - # so that a single memory is pruned. - agent = Agent(prompt_driver=MockPromptDriver(tokenizer=MockTokenizer(model="foo", max_input_tokens=160))) - memory = ConversationMemory( - autoprune=True, - runs=[ - # All of these sum to 155 tokens with the MockTokenizer. - Run(input="foo1", output="bar1"), - Run(input="foo2", output="bar2"), - Run(input="foo3", output="bar3"), - Run(input="foo4", output="bar4"), - Run(input="foo5", output="bar5"), - ], - ) - memory.structure = agent - prompt_stack = PromptStack() - # And then another 6 tokens from fizz for a total of 161 tokens. - prompt_stack.add_system_input("fizz") - prompt_stack.add_user_input("foo") - prompt_stack.add_assistant_input("bar") - prompt_stack.add_conversation_memory(memory, 1) - - # We expect one run (2 prompt stack inputs) to be pruned. - assert len(prompt_stack.inputs) == 11 - assert prompt_stack.inputs[0].content == "fizz" - assert prompt_stack.inputs[1].content == "foo2" - assert prompt_stack.inputs[2].content == "bar2" - assert prompt_stack.inputs[-2].content == "foo" - assert prompt_stack.inputs[-1].content == "bar" diff --git a/tests/unit/utils/test_structure_visualizer.py b/tests/unit/utils/test_structure_visualizer.py new file mode 100644 index 000000000..d7177bf30 --- /dev/null +++ b/tests/unit/utils/test_structure_visualizer.py @@ -0,0 +1,52 @@ +from tests.mocks.mock_prompt_driver import MockPromptDriver +from griptape.utils import StructureVisualizer +from griptape.tasks import PromptTask +from griptape.structures import Agent, Workflow, Pipeline + + +class TestStructureVisualizer: + def test_agent(self): + agent = Agent(prompt_driver=MockPromptDriver(), tasks=[PromptTask("test1", id="task1")]) + + visualizer = StructureVisualizer(agent) + result = visualizer.to_url() + + assert result == "https://mermaid.ink/svg/Z3JhcGggVEQ7Cgl0YXNrMTs=" + + def test_pipeline(self): + pipeline = Pipeline( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1"), + PromptTask("test2", id="task2"), + PromptTask("test3", id="task3"), + PromptTask("test4", id="task4"), + ], + ) + + visualizer = StructureVisualizer(pipeline) + result = visualizer.to_url() + + assert ( + result + == "https://mermaid.ink/svg/Z3JhcGggVEQ7Cgl0YXNrMS0tPiB0YXNrMjsKCXRhc2syLS0+IHRhc2szOwoJdGFzazMtLT4gdGFzazQ7Cgl0YXNrNDs=" + ) + + def test_workflow(self): + workflow = Workflow( + prompt_driver=MockPromptDriver(), + tasks=[ + PromptTask("test1", id="task1"), + PromptTask("test2", id="task2", parent_ids=["task1"]), + PromptTask("test3", id="task3", parent_ids=["task1"]), + PromptTask("test4", id="task4", parent_ids=["task2", "task3"]), + ], + ) + + visualizer = StructureVisualizer(workflow) + result = visualizer.to_url() + + assert ( + result + == "https://mermaid.ink/svg/Z3JhcGggVEQ7Cgl0YXNrMS0tPiB0YXNrMiAmIHRhc2szOwoJdGFzazItLT4gdGFzazQ7Cgl0YXNrMy0tPiB0YXNrNDsKCXRhc2s0Ow==" + ) diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index 4f111b8d8..8d62bc835 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -14,17 +14,10 @@ BasePromptDriver, AmazonBedrockPromptDriver, AnthropicPromptDriver, - BedrockClaudePromptModelDriver, - BedrockJurassicPromptModelDriver, - BedrockTitanPromptModelDriver, - BedrockLlamaPromptModelDriver, CoherePromptDriver, OpenAiChatPromptDriver, - OpenAiCompletionPromptDriver, AzureOpenAiChatPromptDriver, - AmazonSageMakerPromptDriver, - SageMakerLlamaPromptModelDriver, - SageMakerFalconPromptModelDriver, + AmazonSageMakerJumpstartPromptDriver, GooglePromptDriver, ) @@ -53,12 +46,6 @@ class TesterPromptDriverOption: prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo-1106", api_key=os.environ["OPENAI_API_KEY"]), enabled=True, ), - "OPENAI_CHAT_35_TURBO_INSTRUCT": TesterPromptDriverOption( - prompt_driver=OpenAiCompletionPromptDriver( - model="gpt-3.5-turbo-instruct", api_key=os.environ["OPENAI_API_KEY"] - ), - enabled=True, - ), "OPENAI_CHAT_4": TesterPromptDriverOption( prompt_driver=OpenAiChatPromptDriver(model="gpt-4", api_key=os.environ["OPENAI_API_KEY"]), enabled=True ), @@ -69,10 +56,6 @@ class TesterPromptDriverOption: prompt_driver=OpenAiChatPromptDriver(model="gpt-4-1106-preview", api_key=os.environ["OPENAI_API_KEY"]), enabled=True, ), - "OPENAI_COMPLETION_DAVINCI": TesterPromptDriverOption( - prompt_driver=OpenAiCompletionPromptDriver(api_key=os.environ["OPENAI_API_KEY"], model="text-davinci-003"), - enabled=True, - ), "AZURE_CHAT_35_TURBO": TesterPromptDriverOption( prompt_driver=AzureOpenAiChatPromptDriver( api_key=os.environ["AZURE_OPENAI_API_KEY_1"], @@ -84,10 +67,10 @@ class TesterPromptDriverOption: ), "AZURE_CHAT_35_TURBO_16K": TesterPromptDriverOption( prompt_driver=AzureOpenAiChatPromptDriver( - api_key=os.environ["AZURE_OPENAI_API_KEY_1"], + api_key=os.environ["AZURE_OPENAI_API_KEY_2"], model="gpt-35-turbo-16k", azure_deployment=os.environ["AZURE_OPENAI_35_TURBO_16K_DEPLOYMENT_ID"], - azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"], + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_2"], ), enabled=True, ), @@ -142,71 +125,75 @@ class TesterPromptDriverOption: "COHERE_COMMAND": TesterPromptDriverOption( prompt_driver=CoherePromptDriver(model="command", api_key=os.environ["COHERE_API_KEY"]), enabled=True ), - "BEDROCK_TITAN": TesterPromptDriverOption( - prompt_driver=AmazonBedrockPromptDriver( - model="amazon.titan-tg1-large", prompt_model_driver=BedrockTitanPromptModelDriver() - ), - enabled=True, + "AMAZON_BEDROCK_ANTHROPIC_CLAUDE_3_SONNET": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="anthropic.claude-3-sonnet-20240229-v1:0"), enabled=True ), - "BEDROCK_CLAUDE_INSTANT": TesterPromptDriverOption( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-instant-v1", prompt_model_driver=BedrockClaudePromptModelDriver() - ), - enabled=True, + "AMAZON_BEDROCK_ANTHROPIC_CLAUDE_3_HAIKU": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="anthropic.claude-3-haiku-20240307-v1:0"), enabled=True ), - "BEDROCK_CLAUDE_2": TesterPromptDriverOption( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-v2", prompt_model_driver=BedrockClaudePromptModelDriver() - ), - enabled=True, + "AMAZON_BEDROCK_ANTHROPIC_CLAUDE_2": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="anthropic.claude-v2"), enabled=True ), - "BEDROCK_CLAUDE_2.1": TesterPromptDriverOption( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-v2:1", prompt_model_driver=BedrockClaudePromptModelDriver() - ), - enabled=True, + "AMAZON_BEDROCK_ANTHROPIC_CLAUDE_2.1": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="anthropic.claude-v2:1"), enabled=True ), - "BEDROCK_CLAUDE_3_SONNET": TesterPromptDriverOption( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-3-sonnet-20240229-v1:0", prompt_model_driver=BedrockClaudePromptModelDriver() - ), - enabled=True, + "AMAZON_BEDROCK_ANTHROPIC_CLAUDE_INSTANT": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="anthropic.claude-instant-v1"), enabled=True ), - "BEDROCK_CLAUDE_3_HAIKU": TesterPromptDriverOption( - prompt_driver=AmazonBedrockPromptDriver( - model="anthropic.claude-3-haiku-20240307-v1:0", prompt_model_driver=BedrockClaudePromptModelDriver() - ), - enabled=True, + "AMAZON_BEDROCK_J2_ULTRA": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="ai21.j2-ultra"), enabled=True ), - "BEDROCK_J2": TesterPromptDriverOption( - prompt_driver=AmazonBedrockPromptDriver( - model="ai21.j2-ultra", prompt_model_driver=BedrockJurassicPromptModelDriver() - ), - enabled=True, + "AMAZON_BEDROCK_J2_MID": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="ai21.j2-mid"), enabled=True ), - "BEDROCK_LLAMA2_13B": TesterPromptDriverOption( - prompt_driver=AmazonBedrockPromptDriver( - model="meta.llama2-13b-chat-v1", prompt_model_driver=BedrockLlamaPromptModelDriver(), max_attempts=1 - ), - enabled=True, + "AMAZON_BEDROCK_TITAN_TEXT_LITE": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="amazon.titan-text-lite-v1"), enabled=True ), - "BEDROCK_LLAMA2_70B": TesterPromptDriverOption( - prompt_driver=AmazonBedrockPromptDriver( - model="meta.llama2-70b-chat-v1", prompt_model_driver=BedrockLlamaPromptModelDriver(), max_attempts=1 - ), - enabled=True, + "AMAZON_BEDROCK_TITAN_TEXT_EXPRESS": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="amazon.titan-text-express-v1"), enabled=True + ), + "AMAZON_BEDROCK_COHERE_COMMAND_R_PLUS": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="cohere.command-r-plus-v1:0"), enabled=True + ), + "AMAZON_BEDROCK_COHERE_COMMAND_R": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="cohere.command-r-v1:0"), enabled=True + ), + "AMAZON_BEDROCK_COHERE_COMMAND_TEXT": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="cohere.command-text-v14"), enabled=True + ), + "AMAZON_BEDROCK_COHERE_COMMAND_LIGHT_TEXT": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="cohere.command-light-text-v14"), enabled=True + ), + "AMAZON_BEDROCK_LLAMA3_8B_INSTRUCT": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="meta.llama3-8b-instruct-v1:0"), enabled=True + ), + "AMAZON_BEDROCK_LLAMA2_13B_CHAT": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="meta.llama2-13b-chat-v1"), enabled=True + ), + "AMAZON_BEDROCK_LLAMA2_70B_CHAT": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="meta.llama2-70b-chat-v1"), enabled=True + ), + "AMAZON_BEDROCK_MISTRAL_7B_INSTRUCT": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="mistral.mistral-7b-instruct-v0:2"), enabled=True + ), + "AMAZON_BEDROCK_MISTRAL_MIXTRAL_8X7B_INSTRUCT": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="mistral.mixtral-8x7b-instruct-v0:1"), enabled=True + ), + "AMAZON_BEDROCK_MISTRAL_LARGE_2402": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="mistral.mistral-large-2402-v1:0"), enabled=True + ), + "AMAZON_BEDROCK_MISTRAL_SMALL_2402": TesterPromptDriverOption( + prompt_driver=AmazonBedrockPromptDriver(model="mistral.mistral-small-2402-v1:0"), enabled=True ), "SAGEMAKER_LLAMA_7B": TesterPromptDriverOption( - prompt_driver=AmazonSageMakerPromptDriver( - model=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"], - prompt_model_driver=SageMakerLlamaPromptModelDriver(max_tokens=4096), + prompt_driver=AmazonSageMakerJumpstartPromptDriver( + endpoint=os.environ["SAGEMAKER_LLAMA_ENDPOINT_NAME"], model="meta-llama/Llama-2-7b-chat-hf" ), enabled=False, ), "SAGEMAKER_FALCON_7b": TesterPromptDriverOption( - prompt_driver=AmazonSageMakerPromptDriver( - model=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"], - prompt_model_driver=SageMakerFalconPromptModelDriver(), + prompt_driver=AmazonSageMakerJumpstartPromptDriver( + endpoint=os.environ["SAGEMAKER_FALCON_ENDPOINT_NAME"], model="tiiuae/falcon-7b-instruct" ), enabled=False, ), @@ -221,6 +208,7 @@ class TesterPromptDriverOption: PROMPT_DRIVERS["AZURE_CHAT_4"], PROMPT_DRIVERS["AZURE_CHAT_4_32K"], PROMPT_DRIVERS["ANTHROPIC_CLAUDE_3_OPUS"], + PROMPT_DRIVERS["ANTHROPIC_CLAUDE_3_OPUS"], PROMPT_DRIVERS["GOOGLE_GEMINI_PRO"], ] ) @@ -249,7 +237,7 @@ def verify_structure_output(self, structure) -> dict: ) task_names = [task.__class__.__name__ for task in structure.tasks] prompt = structure.input_task.input.to_text() - actual = structure.output_task.output.to_text() + actual = structure.output.to_text() rules = [rule.value for ruleset in structure.input_task.all_rulesets for rule in ruleset.rules] agent = Agent(