Skip to content

Commit

Permalink
Add AzureOpenAiStructureConfig, default azure_deployment (#788)
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo authored May 20, 2024
1 parent 7ed3b90 commit 540e8f3
Show file tree
Hide file tree
Showing 19 changed files with 432 additions and 38 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/docs-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ jobs:
AZURE_OPENAI_API_KEY_1: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_1 }}
AZURE_OPENAI_ENDPOINT_2: ${{ secrets.INTEG_AZURE_OPENAI_ENDPOINT_2 }}
AZURE_OPENAI_API_KEY_2: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_2 }}
AZURE_OPENAI_ENDPOINT_3: ${{ secrets.INTEG_AZURE_OPENAI_ENDPOINT_3 }}
AZURE_OPENAI_API_KEY_3: ${{ secrets.INTEG_AZURE_OPENAI_API_KEY_3 }}
AZURE_OPENAI_35_TURBO_16K_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_35_TURBO_16K_DEPLOYMENT_ID }}
AZURE_OPENAI_DAVINCI_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_DAVINCI_DEPLOYMENT_ID }}
AZURE_OPENAI_4_DEPLOYMENT_ID: ${{ secrets.INTEG_OPENAI_4_DEPLOYMENT_ID }}
Expand Down
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added
- `AzureOpenAiStructureConfig` for providing Structures with all Azure OpenAI Driver configuration.
- `AzureOpenAiVisionImageQueryDriver` to support queries on images using Azure's OpenAI Vision models.

### Changed
- Default the value of `azure_deployment` on all Azure Drivers to the model the Driver is using.
- Field `azure_ad_token` on all Azure Drivers is no longer serializable.

## [0.25.1] - 2024-05-15

### Fixed
Expand Down
33 changes: 14 additions & 19 deletions docs/examples/multiple-agent-shared-memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ The `MongoDbAtlasVectorStoreDriver` assumes that you have a vector index configu
import os
from griptape.tools import WebScraper, VectorStoreClient, TaskMemoryClient
from griptape.structures import Agent
from griptape.drivers import AzureOpenAiChatPromptDriver, AzureOpenAiEmbeddingDriver, AzureMongoDbVectorStoreDriver
from griptape.drivers import AzureOpenAiEmbeddingDriver, AzureMongoDbVectorStoreDriver
from griptape.engines import VectorQueryEngine, PromptSummaryEngine, CsvExtractionEngine, JsonExtractionEngine
from griptape.memory import TaskMemory
from griptape.artifacts import TextArtifact
from griptape.memory.task.storage import TextArtifactStorage
from griptape.config import StructureConfig
from griptape.config import AzureOpenAiStructureConfig


AZURE_OPENAI_ENDPOINT_1 = os.environ["AZURE_OPENAI_ENDPOINT_1"]
Expand All @@ -31,45 +31,40 @@ MONGODB_VECTOR_PATH = os.environ["MONGODB_VECTOR_PATH"]
MONGODB_CONNECTION_STRING = f"mongodb+srv://{MONGODB_USERNAME}:{MONGODB_PASSWORD}@{MONGODB_HOST}/{MONGODB_DATABASE_NAME}?tls=true&authMechanism=SCRAM-SHA-256&retrywrites=false&maxIdleTimeMS=120000"


