Skip to content

Commit

Permalink
Update driver name to remove model specifics
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfrench committed May 23, 2024
1 parent 1146304 commit 078f0f1
Show file tree
Hide file tree
Showing 13 changed files with 37 additions and 42 deletions.
8 changes: 4 additions & 4 deletions docs/griptape-framework/drivers/image-query-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ print(result)
The [OpenAiVisionImageQueryDriver](../../reference/griptape/drivers/image_query/openai_vision_image_query_driver.md) is used to query images using the OpenAI Vision API. Here is an example of how to use it:

```python
from griptape.drivers import OpenAiVisionImageQueryDriver
from griptape.drivers import OpenAiImageQueryDriver
from griptape.engines import ImageQueryEngine
from griptape.loaders import ImageLoader

driver = OpenAiVisionImageQueryDriver(
driver = OpenAiImageQueryDriver(
model="gpt-4o",
max_tokens=256,
)
Expand All @@ -95,11 +95,11 @@ The [AzureOpenAiVisionImageQueryDriver](../../reference/griptape/drivers/image_q

```python
import os
from griptape.drivers import AzureOpenAiVisionImageQueryDriver
from griptape.drivers import AzureOpenAiImageQueryDriver
from griptape.engines import ImageQueryEngine
from griptape.loaders import ImageLoader

