Skip to content

Commit

Permalink
Add AzureOpenAiStructureConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed May 17, 2024
1 parent f1a2dba commit 0998f4b
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 3 deletions.
2 changes: 2 additions & 0 deletions griptape/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .structure_config import StructureConfig
from .openai_structure_config import OpenAiStructureConfig
from .azure_openai_structure_config import AzureOpenAiStructureConfig
from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig
from .anthropic_structure_config import AnthropicStructureConfig
from .google_structure_config import GoogleStructureConfig
Expand All @@ -28,6 +29,7 @@
"StructureTaskMemoryExtractionEngineJsonConfig",
"StructureConfig",
"OpenAiStructureConfig",
"AzureOpenAiStructureConfig",
"AmazonBedrockStructureConfig",
"AnthropicStructureConfig",
"GoogleStructureConfig",
Expand Down
106 changes: 106 additions & 0 deletions griptape/config/azure_openai_structure_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Callable, Optional
from attrs import Factory, define, field

from griptape.config import (
BaseStructureConfig,
StructureGlobalDriversConfig,
StructureTaskMemoryConfig,
StructureTaskMemoryExtractionEngineConfig,
StructureTaskMemoryExtractionEngineCsvConfig,
StructureTaskMemoryExtractionEngineJsonConfig,
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.drivers import (
LocalVectorStoreDriver,
AzureOpenAiChatPromptDriver,
AzureOpenAiEmbeddingDriver,
AzureOpenAiImageGenerationDriver,
AzureOpenAiVisionImageQueryDriver,
)


@define
class AzureOpenAiStructureConfig(BaseStructureConfig):
"""Azure OpenAI Structure Configuration.
Attributes:
azure_endpoint: The endpoint for the Azure OpenAI instance.
azure_deployment_prefix: An optional prefix for Azure mode deployment names. By default, the deployment names are the same as the default models.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_key: An optional Azure API key.
global_drivers: A `StructureGlobalDriversConfig` instance.
task_memory: A `StructureTaskMemoryConfig` instance.
"""

azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_deployment_prefix: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
api_key: str = field(kw_only=True, default=None, metadata={"serializable": False})
global_drivers: StructureGlobalDriversConfig = field(
default=Factory(lambda self: _global_drivers(self), takes_self=True),
kw_only=True,
metadata={"serializable": True},
)
task_memory: StructureTaskMemoryConfig = field(
default=Factory(
lambda self: StructureTaskMemoryConfig(
query_engine=StructureTaskMemoryQueryEngineConfig(
prompt_driver=self.global_drivers.prompt_driver,
vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver),
),
extraction_engine=StructureTaskMemoryExtractionEngineConfig(
csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver),
json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver),
),
summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver),
),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
)


def _global_drivers(config: AzureOpenAiStructureConfig) -> StructureGlobalDriversConfig:
embedding_driver = AzureOpenAiEmbeddingDriver(
model="text-embedding-3-small",
azure_deployment=f"{config.azure_deployment_prefix}text-embedding-3-small",
azure_endpoint=config.azure_endpoint,
api_key=config.api_key,
azure_ad_token=config.azure_ad_token,
azure_ad_token_provider=config.azure_ad_token_provider,
)
return StructureGlobalDriversConfig(
prompt_driver=AzureOpenAiChatPromptDriver(
model="gpt-4o",
azure_deployment=f"{config.azure_deployment_prefix}gpt-4o",
azure_endpoint=config.azure_endpoint,
api_key=config.api_key,
azure_ad_token=config.azure_ad_token,
azure_ad_token_provider=config.azure_ad_token_provider,
),
image_generation_driver=AzureOpenAiImageGenerationDriver(
model="dall-e-2",
azure_deployment=f"{config.azure_deployment_prefix}dall-e-2",
azure_endpoint=config.azure_endpoint,
api_key=config.api_key,
azure_ad_token=config.azure_ad_token,
azure_ad_token_provider=config.azure_ad_token_provider,
image_size="512x512",
),
image_query_driver=AzureOpenAiVisionImageQueryDriver(
model="gpt-4-vision-preview",
azure_deployment=f"{config.azure_deployment_prefix}gpt-4-vision-preview",
azure_endpoint=config.azure_endpoint,
api_key=config.api_key,
azure_ad_token=config.azure_ad_token,
azure_ad_token_provider=config.azure_ad_token_provider,
),
embedding_driver=embedding_driver,
vector_store_driver=LocalVectorStoreDriver(embedding_driver=embedding_driver),
)
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from .image_query.dummy_image_query_driver import DummyImageQueryDriver
from .image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver
from .image_query.anthropic_image_query_driver import AnthropicImageQueryDriver
from .image_query.azure_openai_vision_image_query_driver import AzureOpenAiVisionImageQueryDriver
from .image_query.amazon_bedrock_image_query_driver import AmazonBedrockImageQueryDriver

from .web_scraper.base_web_scraper_driver import BaseWebScraperDriver
Expand Down Expand Up @@ -175,6 +176,7 @@
"BedrockClaudeImageQueryModelDriver",
"BaseImageQueryDriver",
"OpenAiVisionImageQueryDriver",
"AzureOpenAiVisionImageQueryDriver",
"DummyImageQueryDriver",
"AnthropicImageQueryDriver",
"BaseMultiModelImageQueryDriver",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver):

azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from typing import Callable, Optional

from attr import define, field, Factory
from openai.types.chat import (
ChatCompletionUserMessageParam,
ChatCompletionContentPartParam,
ChatCompletionContentPartTextParam,
ChatCompletionContentPartImageParam,
)
import openai
from griptape.drivers.image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver


