-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
343 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from typing import Callable, Optional | ||
from attrs import Factory, define, field | ||
|
||
from griptape.config import ( | ||
BaseStructureConfig, | ||
StructureGlobalDriversConfig, | ||
StructureTaskMemoryConfig, | ||
StructureTaskMemoryExtractionEngineConfig, | ||
StructureTaskMemoryExtractionEngineCsvConfig, | ||
StructureTaskMemoryExtractionEngineJsonConfig, | ||
StructureTaskMemoryQueryEngineConfig, | ||
StructureTaskMemorySummaryEngineConfig, | ||
) | ||
from griptape.drivers import ( | ||
LocalVectorStoreDriver, | ||
AzureOpenAiChatPromptDriver, | ||
AzureOpenAiEmbeddingDriver, | ||
AzureOpenAiImageGenerationDriver, | ||
AzureOpenAiVisionImageQueryDriver, | ||
) | ||
|
||
|
||
@define | ||
class AzureOpenAiStructureConfig(BaseStructureConfig): | ||
"""Azure OpenAI Structure Configuration. | ||
Attributes: | ||
azure_endpoint: The endpoint for the Azure OpenAI instance. | ||
azure_deployment_prefix: An optional prefix for Azure mode deployment names. By default, the deployment names are the same as the default models. | ||
azure_ad_token: An optional Azure Active Directory token. | ||
azure_ad_token_provider: An optional Azure Active Directory token provider. | ||
api_key: An optional Azure API key. | ||
global_drivers: A `StructureGlobalDriversConfig` instance. | ||
task_memory: A `StructureTaskMemoryConfig` instance. | ||
""" | ||
|
||
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) | ||
azure_deployment_prefix: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) | ||
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) | ||
azure_ad_token_provider: Optional[Callable[[], str]] = field( | ||
kw_only=True, default=None, metadata={"serializable": False} | ||
) | ||
api_key: str = field(kw_only=True, default=None, metadata={"serializable": False}) | ||
global_drivers: StructureGlobalDriversConfig = field( | ||
default=Factory(lambda self: _global_drivers(self), takes_self=True), | ||
kw_only=True, | ||
metadata={"serializable": True}, | ||
) | ||
task_memory: StructureTaskMemoryConfig = field( | ||
default=Factory( | ||
lambda self: StructureTaskMemoryConfig( | ||
query_engine=StructureTaskMemoryQueryEngineConfig( | ||
prompt_driver=self.global_drivers.prompt_driver, | ||
vector_store_driver=LocalVectorStoreDriver(embedding_driver=self.global_drivers.embedding_driver), | ||
), | ||
extraction_engine=StructureTaskMemoryExtractionEngineConfig( | ||
csv=StructureTaskMemoryExtractionEngineCsvConfig(prompt_driver=self.global_drivers.prompt_driver), | ||
json=StructureTaskMemoryExtractionEngineJsonConfig(prompt_driver=self.global_drivers.prompt_driver), | ||
), | ||
summary_engine=StructureTaskMemorySummaryEngineConfig(prompt_driver=self.global_drivers.prompt_driver), | ||
), | ||
takes_self=True, | ||
), | ||
kw_only=True, | ||
metadata={"serializable": True}, | ||
) | ||
|
||
|
||
def _global_drivers(config: AzureOpenAiStructureConfig) -> StructureGlobalDriversConfig: | ||
embedding_driver = AzureOpenAiEmbeddingDriver( | ||
model="text-embedding-3-small", | ||
azure_deployment=f"{config.azure_deployment_prefix}text-embedding-3-small", | ||
azure_endpoint=config.azure_endpoint, | ||
api_key=config.api_key, | ||
azure_ad_token=config.azure_ad_token, | ||
azure_ad_token_provider=config.azure_ad_token_provider, | ||
) | ||
return StructureGlobalDriversConfig( | ||
prompt_driver=AzureOpenAiChatPromptDriver( | ||
model="gpt-4o", | ||
azure_deployment=f"{config.azure_deployment_prefix}gpt-4o", | ||
azure_endpoint=config.azure_endpoint, | ||
api_key=config.api_key, | ||
azure_ad_token=config.azure_ad_token, | ||
azure_ad_token_provider=config.azure_ad_token_provider, | ||
), | ||
image_generation_driver=AzureOpenAiImageGenerationDriver( | ||
model="dall-e-2", | ||
azure_deployment=f"{config.azure_deployment_prefix}dall-e-2", | ||
azure_endpoint=config.azure_endpoint, | ||
api_key=config.api_key, | ||
azure_ad_token=config.azure_ad_token, | ||
azure_ad_token_provider=config.azure_ad_token_provider, | ||
image_size="512x512", | ||
), | ||
image_query_driver=AzureOpenAiVisionImageQueryDriver( | ||
model="gpt-4-vision-preview", | ||
azure_deployment=f"{config.azure_deployment_prefix}gpt-4-vision-preview", | ||
azure_endpoint=config.azure_endpoint, | ||
api_key=config.api_key, | ||
azure_ad_token=config.azure_ad_token, | ||
azure_ad_token_provider=config.azure_ad_token_provider, | ||
), | ||
embedding_driver=embedding_driver, | ||
vector_store_driver=LocalVectorStoreDriver(embedding_driver=embedding_driver), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
49 changes: 49 additions & 0 deletions
49
griptape/drivers/image_query/azure_openai_vision_image_query_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Callable, Optional | ||
|
||
from attr import define, field, Factory | ||
from openai.types.chat import ( | ||
ChatCompletionUserMessageParam, | ||
ChatCompletionContentPartParam, | ||
ChatCompletionContentPartTextParam, | ||
ChatCompletionContentPartImageParam, | ||
) | ||
import openai | ||
from griptape.drivers.image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver | ||
|
||
|
||
@define | ||
class AzureOpenAiVisionImageQueryDriver(OpenAiVisionImageQueryDriver): | ||
"""Driver for Azure-hosted OpenAI image query API. | ||
Attributes: | ||
azure_deployment: An Azure OpenAi deployment id. | ||
azure_endpoint: An Azure OpenAi endpoint. | ||
azure_ad_token: An optional Azure Active Directory token. | ||
azure_ad_token_provider: An optional Azure Active Directory token provider. | ||
api_version: An Azure OpenAi API version. | ||
client: An `openai.AzureOpenAI` client. | ||
""" | ||
|
||
azure_deployment: str = field(kw_only=True, metadata={"serializable": True}) | ||
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) | ||
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) | ||
azure_ad_token_provider: Optional[Callable[[], str]] = field( | ||
kw_only=True, default=None, metadata={"serializable": False} | ||
) | ||
api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True}) | ||
client: openai.AzureOpenAI = field( | ||
default=Factory( | ||
lambda self: openai.AzureOpenAI( | ||
organization=self.organization, | ||
api_key=self.api_key, | ||
api_version=self.api_version, | ||
azure_endpoint=self.azure_endpoint, | ||
azure_deployment=self.azure_deployment, | ||
azure_ad_token=self.azure_ad_token, | ||
azure_ad_token_provider=self.azure_ad_token_provider, | ||
), | ||
takes_self=True, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
181 changes: 181 additions & 0 deletions
181
tests/unit/config/test_azure_openai_structure_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
from pytest import fixture | ||
from griptape.config import AzureOpenAiStructureConfig | ||
|
||
|
||
class TestAzureOpenAiStructureConfig: | ||
@fixture(autouse=True) | ||
def mock_openai(self, mocker): | ||
return mocker.patch("openai.AzureOpenAI") | ||
|
||
@fixture | ||
def config(self): | ||
return AzureOpenAiStructureConfig(azure_endpoint="http://localhost:8080", azure_deployment_prefix="test-") | ||
|
||
def test_to_dict(self, config): | ||
assert config.to_dict() == { | ||
"type": "AzureOpenAiStructureConfig", | ||
"azure_endpoint": "http://localhost:8080", | ||
"azure_deployment_prefix": "test-", | ||
"global_drivers": { | ||
"type": "StructureGlobalDriversConfig", | ||
"prompt_driver": { | ||
"type": "AzureOpenAiChatPromptDriver", | ||
"base_url": None, | ||
"model": "gpt-4o", | ||
"azure_deployment": "test-gpt-4o", | ||
"azure_endpoint": "http://localhost:8080", | ||
"api_version": "2023-05-15", | ||
"organization": None, | ||
"response_format": None, | ||
"seed": None, | ||
"temperature": 0.1, | ||
"max_tokens": None, | ||
"stream": False, | ||
"user": "", | ||
}, | ||
"conversation_memory_driver": None, | ||
"embedding_driver": { | ||
"base_url": None, | ||
"model": "text-embedding-3-small", | ||
"api_version": "2023-05-15", | ||
"azure_deployment": "test-text-embedding-3-small", | ||
"azure_endpoint": "http://localhost:8080", | ||
"organization": None, | ||
"type": "AzureOpenAiEmbeddingDriver", | ||
}, | ||
"image_generation_driver": { | ||
"api_version": "2024-02-01", | ||
"base_url": None, | ||
"image_size": "512x512", | ||
"model": "dall-e-2", | ||
"azure_deployment": "test-dall-e-2", | ||
"azure_endpoint": "http://localhost:8080", | ||
"organization": None, | ||
"quality": "standard", | ||
"response_format": "b64_json", | ||
"style": None, | ||
"type": "AzureOpenAiImageGenerationDriver", | ||
}, | ||
"text_to_speech_driver": {"type": "DummyTextToSpeechDriver"}, | ||
"image_query_driver": { | ||
"api_version": None, | ||
"base_url": None, | ||
"image_quality": "auto", | ||
"max_tokens": 256, | ||
"model": "gpt-4-vision-preview", | ||
"api_version": "2024-02-01", | ||
"azure_deployment": "test-gpt-4-vision-preview", | ||
"azure_endpoint": "http://localhost:8080", | ||
"organization": None, | ||
"type": "AzureOpenAiVisionImageQueryDriver", | ||
}, | ||
"vector_store_driver": { | ||
"embedding_driver": { | ||
"base_url": None, | ||
"model": "text-embedding-3-small", | ||
"api_version": "2023-05-15", | ||
"azure_deployment": "test-text-embedding-3-small", | ||
"azure_endpoint": "http://localhost:8080", | ||
"organization": None, | ||
"type": "AzureOpenAiEmbeddingDriver", | ||
}, | ||
"type": "LocalVectorStoreDriver", | ||
}, | ||
}, | ||
"task_memory": { | ||
"type": "StructureTaskMemoryConfig", | ||
"query_engine": { | ||
"type": "StructureTaskMemoryQueryEngineConfig", | ||
"prompt_driver": { | ||
"base_url": None, | ||
"type": "AzureOpenAiChatPromptDriver", | ||
"api_version": "2023-05-15", | ||
"model": "gpt-4o", | ||
"azure_deployment": "test-gpt-4o", | ||
"azure_endpoint": "http://localhost:8080", | ||
"organization": None, | ||
"response_format": None, | ||
"seed": None, | ||
"temperature": 0.1, | ||
"max_tokens": None, | ||
"stream": False, | ||
"user": "", | ||
}, | ||
"vector_store_driver": { | ||
"type": "LocalVectorStoreDriver", | ||
"embedding_driver": { | ||
"type": "AzureOpenAiEmbeddingDriver", | ||
"base_url": None, | ||
"api_version": "2023-05-15", | ||
"organization": None, | ||
"model": "text-embedding-3-small", | ||
"azure_deployment": "test-text-embedding-3-small", | ||
"azure_endpoint": "http://localhost:8080", | ||
}, | ||
}, | ||
}, | ||
"extraction_engine": { | ||
"type": "StructureTaskMemoryExtractionEngineConfig", | ||
"csv": { | ||
"type": "StructureTaskMemoryExtractionEngineCsvConfig", | ||
"prompt_driver": { | ||
"type": "AzureOpenAiChatPromptDriver", | ||
"api_version": "2023-05-15", | ||
"base_url": None, | ||
"model": "gpt-4o", | ||
"azure_deployment": "test-gpt-4o", | ||
"azure_endpoint": "http://localhost:8080", | ||
"organization": None, | ||
"response_format": None, | ||
"seed": None, | ||
"temperature": 0.1, | ||
"max_tokens": None, | ||
"stream": False, | ||
"user": "", | ||
}, | ||
}, | ||
"json": { | ||
"type": "StructureTaskMemoryExtractionEngineJsonConfig", | ||
"prompt_driver": { | ||
"type": "AzureOpenAiChatPromptDriver", | ||
"api_version": "2023-05-15", | ||
"base_url": None, | ||
"model": "gpt-4o", | ||
"azure_deployment": "test-gpt-4o", | ||
"azure_endpoint": "http://localhost:8080", | ||
"organization": None, | ||
"response_format": None, | ||
"seed": None, | ||
"temperature": 0.1, | ||
"max_tokens": None, | ||
"stream": False, | ||
"user": "", | ||
}, | ||
}, | ||
}, | ||
"summary_engine": { | ||
"type": "StructureTaskMemorySummaryEngineConfig", | ||
"prompt_driver": { | ||
"api_version": "2023-05-15", | ||
"type": "AzureOpenAiChatPromptDriver", | ||
"base_url": None, | ||
"model": "gpt-4o", | ||
"azure_deployment": "test-gpt-4o", | ||
"azure_endpoint": "http://localhost:8080", | ||
"organization": None, | ||
"response_format": None, | ||
"seed": None, | ||
"temperature": 0.1, | ||
"max_tokens": None, | ||
"stream": False, | ||
"user": "", | ||
}, | ||
}, | ||
}, | ||
} | ||
|
||
def test_from_dict(self, config): | ||
assert AzureOpenAiStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() | ||
new_config = config.to_dict() | ||
new_config["global_drivers"]["prompt_driver"]["azure_deployment"] = "new-test-gpt-4o" | ||
assert AzureOpenAiStructureConfig.from_dict(new_config).to_dict() == new_config |