driver = AzureOpenAiVisionImageQueryDriver(
driver = AzureOpenAiImageQueryDriver(
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"],
api_key=os.environ["AZURE_OPENAI_API_KEY_3"],
model="gpt-4",
Expand Down
6 changes: 3 additions & 3 deletions docs/griptape-framework/engines/image-query-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

The [Image Query Engine](../../reference/griptape/engines/image_query/image_query_engine.md) is used to execute natural language queries on the contents of images. You can specify the provider and model used to query the image by providing the Engine with a particular [Image Query Driver](../drivers/image-query-drivers.md).

All Image Query Drivers default to a `max_tokens` of 256. You can tune this value based on your use case and the [Image Query Driver](../drivers/image-query-drivers.md) you are providing.
All Image Query Drivers default to a `max_tokens` of 256. You can tune this value based on your use case and the [Image Query Driver](../drivers/image-query-drivers.md) you are providing.

```python
from griptape.drivers import OpenAiVisionImageQueryDriver
from griptape.drivers import OpenAiImageQueryDriver
from griptape.engines import ImageQueryEngine
from griptape.loaders import ImageLoader

driver = OpenAiVisionImageQueryDriver(
driver = OpenAiImageQueryDriver(
model="gpt-4o",
max_tokens=256
)
Expand Down
5 changes: 2 additions & 3 deletions docs/griptape-framework/structures/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -633,14 +633,13 @@ This Task accepts two inputs: a query (represented by either a string or a [Text

```python
from griptape.engines import ImageQueryEngine
from griptape.drivers import OpenAiVisionImageQueryDriver
from griptape.drivers import OpenAiImageQueryDriver
from griptape.tasks import ImageQueryTask
from griptape.loaders import ImageLoader
from griptape.structures import Pipeline


# Create a driver configured to use OpenAI's GPT-4 Vision model.
driver = OpenAiVisionImageQueryDriver(
driver = OpenAiImageQueryDriver(
model="gpt-4o",
max_tokens=100,
)
Expand Down
8 changes: 4 additions & 4 deletions docs/griptape-tools/official-tools/image-query-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ This tool allows Agents to execute natural language queries on the contents of i
```python
from griptape.structures import Agent
from griptape.tools import ImageQueryClient
from griptape.drivers import OpenAiVisionImageQueryDriver
from griptape.engines import ImageQueryEngine
from griptape.drivers import OpenAiImageQueryDriver
from griptape.engines import ImageQueryEngine

# Create an Image Query Driver.
driver = OpenAiVisionImageQueryDriver(
driver = OpenAiImageQueryDriver(
model="gpt-4-vision-preview"
)

# Create an Image Query Engine configured to use the driver.
engine = ImageQueryEngine(
image_query_driver=driver,
Expand Down
4 changes: 2 additions & 2 deletions griptape/config/azure_openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
AzureOpenAiChatPromptDriver,
AzureOpenAiEmbeddingDriver,
AzureOpenAiImageGenerationDriver,
AzureOpenAiVisionImageQueryDriver,
AzureOpenAiImageQueryDriver,
BasePromptDriver,
BaseEmbeddingDriver,
BaseImageGenerationDriver,
Expand Down Expand Up @@ -69,7 +69,7 @@ class AzureOpenAiStructureConfig(StructureConfig):
)
image_query_driver: BaseImageQueryDriver = field(
default=Factory(
lambda self: AzureOpenAiVisionImageQueryDriver(
lambda self: AzureOpenAiImageQueryDriver(
model="gpt-4",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
Expand Down
6 changes: 2 additions & 4 deletions griptape/config/openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
OpenAiChatPromptDriver,
OpenAiEmbeddingDriver,
OpenAiImageGenerationDriver,
OpenAiVisionImageQueryDriver,
OpenAiImageQueryDriver,
)


Expand All @@ -26,9 +26,7 @@ class OpenAiStructureConfig(StructureConfig):
metadata={"serializable": True},
)
image_query_driver: BaseImageQueryDriver = field(
default=Factory(lambda: OpenAiVisionImageQueryDriver(model="gpt-4o")),
kw_only=True,
metadata={"serializable": True},
default=Factory(lambda: OpenAiImageQueryDriver(model="gpt-4o")), kw_only=True, metadata={"serializable": True}
)
embedding_driver: BaseEmbeddingDriver = field(
default=Factory(lambda: OpenAiEmbeddingDriver(model="text-embedding-3-small")),
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@
from .image_query.base_image_query_driver import BaseImageQueryDriver
from .image_query.base_multi_model_image_query_driver import BaseMultiModelImageQueryDriver
from .image_query.dummy_image_query_driver import DummyImageQueryDriver
from .image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver
from .image_query.openai_image_query_driver import OpenAiImageQueryDriver
from .image_query.anthropic_image_query_driver import AnthropicImageQueryDriver
from .image_query.azure_openai_vision_image_query_driver import AzureOpenAiVisionImageQueryDriver
from .image_query.azure_openai_image_query_driver import AzureOpenAiImageQueryDriver
from .image_query.amazon_bedrock_image_query_driver import AmazonBedrockImageQueryDriver

from .web_scraper.base_web_scraper_driver import BaseWebScraperDriver
Expand Down Expand Up @@ -175,8 +175,8 @@
"BaseImageQueryModelDriver",
"BedrockClaudeImageQueryModelDriver",
"BaseImageQueryDriver",
"OpenAiVisionImageQueryDriver",
"AzureOpenAiVisionImageQueryDriver",
"OpenAiImageQueryDriver",
"AzureOpenAiImageQueryDriver",
"DummyImageQueryDriver",
"AnthropicImageQueryDriver",
"BaseMultiModelImageQueryDriver",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from attr import define, field, Factory
import openai
from griptape.drivers.image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver
from griptape.drivers.image_query.openai_image_query_driver import OpenAiImageQueryDriver


@define
class AzureOpenAiVisionImageQueryDriver(OpenAiVisionImageQueryDriver):
class AzureOpenAiImageQueryDriver(OpenAiImageQueryDriver):
"""Driver for Azure-hosted OpenAI image query API.
Attributes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@define
class OpenAiVisionImageQueryDriver(BaseImageQueryDriver):
class OpenAiImageQueryDriver(BaseImageQueryDriver):
model: str = field(kw_only=True, metadata={"serializable": True})
api_type: str = field(default=openai.api_type, kw_only=True)
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/config/test_azure_openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_to_dict(self, config):
"azure_deployment": "gpt-4",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"type": "AzureOpenAiVisionImageQueryDriver",
"type": "AzureOpenAiImageQueryDriver",
},
"vector_store_driver": {
"embedding_driver": {
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/config/test_openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_to_dict(self, config):
"max_tokens": 256,
"model": "gpt-4o",
"organization": None,
"type": "OpenAiVisionImageQueryDriver",
"type": "OpenAiImageQueryDriver",
},
"vector_store_driver": {
"embedding_driver": {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from unittest.mock import Mock
from griptape.drivers import AzureOpenAiVisionImageQueryDriver
from griptape.drivers import AzureOpenAiImageQueryDriver
from griptape.artifacts import ImageArtifact


Expand All @@ -13,15 +13,13 @@ def mock_completion_create(self, mocker):
return mock_chat_create

def test_init(self):
assert AzureOpenAiVisionImageQueryDriver(
assert AzureOpenAiImageQueryDriver(
azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4"
)
assert (
AzureOpenAiVisionImageQueryDriver(azure_endpoint="test-endpoint", model="gpt-4").azure_deployment == "gpt-4"
)
assert AzureOpenAiImageQueryDriver(azure_endpoint="test-endpoint", model="gpt-4").azure_deployment == "gpt-4"

def test_try_query_defaults(self, mock_completion_create):
driver = AzureOpenAiVisionImageQueryDriver(
driver = AzureOpenAiImageQueryDriver(
azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4"
)
test_prompt_string = "Prompt String"
Expand All @@ -36,7 +34,7 @@ def test_try_query_defaults(self, mock_completion_create):
assert text_artifact.value == "expected_output_text"

def test_try_query_max_tokens(self, mock_completion_create):
driver = AzureOpenAiVisionImageQueryDriver(
driver = AzureOpenAiImageQueryDriver(
azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4", max_tokens=1024
)
test_prompt_string = "Prompt String"
Expand All @@ -50,7 +48,7 @@ def test_try_query_max_tokens(self, mock_completion_create):

def test_try_query_multiple_choices(self, mock_completion_create):
mock_completion_create.return_value.choices.append(Mock(message=Mock(content="expected_output_text2")))
driver = AzureOpenAiVisionImageQueryDriver(
driver = AzureOpenAiImageQueryDriver(
azure_endpoint="test-endpoint", azure_deployment="test-deployment", model="gpt-4"
)

Expand Down
10 changes: 5 additions & 5 deletions tests/unit/drivers/image_query/test_openai_image_query_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from unittest.mock import Mock
from griptape.drivers import OpenAiVisionImageQueryDriver
from griptape.drivers import OpenAiImageQueryDriver
from griptape.artifacts import ImageArtifact


Expand All @@ -13,10 +13,10 @@ def mock_completion_create(self, mocker):
return mock_chat_create

def test_init(self):
assert OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview")
assert OpenAiImageQueryDriver(model="gpt-4-vision-preview")

def test_try_query_defaults(self, mock_completion_create):
driver = OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview")
driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview")
test_prompt_string = "Prompt String"
test_binary_data = b"test-data"
test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png")
Expand All @@ -29,7 +29,7 @@ def test_try_query_defaults(self, mock_completion_create):
assert text_artifact.value == "expected_output_text"

def test_try_query_max_tokens(self, mock_completion_create):
driver = OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview", max_tokens=1024)
driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview", max_tokens=1024)
test_prompt_string = "Prompt String"
test_binary_data = b"test-data"
test_image = ImageArtifact(value=test_binary_data, width=100, height=100, format="png")
Expand All @@ -41,7 +41,7 @@ def test_try_query_max_tokens(self, mock_completion_create):

def test_try_query_multiple_choices(self, mock_completion_create):
mock_completion_create.return_value.choices.append(Mock(message=Mock(content="expected_output_text2")))
driver = OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview")
driver = OpenAiImageQueryDriver(model="gpt-4-vision-preview")

with pytest.raises(Exception):
driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")])
Expand Down

0 comments on commit 078f0f1

Please sign in to comment.