azure_embedding_driver = AzureOpenAiEmbeddingDriver(
embedding_driver = AzureOpenAiEmbeddingDriver(
model='text-embedding-ada-002',
azure_endpoint=AZURE_OPENAI_ENDPOINT_1,
api_key=AZURE_OPENAI_API_KEY_1,
azure_deployment='text-embedding-ada-002'
)

azure_prompt_driver = AzureOpenAiChatPromptDriver(
model='gpt-4',
azure_endpoint=AZURE_OPENAI_ENDPOINT_1,
api_key=AZURE_OPENAI_API_KEY_1,
azure_deployment='gpt-4'
)

mongo_driver = AzureMongoDbVectorStoreDriver(
connection_string=MONGODB_CONNECTION_STRING,
database_name=MONGODB_DATABASE_NAME,
collection_name=MONGODB_COLLECTION_NAME,
embedding_driver=azure_embedding_driver,
embedding_driver=embedding_driver,
index_name=MONGODB_INDEX_NAME,
vector_path=MONGODB_VECTOR_PATH
vector_path=MONGODB_VECTOR_PATH,
)

config = AzureOpenAiStructureConfig(
azure_endpoint=AZURE_OPENAI_ENDPOINT_1,
vector_store_driver=mongo_driver,
embedding_driver=embedding_driver,
)

loader = Agent(
tools=[
WebScraper()
WebScraper(),
],
config=StructureConfig(
prompt_driver=azure_prompt_driver,
vector_store_driver=mongo_driver,
embedding_driver=azure_embedding_driver
),
config=config,
)
asker = Agent(
tools=[
TaskMemoryClient(off_prompt=False),
],
meta_memory=loader.meta_memory,
task_memory=loader.task_memory,
config=config,
)

if __name__ == "__main__":
Expand Down
33 changes: 32 additions & 1 deletion docs/griptape-framework/drivers/image-query-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ print(result)
## OpenAiVisionImageQueryDriver

!!! info
This Driver defaults to using the `gpt-4-vision-preview` model. As other multimodal models are released, they can be specified using the `model` field. While the `max_tokens` field is optional, it is recommended to set this to a value that corresponds to the desired response length. Without an explicit value, the model will default to very short responses. See [OpenAI's documentation](https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them) for more information on how to relate token count to response length.
While the `max_tokens` field is optional, it is recommended to set this to a value that corresponds to the desired response length. Without an explicit value, the model will default to very short responses. See [OpenAI's documentation](https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them) for more information on how to relate token count to response length.

The [OpenAiVisionImageQueryDriver](../../reference/griptape/drivers/image_query/openai_vision_image_query_driver.md) is used to query images using the OpenAI Vision API. Here is an example of how to use it:

Expand All @@ -86,6 +86,37 @@ with open("tests/resources/mountain.png", "rb") as f:
engine.run("Describe the weather in the image", [image_artifact])
```

## AzureOpenAiVisionImageQueryDriver

!!! info
In order to use the `gpt-4-vision-preview` model on Azure OpenAI, the `gpt-4` model must be deployed with the version set to `vision-preview`. More information can be found in the [Azure OpenAI documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/gpt-with-vision).

The [AzureOpenAiVisionImageQueryDriver](../../reference/griptape/drivers/image_query/azure_openai_vision_image_query_driver.md) is used to query images using the Azure OpenAI Vision API. Here is an example of how to use it:

```python
import os
from griptape.drivers import AzureOpenAiVisionImageQueryDriver
from griptape.engines import ImageQueryEngine
from griptape.loaders import ImageLoader

driver = AzureOpenAiVisionImageQueryDriver(
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"],
api_key=os.environ["AZURE_OPENAI_API_KEY_3"],
model="gpt-4",
azure_deployment="gpt-4-vision-preview",
max_tokens=256,
)

engine = ImageQueryEngine(
image_query_driver=driver,
)

with open("tests/resources/mountain.png", "rb") as f:
image_artifact = ImageLoader().load(f.read())

engine.run("Describe the weather in the image", [image_artifact])
```

## AmazonBedrockImageQueryDriver

The [Amazon Bedrock Image Query Driver](../../reference/griptape/drivers/image_query/amazon_bedrock_image_query_driver.md) provides multi-model access to image query models hosted by Amazon Bedrock. This Driver manages API calls to the Bedrock API, while the specific Model Drivers below format the API requests and parse the responses.
Expand Down
28 changes: 27 additions & 1 deletion docs/griptape-framework/structures/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,28 @@ agent = Agent(
agent = Agent() # This is equivalent to the above
```

#### Azure OpenAI

The [Azure OpenAI Structure Config](../../reference/griptape/config/azure_openai_structure_config.md) provides default Drivers for Azure's OpenAI APIs.


```python
import os
from griptape.structures import Agent
from griptape.config import AzureOpenAiStructureConfig

agent = Agent(
config=AzureOpenAiStructureConfig(
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"],
api_key=os.environ["AZURE_OPENAI_API_KEY_3"]
).merge_config({
"image_query_driver": {
"azure_deployment": "gpt-4-vision-preview",
},
}),
)
```

#### Amazon Bedrock
The [Amazon Bedrock Structure Config](../../reference/griptape/config/amazon_bedrock_structure_config.md) provides default Drivers for Amazon Bedrock's APIs.

Expand Down Expand Up @@ -112,7 +134,11 @@ serialized_config = custom_config.to_json()
deserialized_config = AmazonBedrockStructureConfig.from_json(serialized_config)

agent = Agent(
config=deserialized_config,
config=deserialized_config.merge_config({
"prompt_driver" : {
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
},
}),
)
```

2 changes: 2 additions & 0 deletions griptape/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .structure_config import StructureConfig
from .openai_structure_config import OpenAiStructureConfig
from .azure_openai_structure_config import AzureOpenAiStructureConfig
from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig
from .anthropic_structure_config import AnthropicStructureConfig
from .google_structure_config import GoogleStructureConfig
Expand All @@ -14,6 +15,7 @@
"BaseStructureConfig",
"StructureConfig",
"OpenAiStructureConfig",
"AzureOpenAiStructureConfig",
"AmazonBedrockStructureConfig",
"AnthropicStructureConfig",
"GoogleStructureConfig",
Expand Down
102 changes: 102 additions & 0 deletions griptape/config/azure_openai_structure_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Callable, Optional
from attrs import Factory, define, field

from griptape.config import StructureConfig
from griptape.drivers import (
LocalVectorStoreDriver,
AzureOpenAiChatPromptDriver,
AzureOpenAiEmbeddingDriver,
AzureOpenAiImageGenerationDriver,
AzureOpenAiVisionImageQueryDriver,
BasePromptDriver,
BaseEmbeddingDriver,
BaseImageGenerationDriver,
BaseImageQueryDriver,
BaseVectorStoreDriver,
)


@define
class AzureOpenAiStructureConfig(StructureConfig):
"""Azure OpenAI Structure Configuration.
Attributes:
azure_endpoint: The endpoint for the Azure OpenAI instance.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_key: An optional Azure API key.
prompt_driver: An Azure OpenAI Chat Prompt Driver.
image_generation_driver: An Azure OpenAI Image Generation Driver.
image_query_driver: An Azure OpenAI Vision Image Query Driver.
embedding_driver: An Azure OpenAI Embedding Driver.
vector_store_driver: A Local Vector Store Driver.
"""

azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
prompt_driver: BasePromptDriver = field(
default=Factory(
lambda self: AzureOpenAiChatPromptDriver(
model="gpt-4",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
),
metadata={"serializable": True},
kw_only=True,
)
image_generation_driver: BaseImageGenerationDriver = field(
default=Factory(
lambda self: AzureOpenAiImageGenerationDriver(
model="dall-e-2",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
image_size="512x512",
),
takes_self=True,
),
metadata={"serializable": True},
kw_only=True,
)
image_query_driver: BaseImageQueryDriver = field(
default=Factory(
lambda self: AzureOpenAiVisionImageQueryDriver(
model="gpt-4",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
),
metadata={"serializable": True},
kw_only=True,
)
embedding_driver: BaseEmbeddingDriver = field(
default=Factory(
lambda self: AzureOpenAiEmbeddingDriver(
model="text-embedding-3-small",
azure_endpoint=self.azure_endpoint,
api_key=self.api_key,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
),
metadata={"serializable": True},
kw_only=True,
)
vector_store_driver: BaseVectorStoreDriver = field(
default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True),
metadata={"serializable": True},
kw_only=True,
)
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from .image_query.dummy_image_query_driver import DummyImageQueryDriver
from .image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver
from .image_query.anthropic_image_query_driver import AnthropicImageQueryDriver
from .image_query.azure_openai_vision_image_query_driver import AzureOpenAiVisionImageQueryDriver
from .image_query.amazon_bedrock_image_query_driver import AmazonBedrockImageQueryDriver

from .web_scraper.base_web_scraper_driver import BaseWebScraperDriver
Expand Down Expand Up @@ -175,6 +176,7 @@
"BedrockClaudeImageQueryModelDriver",
"BaseImageQueryDriver",
"OpenAiVisionImageQueryDriver",
"AzureOpenAiVisionImageQueryDriver",
"DummyImageQueryDriver",
"AnthropicImageQueryDriver",
"BaseMultiModelImageQueryDriver",
Expand Down
6 changes: 4 additions & 2 deletions griptape/drivers/embedding/azure_openai_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
"""
Attributes:
azure_deployment: An Azure OpenAi deployment id.
azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
azure_endpoint: An Azure OpenAi endpoint.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
Expand All @@ -20,7 +20,9 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
client: An `openai.AzureOpenAI` client.
"""

azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_deployment: str = field(
kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}
)
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@ class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver):
"""Driver for Azure-hosted OpenAI image generation API.
Attributes:
azure_deployment: An Azure OpenAi deployment id.
azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
azure_endpoint: An Azure OpenAi endpoint.
azure_ad_token: An optional Azure Active Directory token.
azure_ad_token_provider: An optional Azure Active Directory token provider.
api_version: An Azure OpenAi API version.
client: An `openai.AzureOpenAI` client.
"""

azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_deployment: str = field(
kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}
)
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
Expand Down
Loading

0 comments on commit 540e8f3

Please sign in to comment.