@define
class AzureOpenAiVisionImageQueryDriver(OpenAiVisionImageQueryDriver):
"""Driver for Azure-hosted OpenAI image query API.
Attributes:
azure_deployment: An Azure OpenAi deployment id.
azure_endpoint: An Azure OpenAi endpoint.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_version: An Azure OpenAi API version.
client: An `openai.AzureOpenAI` client.
"""

azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True})
client: openai.AzureOpenAI = field(
default=Factory(
lambda self: openai.AzureOpenAI(
organization=self.organization,
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
)
)
2 changes: 1 addition & 1 deletion griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):

azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):

azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
Expand Down
181 changes: 181 additions & 0 deletions tests/unit/config/test_azure_openai_structure_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from pytest import fixture
from griptape.config import AzureOpenAiStructureConfig


class TestAzureOpenAiStructureConfig:
@fixture(autouse=True)
def mock_openai(self, mocker):
return mocker.patch("openai.AzureOpenAI")

@fixture
def config(self):
return AzureOpenAiStructureConfig(azure_endpoint="http://localhost:8080", azure_deployment_prefix="test-")

def test_to_dict(self, config):
assert config.to_dict() == {
"type": "AzureOpenAiStructureConfig",
"azure_endpoint": "http://localhost:8080",
"azure_deployment_prefix": "test-",
"global_drivers": {
"type": "StructureGlobalDriversConfig",
"prompt_driver": {
"type": "AzureOpenAiChatPromptDriver",
"base_url": None,
"model": "gpt-4o",
"azure_deployment": "test-gpt-4o",
"azure_endpoint": "http://localhost:8080",
"api_version": "2023-05-15",
"organization": None,
"response_format": None,
"seed": None,
"temperature": 0.1,
"max_tokens": None,
"stream": False,
"user": "",
},
"conversation_memory_driver": None,
"embedding_driver": {
"base_url": None,
"model": "text-embedding-3-small",
"api_version": "2023-05-15",
"azure_deployment": "test-text-embedding-3-small",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"type": "AzureOpenAiEmbeddingDriver",
},
"image_generation_driver": {
"api_version": "2024-02-01",
"base_url": None,
"image_size": "512x512",
"model": "dall-e-2",
"azure_deployment": "test-dall-e-2",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"quality": "standard",
"response_format": "b64_json",
"style": None,
"type": "AzureOpenAiImageGenerationDriver",
},
"text_to_speech_driver": {"type": "DummyTextToSpeechDriver"},
"image_query_driver": {
"api_version": None,
"base_url": None,
"image_quality": "auto",
"max_tokens": 256,
"model": "gpt-4-vision-preview",
"api_version": "2024-02-01",
"azure_deployment": "test-gpt-4-vision-preview",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"type": "AzureOpenAiVisionImageQueryDriver",
},
"vector_store_driver": {
"embedding_driver": {
"base_url": None,
"model": "text-embedding-3-small",
"api_version": "2023-05-15",
"azure_deployment": "test-text-embedding-3-small",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"type": "AzureOpenAiEmbeddingDriver",
},
"type": "LocalVectorStoreDriver",
},
},
"task_memory": {
"type": "StructureTaskMemoryConfig",
"query_engine": {
"type": "StructureTaskMemoryQueryEngineConfig",
"prompt_driver": {
"base_url": None,
"type": "AzureOpenAiChatPromptDriver",
"api_version": "2023-05-15",
"model": "gpt-4o",
"azure_deployment": "test-gpt-4o",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"response_format": None,
"seed": None,
"temperature": 0.1,
"max_tokens": None,
"stream": False,
"user": "",
},
"vector_store_driver": {
"type": "LocalVectorStoreDriver",
"embedding_driver": {
"type": "AzureOpenAiEmbeddingDriver",
"base_url": None,
"api_version": "2023-05-15",
"organization": None,
"model": "text-embedding-3-small",
"azure_deployment": "test-text-embedding-3-small",
"azure_endpoint": "http://localhost:8080",
},
},
},
"extraction_engine": {
"type": "StructureTaskMemoryExtractionEngineConfig",
"csv": {
"type": "StructureTaskMemoryExtractionEngineCsvConfig",
"prompt_driver": {
"type": "AzureOpenAiChatPromptDriver",
"api_version": "2023-05-15",
"base_url": None,
"model": "gpt-4o",
"azure_deployment": "test-gpt-4o",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"response_format": None,
"seed": None,
"temperature": 0.1,
"max_tokens": None,
"stream": False,
"user": "",
},
},
"json": {
"type": "StructureTaskMemoryExtractionEngineJsonConfig",
"prompt_driver": {
"type": "AzureOpenAiChatPromptDriver",
"api_version": "2023-05-15",
"base_url": None,
"model": "gpt-4o",
"azure_deployment": "test-gpt-4o",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"response_format": None,
"seed": None,
"temperature": 0.1,
"max_tokens": None,
"stream": False,
"user": "",
},
},
},
"summary_engine": {
"type": "StructureTaskMemorySummaryEngineConfig",
"prompt_driver": {
"api_version": "2023-05-15",
"type": "AzureOpenAiChatPromptDriver",
"base_url": None,
"model": "gpt-4o",
"azure_deployment": "test-gpt-4o",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"response_format": None,
"seed": None,
"temperature": 0.1,
"max_tokens": None,
"stream": False,
"user": "",
},
},
},
}

def test_from_dict(self, config):
assert AzureOpenAiStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict()
new_config = config.to_dict()
new_config["global_drivers"]["prompt_driver"]["azure_deployment"] = "new-test-gpt-4o"
assert AzureOpenAiStructureConfig.from_dict(new_config).to_dict() == new_config

0 comments on commit 0998f4b

Please sign in to comment.