From 078f0f150c73c84b4836a43e0df081780cf8c417 Mon Sep 17 00:00:00 2001 From: Andrew French Date: Thu, 23 May 2024 11:23:48 -0700 Subject: [PATCH] Update driver name to remove model specifics --- .../drivers/image-query-drivers.md | 8 ++++---- .../engines/image-query-engines.md | 6 +++--- docs/griptape-framework/structures/tasks.md | 5 ++--- .../official-tools/image-query-client.md | 8 ++++---- griptape/config/azure_openai_structure_config.py | 4 ++-- griptape/config/openai_structure_config.py | 6 ++---- griptape/drivers/__init__.py | 8 ++++---- ...river.py => azure_openai_image_query_driver.py} | 4 ++-- ...uery_driver.py => openai_image_query_driver.py} | 2 +- .../config/test_azure_openai_structure_config.py | 2 +- tests/unit/config/test_openai_structure_config.py | 2 +- .../test_azure_openai_image_query_driver.py | 14 ++++++-------- .../image_query/test_openai_image_query_driver.py | 10 +++++----- 13 files changed, 37 insertions(+), 42 deletions(-) rename griptape/drivers/image_query/{azure_openai_vision_image_query_driver.py => azure_openai_image_query_driver.py} (90%) rename griptape/drivers/image_query/{openai_vision_image_query_driver.py => openai_image_query_driver.py} (97%) diff --git a/docs/griptape-framework/drivers/image-query-drivers.md b/docs/griptape-framework/drivers/image-query-drivers.md index d569e69d8..e7745d418 100644 --- a/docs/griptape-framework/drivers/image-query-drivers.md +++ b/docs/griptape-framework/drivers/image-query-drivers.md @@ -67,11 +67,11 @@ print(result) The [OpenAiVisionImageQueryDriver](../../reference/griptape/drivers/image_query/openai_vision_image_query_driver.md) is used to query images using the OpenAI Vision API. Here is an example of how to use it: ```python -from griptape.drivers import OpenAiVisionImageQueryDriver +from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader -driver = OpenAiVisionImageQueryDriver( +driver = OpenAiImageQueryDriver( model="gpt-4o", max_tokens=256, ) @@ -95,11 +95,11 @@ The [AzureOpenAiVisionImageQueryDriver](../../reference/griptape/drivers/image_q ```python import os -from griptape.drivers import AzureOpenAiVisionImageQueryDriver +from griptape.drivers import AzureOpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader -driver = AzureOpenAiVisionImageQueryDriver( +driver = AzureOpenAiImageQueryDriver( azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"], model="gpt-4", diff --git a/docs/griptape-framework/engines/image-query-engines.md b/docs/griptape-framework/engines/image-query-engines.md index 400caf2d0..0457657f8 100644 --- a/docs/griptape-framework/engines/image-query-engines.md +++ b/docs/griptape-framework/engines/image-query-engines.md @@ -2,14 +2,14 @@ The [Image Query Engine](../../reference/griptape/engines/image_query/image_query_engine.md) is used to execute natural language queries on the contents of images. You can specify the provider and model used to query the image by providing the Engine with a particular [Image Query Driver](../drivers/image-query-drivers.md). -All Image Query Drivers default to a `max_tokens` of 256. You can tune this value based on your use case and the [Image Query Driver](../drivers/image-query-drivers.md) you are providing. +All Image Query Drivers default to a `max_tokens` of 256. You can tune this value based on your use case and the [Image Query Driver](../drivers/image-query-drivers.md) you are providing. ```python -from griptape.drivers import OpenAiVisionImageQueryDriver +from griptape.drivers import OpenAiImageQueryDriver from griptape.engines import ImageQueryEngine from griptape.loaders import ImageLoader -driver = OpenAiVisionImageQueryDriver( +driver = OpenAiImageQueryDriver( model="gpt-4o", max_tokens=256 ) diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index 03c96323f..683648cef 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -633,14 +633,13 @@ This Task accepts two inputs: a query (represented by either a string or a [Text ```python from griptape.engines import ImageQueryEngine -from griptape.drivers import OpenAiVisionImageQueryDriver +from griptape.drivers import OpenAiImageQueryDriver from griptape.tasks import ImageQueryTask from griptape.loaders import ImageLoader from griptape.structures import Pipeline - # Create a driver configured to use OpenAI's GPT-4 Vision model. -driver = OpenAiVisionImageQueryDriver( +driver = OpenAiImageQueryDriver( model="gpt-4o", max_tokens=100, ) diff --git a/docs/griptape-tools/official-tools/image-query-client.md b/docs/griptape-tools/official-tools/image-query-client.md index fb5dcca33..62b6f6f52 100644 --- a/docs/griptape-tools/official-tools/image-query-client.md +++ b/docs/griptape-tools/official-tools/image-query-client.md @@ -5,14 +5,14 @@ This tool allows Agents to execute natural language queries on the contents of i ```python from griptape.structures import Agent from griptape.tools import ImageQueryClient -from griptape.drivers import OpenAiVisionImageQueryDriver -from griptape.engines import ImageQueryEngine +from griptape.drivers import OpenAiImageQueryDriver +from griptape.engines import ImageQueryEngine # Create an Image Query Driver. -driver = OpenAiVisionImageQueryDriver( +driver = OpenAiImageQueryDriver( model="gpt-4-vision-preview" ) - + # Create an Image Query Engine configured to use the driver. engine = ImageQueryEngine( image_query_driver=driver, diff --git a/griptape/config/azure_openai_structure_config.py b/griptape/config/azure_openai_structure_config.py index 8c7ce82d3..7ed08bdd9 100644 --- a/griptape/config/azure_openai_structure_config.py +++ b/griptape/config/azure_openai_structure_config.py @@ -7,7 +7,7 @@ AzureOpenAiChatPromptDriver, AzureOpenAiEmbeddingDriver, AzureOpenAiImageGenerationDriver, - AzureOpenAiVisionImageQueryDriver, + AzureOpenAiImageQueryDriver, BasePromptDriver, BaseEmbeddingDriver, BaseImageGenerationDriver, @@ -69,7 +69,7 @@ class AzureOpenAiStructureConfig(StructureConfig): ) image_query_driver: BaseImageQueryDriver = field( default=Factory( - lambda self: AzureOpenAiVisionImageQueryDriver( + lambda self: AzureOpenAiImageQueryDriver( model="gpt-4", azure_endpoint=self.azure_endpoint, api_key=self.api_key, diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_structure_config.py index b4a3cd6c3..459160b11 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_structure_config.py @@ -11,7 +11,7 @@ OpenAiChatPromptDriver, OpenAiEmbeddingDriver, OpenAiImageGenerationDriver, - OpenAiVisionImageQueryDriver, + OpenAiImageQueryDriver, ) @@ -26,9 +26,7 @@ class OpenAiStructureConfig(StructureConfig): metadata={"serializable": True}, ) image_query_driver: BaseImageQueryDriver = field( - default=Factory(lambda: OpenAiVisionImageQueryDriver(model="gpt-4o")), - kw_only=True, - metadata={"serializable": True}, + default=Factory(lambda: OpenAiImageQueryDriver(model="gpt-4o")), kw_only=True, metadata={"serializable": True} ) embedding_driver: BaseEmbeddingDriver = field( default=Factory(lambda: OpenAiEmbeddingDriver(model="text-embedding-3-small")), diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 1d76ba30d..6043b0b43 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -79,9 +79,9 @@ from .image_query.base_image_query_driver import BaseImageQueryDriver from .image_query.base_multi_model_image_query_driver import BaseMultiModelImageQueryDriver from .image_query.dummy_image_query_driver import DummyImageQueryDriver -from .image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver +from .image_query.openai_image_query_driver import OpenAiImageQueryDriver from .image_query.anthropic_image_query_driver import AnthropicImageQueryDriver -from .image_query.azure_openai_vision_image_query_driver import AzureOpenAiVisionImageQueryDriver +from .image_query.azure_openai_image_query_driver import AzureOpenAiImageQueryDriver from .image_query.amazon_bedrock_image_query_driver import AmazonBedrockImageQueryDriver from .web_scraper.base_web_scraper_driver import BaseWebScraperDriver @@ -175,8 +175,8 @@ "BaseImageQueryModelDriver", "BedrockClaudeImageQueryModelDriver", "BaseImageQueryDriver", - "OpenAiVisionImageQueryDriver", - "AzureOpenAiVisionImageQueryDriver", + "OpenAiImageQueryDriver", + "AzureOpenAiImageQueryDriver", "DummyImageQueryDriver", "AnthropicImageQueryDriver", "BaseMultiModelImageQueryDriver", diff --git a/griptape/drivers/image_query/azure_openai_vision_image_query_driver.py b/griptape/drivers/image_query/azure_openai_image_query_driver.py similarity index 90% rename from griptape/drivers/image_query/azure_openai_vision_image_query_driver.py rename to griptape/drivers/image_query/azure_openai_image_query_driver.py index 06b329c03..017c98f1f 100644 --- a/griptape/drivers/image_query/azure_openai_vision_image_query_driver.py +++ b/griptape/drivers/image_query/azure_openai_image_query_driver.py @@ -4,11 +4,11 @@ from attr import define, field, Factory import openai -from griptape.drivers.image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver +from griptape.drivers.image_query.openai_image_query_driver import OpenAiImageQueryDriver @define -class AzureOpenAiVisionImageQueryDriver(OpenAiVisionImageQueryDriver): +class AzureOpenAiImageQueryDriver(OpenAiImageQueryDriver): """Driver for Azure-hosted OpenAI image query API. Attributes: diff --git a/griptape/drivers/image_query/openai_vision_image_query_driver.py b/griptape/drivers/image_query/openai_image_query_driver.py similarity index 97% rename from griptape/drivers/image_query/openai_vision_image_query_driver.py rename to griptape/drivers/image_query/openai_image_query_driver.py index c37cebba9..515bdcc7c 100644 --- a/griptape/drivers/image_query/openai_vision_image_query_driver.py +++ b/griptape/drivers/image_query/openai_image_query_driver.py @@ -16,7 +16,7 @@ @define -class OpenAiVisionImageQueryDriver(BaseImageQueryDriver): +class OpenAiImageQueryDriver(BaseImageQueryDriver): model: str = field(kw_only=True, metadata={"serializable": True}) api_type: str = field(default=openai.api_type, kw_only=True) api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True}) diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_structure_config.py index acec4526a..0cdeba043 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_structure_config.py @@ -67,7 +67,7 @@ def test_to_dict(self, config): "azure_deployment": "gpt-4", "azure_endpoint": "http://localhost:8080", "organization": None, - "type": "AzureOpenAiVisionImageQueryDriver", + "type": "AzureOpenAiImageQueryDriver", }, "vector_store_driver": { "embedding_driver": { diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index c2f0ea7b8..efcac008e 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -52,7 +52,7 @@ def test_to_dict(self, config): "max_tokens": 256, "model": "gpt-4o", "organization": None, - "type": "OpenAiVisionImageQueryDriver", + "type": "OpenAiImageQueryDriver", }, "vector_store_driver": { "embedding_driver": { diff --git a/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py b/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py index d3a34fe44..a44319861 100644 --- a/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_azure_openai_image_query_driver.py @@ -1,6 +1,6 @@ import pytest from unittest.mock import Mock -from griptape.drivers import AzureOpenAiVisionImageQueryDriver +from griptape.drivers import AzureOpenAiImageQueryDriver from griptape.artifacts import ImageArtifact @@ -13,15 +13,13 @@ def mock_completion_create(self, mocker): return mock_chat_create def test_init(self): - assert AzureOpenAiVisionImageQueryDriver( + assert AzureOpenAiImageQueryDriver( azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4" ) - assert ( - AzureOpenAiVisionImageQueryDriver(azure_endpoint="test-endpoint", model="gpt-4").azure_deployment == "gpt-4" - ) + assert AzureOpenAiImageQueryDriver(azure_endpoint="test-endpoint", model="gpt-4").azure_deployment == "gpt-4" def test_try_query_defaults(self, mock_completion_create): - driver = AzureOpenAiVisionImageQueryDriver( + driver = AzureOpenAiImageQueryDriver( azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4" ) test_prompt_string = "Prompt String" @@ -36,7 +34,7 @@ def test_try_query_defaults(self, mock_completion_create): assert text_artifact.value == "expected_output_text" def test_try_query_max_tokens(self, mock_completion_create): - driver = AzureOpenAiVisionImageQueryDriver( + driver = AzureOpenAiImageQueryDriver( azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4", max_tokens=1024 ) test_prompt_string = "Prompt String" @@ -50,7 +48,7 @@ def test_try_query_max_tokens(self, mock_completion_create): def test_try_query_multiple_choices(self, mock_completion_create): mock_completion_create.return_value.choices.append(Mock(message=Mock(content="expected_output_text2"))) - driver = AzureOpenAiVisionImageQueryDriver( + driver = AzureOpenAiImageQueryDriver( azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4" ) diff --git a/tests/unit/drivers/image_query/test_openai_image_query_driver.py b/tests/unit/drivers/image_query/test_openai_image_query_driver.py index d6ce8550b..08f0c70c9 100644 --- a/tests/unit/drivers/image_query/test_openai_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_openai_image_query_driver.py @@ -1,6 +1,6 @@ import pytest from unittest.mock import Mock -from griptape.drivers import OpenAiVisionImageQueryDriver +from griptape.drivers import OpenAiImageQueryDriver from griptape.artifacts import ImageArtifact @@ -13,10 +13,10 @@ def mock_completion_create(self, mocker): return mock_chat_create def test_init(self): - assert OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview") + assert OpenAiImageQueryDriver(model="gpt-4-vision-preview") def test_try_query_defaults(self, mock_completion_create): - driver = OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview") + driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview") test_prompt_string = "Prompt String" test_binary_data = b"test-data" test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png") @@ -29,7 +29,7 @@ def test_try_query_defaults(self, mock_completion_create): assert text_artifact.value == "expected_output_text" def test_try_query_max_tokens(self, mock_completion_create): - driver = OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview", max_tokens=1024) + driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview", max_tokens=1024) test_prompt_string = "Prompt String" test_binary_data = b"test-data" test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png") @@ -41,7 +41,7 @@ def test_try_query_max_tokens(self, mock_completion_create): def test_try_query_multiple_choices(self, mock_completion_create): mock_completion_create.return_value.choices.append(Mock(message=Mock(content="expected_output_text2"))) - driver = OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview") + driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview") with pytest.raises(Exception): driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")])