Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed May 20, 2024
1 parent 5eff82a commit 592075e
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 46 deletions.
29 changes: 13 additions & 16 deletions docs/examples/multiple-agent-shared-memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The `MongoDbAtlasVectorStoreDriver` assumes that you have a vector index configu
import os
from griptape.tools import WebScraper, VectorStoreClient, TaskMemoryClient
from griptape.structures import Agent
from griptape.drivers import AzureMongoDbVectorStoreDriver
from griptape.drivers import AzureOpenAiEmbeddingDriver, AzureMongoDbVectorStoreDriver
from griptape.engines import VectorQueryEngine, PromptSummaryEngine, CsvExtractionEngine, JsonExtractionEngine
from griptape.memory import TaskMemory
from griptape.artifacts import TextArtifact
Expand All @@ -31,43 +31,40 @@ MONGODB_VECTOR_PATH = os.environ["MONGODB_VECTOR_PATH"]
MONGODB_CONNECTION_STRING = f"mongodb+srv://{MONGODB_USERNAME}:{MONGODB_PASSWORD}@{MONGODB_HOST}/{MONGODB_DATABASE_NAME}?tls=true&authMechanism=SCRAM-SHA-256&retrywrites=false&maxIdleTimeMS=120000"


azure_embedding_driver = AzureOpenAiEmbeddingDriver(
embedding_driver = AzureOpenAiEmbeddingDriver(
model='text-embedding-ada-002',
azure_endpoint=AZURE_OPENAI_ENDPOINT_1,
api_key=AZURE_OPENAI_API_KEY_1,
azure_deployment='text-embedding-ada-002'
)

mongo_driver = AzureMongoDbVectorStoreDriver(
connection_string=MONGODB_CONNECTION_STRING,
database_name=MONGODB_DATABASE_NAME,
collection_name=MONGODB_COLLECTION_NAME,
embedding_driver=azure_embedding_driver,
embedding_driver=embedding_driver,
index_name=MONGODB_INDEX_NAME,
vector_path=MONGODB_VECTOR_PATH
vector_path=MONGODB_VECTOR_PATH,
)

config = AzureOpenAiStructureConfig(
azure_endpoint=AZURE_OPENAI_ENDPOINT_1,
vector_store_driver=mongo_driver,
embedding_driver=embedding_driver,
)

loader = Agent(
tools=[
WebScraper()
WebScraper(),
],
config=AzureOpenAiStructureConfig(
azure_endpoint=AZURE_OPENAI_ENDPOINT_1
vector_store_driver=mongo_driver,
embedding_driver=embedding_driver
),
config=config,
)
asker = Agent(
tools=[
TaskMemoryClient(off_prompt=False),
],
meta_memory=loader.meta_memory,
task_memory=loader.task_memory,
config=AzureOpenAiStructureConfig(
azure_endpoint=AZURE_OPENAI_ENDPOINT_1
vector_store_driver=mongo_driver,
embedding_driver=embedding_driver
),
config=config,
)

if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion docs/griptape-framework/structures/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ serialized_config = custom_config.to_json()
deserialized_config = AmazonBedrockStructureConfig.from_json(serialized_config)

agent = Agent(
config=deserialized_config,
config=deserialized_config.merge_config({
"prompt_driver" : {
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
},
}),
)
```

8 changes: 1 addition & 7 deletions griptape/config/azure_openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class AzureOpenAiStructureConfig(StructureConfig):
Attributes:
azure_endpoint: The endpoint for the Azure OpenAI instance.
azure_deployment_prefix: An optional prefix for Azure mode deployment names. By default, the deployment names are the same as the default models.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_key: An optional Azure API key.
Expand All @@ -29,7 +28,6 @@ class AzureOpenAiStructureConfig(StructureConfig):
"""

azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_deployment_prefix: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
Expand All @@ -39,7 +37,6 @@ class AzureOpenAiStructureConfig(StructureConfig):
default=Factory(
lambda self: AzureOpenAiChatPromptDriver(
model="gpt-4o",
azure_deployment=f"{self.azure_deployment_prefix}gpt-4o",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
Expand All @@ -54,7 +51,6 @@ class AzureOpenAiStructureConfig(StructureConfig):
default=Factory(
lambda self: AzureOpenAiImageGenerationDriver(
model="dall-e-2",
azure_deployment=f"{self.azure_deployment_prefix}dall-e-2",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
Expand All @@ -69,8 +65,7 @@ class AzureOpenAiStructureConfig(StructureConfig):
image_query_driver: AzureOpenAiVisionImageQueryDriver = field(
default=Factory(
lambda self: AzureOpenAiVisionImageQueryDriver(
model="gpt-4-vision-preview",
azure_deployment=f"{self.azure_deployment_prefix}gpt-4-vision-preview",
model="gpt-4",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
Expand All @@ -85,7 +80,6 @@ class AzureOpenAiStructureConfig(StructureConfig):
default=Factory(
lambda self: AzureOpenAiEmbeddingDriver(
model="text-embedding-3-small",
azure_deployment=f"{self.azure_deployment_prefix}text-embedding-3-small",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/embedding/azure_openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
client: An `openai.AzureOpenAI` client.
"""

azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_deployment: str = field(
kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}
)
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver):
client: An `openai.AzureOpenAI` client.
"""

azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_deployment: str = field(
kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}
)
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ class AzureOpenAiVisionImageQueryDriver(OpenAiVisionImageQueryDriver):
client: An `openai.AzureOpenAI` client.
"""

azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_deployment: str = field(
kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}
)
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
Expand Down
4 changes: 3 additions & 1 deletion griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
client: An `openai.AzureOpenAI` client.
"""

azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_deployment: str = field(
kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}
)
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
client: An `openai.AzureOpenAI` client.
"""

azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_deployment: str = field(
kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}
)
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
Expand Down
23 changes: 12 additions & 11 deletions tests/unit/config/test_azure_openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def mock_openai(self, mocker):
def config(self):
return AzureOpenAiStructureConfig(
azure_endpoint="http://localhost:8080",
azure_deployment_prefix="test-",
azure_ad_token="test-token",
azure_ad_token_provider=lambda: "test-provider",
)
Expand All @@ -20,12 +19,11 @@ def test_to_dict(self, config):
assert config.to_dict() == {
"type": "AzureOpenAiStructureConfig",
"azure_endpoint": "http://localhost:8080",
"azure_deployment_prefix": "test-",
"prompt_driver": {
"type": "AzureOpenAiChatPromptDriver",
"base_url": None,
"model": "gpt-4o",
"azure_deployment": "test-gpt-4o",
"azure_deployment": "gpt-4o",
"azure_endpoint": "http://localhost:8080",
"api_version": "2023-05-15",
"organization": None,
Expand All @@ -41,7 +39,7 @@ def test_to_dict(self, config):
"base_url": None,
"model": "text-embedding-3-small",
"api_version": "2023-05-15",
"azure_deployment": "test-text-embedding-3-small",
"azure_deployment": "text-embedding-3-small",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"type": "AzureOpenAiEmbeddingDriver",
Expand All @@ -51,7 +49,7 @@ def test_to_dict(self, config):
"base_url": None,
"image_size": "512x512",
"model": "dall-e-2",
"azure_deployment": "test-dall-e-2",
"azure_deployment": "dall-e-2",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"quality": "standard",
Expand All @@ -66,7 +64,7 @@ def test_to_dict(self, config):
"max_tokens": 256,
"model": "gpt-4",
"api_version": "2024-02-01",
"azure_deployment": "test-gpt-4",
"azure_deployment": "gpt-4",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"type": "AzureOpenAiVisionImageQueryDriver",
Expand All @@ -76,7 +74,7 @@ def test_to_dict(self, config):
"base_url": None,
"model": "text-embedding-3-small",
"api_version": "2023-05-15",
"azure_deployment": "test-text-embedding-3-small",
"azure_deployment": "text-embedding-3-small",
"azure_endpoint": "http://localhost:8080",
"organization": None,
"type": "AzureOpenAiEmbeddingDriver",
Expand All @@ -85,12 +83,15 @@ def test_to_dict(self, config):
},
}

def test_from_dict(self, config):
def test_from_dict(self, config: AzureOpenAiStructureConfig):
assert AzureOpenAiStructureConfig.from_dict(config.to_dict()).to_dict() == config.to_dict()

# override values in the dict config
# serialize and deserialize the config
new_config = config.to_dict()
new_config["prompt_driver"]["azure_deployment"] = "new-test-gpt-4o"
new_config["embedding_driver"]["model"] = "new-text-embedding-3-small"
new_config = config.merge_config(
{
"prompt_driver": {"azure_deployment": "new-test-gpt-4o"},
"embedding_driver": {"model": "new-text-embedding-3-small"},
}
).to_dict()
assert AzureOpenAiStructureConfig.from_dict(new_config).to_dict() == new_config
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@ def test_init_requires_endpoint(self):
model="dall-e-3", client=Mock(), azure_deployment="dalle-deployment", image_size="512x512"
) # pyright: ignore

def test_init_requires_deployment(self):
with pytest.raises(TypeError):
AzureOpenAiImageGenerationDriver(
model="dall-e-3", client=Mock(), azure_endpoint="https://dalle.example.com", image_size="512x512"
) # pyright: ignore

def test_try_text_to_image(self, driver):
driver.client.images.generate.return_value = Mock(data=[Mock(b64_json=b"aW1hZ2UgZGF0YQ==")])

Expand Down

0 comments on commit 592075e

Please sign in to comment.