diff --git a/docs/examples/multiple-agent-shared-memory.md b/docs/examples/multiple-agent-shared-memory.md index 3e56634bdb..7332de1658 100644 --- a/docs/examples/multiple-agent-shared-memory.md +++ b/docs/examples/multiple-agent-shared-memory.md @@ -10,7 +10,7 @@ The `MongoDbAtlasVectorStoreDriver` assumes that you have a vector index configu import os from griptape.tools import WebScraper, VectorStoreClient, TaskMemoryClient from griptape.structures import Agent -from griptape.drivers import AzureMongoDbVectorStoreDriver +from griptape.drivers import AzureOpenAiEmbeddingDriver, AzureMongoDbVectorStoreDriver from griptape.engines import VectorQueryEngine, PromptSummaryEngine, CsvExtractionEngine, JsonExtractionEngine from griptape.memory import TaskMemory from griptape.artifacts import TextArtifact @@ -31,31 +31,32 @@ MONGODB_VECTOR_PATH = os.environ["MONGODB_VECTOR_PATH"] MONGODB_CONNECTION_STRING = f"mongodb+srv://{MONGODB_USERNAME}:{MONGODB_PASSWORD}@{MONGODB_HOST}/{MONGODB_DATABASE_NAME}?tls=true&authMechanism=SCRAM-SHA-256&retrywrites=false&maxIdleTimeMS=120000" -azure_embedding_driver = AzureOpenAiEmbeddingDriver( +embedding_driver = AzureOpenAiEmbeddingDriver( model='text-embedding-ada-002', azure_endpoint=AZURE_OPENAI_ENDPOINT_1, api_key=AZURE_OPENAI_API_KEY_1, - azure_deployment='text-embedding-ada-002' ) mongo_driver = AzureMongoDbVectorStoreDriver( connection_string=MONGODB_CONNECTION_STRING, database_name=MONGODB_DATABASE_NAME, collection_name=MONGODB_COLLECTION_NAME, - embedding_driver=azure_embedding_driver, + embedding_driver=embedding_driver, index_name=MONGODB_INDEX_NAME, - vector_path=MONGODB_VECTOR_PATH + vector_path=MONGODB_VECTOR_PATH, +) + +config = AzureOpenAiStructureConfig( + azure_endpoint=AZURE_OPENAI_ENDPOINT_1, + vector_store_driver=mongo_driver, + embedding_driver=embedding_driver, ) loader = Agent( tools=[ - WebScraper() + WebScraper(), ], - config=AzureOpenAiStructureConfig( - azure_endpoint=AZURE_OPENAI_ENDPOINT_1 - vector_store_driver=mongo_driver, - embedding_driver=embedding_driver - ), + config=config, ) asker = Agent( tools=[ @@ -63,11 +64,7 @@ asker = Agent( ], meta_memory=loader.meta_memory, task_memory=loader.task_memory, - config=AzureOpenAiStructureConfig( - azure_endpoint=AZURE_OPENAI_ENDPOINT_1 - vector_store_driver=mongo_driver, - embedding_driver=embedding_driver - ), + config=config, ) if __name__ == "__main__": diff --git a/docs/griptape-framework/structures/config.md b/docs/griptape-framework/structures/config.md index 2a72c123af..e74ef6f506 100644 --- a/docs/griptape-framework/structures/config.md +++ b/docs/griptape-framework/structures/config.md @@ -129,7 +129,11 @@ serialized_config = custom_config.to_json() deserialized_config = AmazonBedrockStructureConfig.from_json(serialized_config) agent = Agent( - config=deserialized_config, + config=deserialized_config.merge_config({ + "prompt_driver" : { + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + }, + }), ) ``` diff --git a/griptape/config/azure_openai_structure_config.py b/griptape/config/azure_openai_structure_config.py index 6031c55c46..35adb1eaa9 100644 --- a/griptape/config/azure_openai_structure_config.py +++ b/griptape/config/azure_openai_structure_config.py @@ -17,7 +17,6 @@ class AzureOpenAiStructureConfig(StructureConfig): Attributes: azure_endpoint: The endpoint for the Azure OpenAI instance. - azure_deployment_prefix: An optional prefix for Azure mode deployment names. By default, the deployment names are the same as the default models. azure_ad_token: An optional Azure Active Directory token. azure_ad_token_provider: An optional Azure Active Directory token provider. api_key: An optional Azure API key. @@ -29,7 +28,6 @@ class AzureOpenAiStructureConfig(StructureConfig): """ azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) - azure_deployment_prefix: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) azure_ad_token_provider: Optional[Callable[[], str]] = field( kw_only=True, default=None, metadata={"serializable": False} @@ -39,7 +37,6 @@ class AzureOpenAiStructureConfig(StructureConfig): default=Factory( lambda self: AzureOpenAiChatPromptDriver( model="gpt-4o", - azure_deployment=f"{self.azure_deployment_prefix}gpt-4o", azure_endpoint=self.azure_endpoint, api_key=self.api_key, azure_ad_token=self.azure_ad_token, @@ -54,7 +51,6 @@ class AzureOpenAiStructureConfig(StructureConfig): default=Factory( lambda self: AzureOpenAiImageGenerationDriver( model="dall-e-2", - azure_deployment=f"{self.azure_deployment_prefix}dall-e-2", azure_endpoint=self.azure_endpoint, api_key=self.api_key, azure_ad_token=self.azure_ad_token, @@ -69,8 +65,7 @@ class AzureOpenAiStructureConfig(StructureConfig): image_query_driver: AzureOpenAiVisionImageQueryDriver = field( default=Factory( lambda self: AzureOpenAiVisionImageQueryDriver( - model="gpt-4-vision-preview", - azure_deployment=f"{self.azure_deployment_prefix}gpt-4-vision-preview", + model="gpt-4", azure_endpoint=self.azure_endpoint, api_key=self.api_key, azure_ad_token=self.azure_ad_token, @@ -85,7 +80,6 @@ class AzureOpenAiStructureConfig(StructureConfig): default=Factory( lambda self: AzureOpenAiEmbeddingDriver( model="text-embedding-3-small", - azure_deployment=f"{self.azure_deployment_prefix}text-embedding-3-small", azure_endpoint=self.azure_endpoint, api_key=self.api_key, azure_ad_token=self.azure_ad_token, diff --git a/griptape/drivers/embedding/azure_openai_embedding_driver.py b/griptape/drivers/embedding/azure_openai_embedding_driver.py index 3cf473c061..6b8ab1b2ba 100644 --- a/griptape/drivers/embedding/azure_openai_embedding_driver.py +++ b/griptape/drivers/embedding/azure_openai_embedding_driver.py @@ -20,7 +20,9 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver): client: An `openai.AzureOpenAI` client. """ - azure_deployment: str = field(kw_only=True, metadata={"serializable": True}) + azure_deployment: str = field( + kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True} + ) azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) azure_ad_token_provider: Optional[Callable[[], str]] = field( diff --git a/griptape/drivers/image_generation/azure_openai_image_generation_driver.py b/griptape/drivers/image_generation/azure_openai_image_generation_driver.py index d93ec732e8..49debc5d85 100644 --- a/griptape/drivers/image_generation/azure_openai_image_generation_driver.py +++ b/griptape/drivers/image_generation/azure_openai_image_generation_driver.py @@ -20,7 +20,9 @@ class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver): client: An `openai.AzureOpenAI` client. """ - azure_deployment: str = field(kw_only=True, metadata={"serializable": True}) + azure_deployment: str = field( + kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True} + ) azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) azure_ad_token_provider: Optional[Callable[[], str]] = field( diff --git a/griptape/drivers/image_query/azure_openai_vision_image_query_driver.py b/griptape/drivers/image_query/azure_openai_vision_image_query_driver.py index 460b343624..a065e40a7b 100644 --- a/griptape/drivers/image_query/azure_openai_vision_image_query_driver.py +++ b/griptape/drivers/image_query/azure_openai_vision_image_query_driver.py @@ -20,7 +20,9 @@ class AzureOpenAiVisionImageQueryDriver(OpenAiVisionImageQueryDriver): client: An `openai.AzureOpenAI` client. """ - azure_deployment: str = field(kw_only=True, metadata={"serializable": True}) + azure_deployment: str = field( + kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True} + ) azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) azure_ad_token_provider: Optional[Callable[[], str]] = field( diff --git a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py index fe4d6d1ef0..c401c7fd7f 100644 --- a/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py +++ b/griptape/drivers/prompt/azure_openai_chat_prompt_driver.py @@ -17,7 +17,9 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver): client: An `openai.AzureOpenAI` client. """ - azure_deployment: str = field(kw_only=True, metadata={"serializable": True}) + azure_deployment: str = field( + kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True} + ) azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) azure_ad_token_provider: Optional[Callable[[], str]] = field( diff --git a/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py b/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py index 53e51963ff..58dfe299f9 100644 --- a/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py +++ b/griptape/drivers/prompt/azure_openai_completion_prompt_driver.py @@ -16,7 +16,9 @@ class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver): client: An `openai.AzureOpenAI` client. """ - azure_deployment: str = field(kw_only=True, metadata={"serializable": True}) + azure_deployment: str = field( + kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True} + ) azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) azure_ad_token_provider: Optional[Callable[[], str]] = field( diff --git a/tests/unit/config/test_azure_openai_structure_config.py b/tests/unit/config/test_azure_openai_structure_config.py index 201ed60ef7..94443e1446 100644 --- a/tests/unit/config/test_azure_openai_structure_config.py +++ b/tests/unit/config/test_azure_openai_structure_config.py @@ -11,7 +11,6 @@ def mock_openai(self, mocker): def config(self): return AzureOpenAiStructureConfig( azure_endpoint="http://localhost:8080", - azure_deployment_prefix="test-", azure_ad_token="test-token", azure_ad_token_provider=lambda: "test-provider", ) @@ -20,12 +19,11 @@ def test_to_dict(self, config): assert config.to_dict() == { "type": "AzureOpenAiStructureConfig", "azure_endpoint": "http://localhost:8080", - "azure_deployment_prefix": "test-", "prompt_driver": { "type": "AzureOpenAiChatPromptDriver", "base_url": None, "model": "gpt-4o", - "azure_deployment": "test-gpt-4o", + "azure_deployment": "gpt-4o", "azure_endpoint": "http://localhost:8080", "api_version": "2023-05-15", "organization": None, @@ -41,7 +39,7 @@ def test_to_dict(self, config): "base_url": None, "model": "text-embedding-3-small", "api_version": "2023-05-15", - "azure_deployment": "test-text-embedding-3-small", + "azure_deployment": "text-embedding-3-small", "azure_endpoint": "http://localhost:8080", "organization": None, "type": "AzureOpenAiEmbeddingDriver", @@ -51,7 +49,7 @@ def test_to_dict(self, config): "base_url": None, "image_size": "512x512", "model": "dall-e-2", - "azure_deployment": "test-dall-e-2", + "azure_deployment": "dall-e-2", "azure_endpoint": "http://localhost:8080", "organization": None, "quality": "standard", @@ -66,7 +64,7 @@ def test_to_dict(self, config): "max_tokens": 256, "model": "gpt-4", "api_version": "2024-02-01", - "azure_deployment": "test-gpt-4", + "azure_deployment": "gpt-4", "azure_endpoint": "http://localhost:8080", "organization": None, "type": "AzureOpenAiVisionImageQueryDriver", @@ -76,7 +74,7 @@ def test_to_dict(self, config): "base_url": None, "model": "text-embedding-3-small", "api_version": "2023-05-15", - "azure_deployment": "test-text-embedding-3-small", + "azure_deployment": "text-embedding-3-small", "azure_endpoint": "http://localhost:8080", "organization": None, "type": "AzureOpenAiEmbeddingDriver", @@ -85,12 +83,15 @@ def test_to_dict(self, config): }, } - def test_from_dict(self, config): + def test_from_dict(self, config: AzureOpenAiStructureConfig): assert AzureOpenAiStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict() # override values in the dict config # serialize and deserialize the config - new_config = config.to_dict() - new_config["prompt_driver"]["azure_deployment"] = "new-test-gpt-4o" - new_config["embedding_driver"]["model"] = "new-text-embedding-3-small" + new_config = config.merge_config( + { + "prompt_driver": {"azure_deployment": "new-test-gpt-4o"}, + "embedding_driver": {"model": "new-text-embedding-3-small"}, + } + ).to_dict() assert AzureOpenAiStructureConfig.from_dict(new_config).to_dict() == new_config diff --git a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py index 13b122af76..d7bd82ee5f 100644 --- a/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_azure_openai_image_generation_driver.py @@ -23,12 +23,6 @@ def test_init_requires_endpoint(self): model="dall-e-3", client=Mock(), azure_deployment="dalle-deployment", image_size="512x512" ) # pyright: ignore - def test_init_requires_deployment(self): - with pytest.raises(TypeError): - AzureOpenAiImageGenerationDriver( - model="dall-e-3", client=Mock(), azure_endpoint="https://dalle.example.com", image_size="512x512" - ) # pyright: ignore - def test_try_text_to_image(self, driver): driver.client.images.generate.return_value = Mock(data=[Mock(b64_json=b"aW1hZ2UgZGF0YQ==")])