Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update cohere prompt driver, add cohere embedding driver, cohere stru… #831

Merged
merged 4 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BaseTask.add_parent()` to add a parent task to a child task.
- `BaseTask.add_parents()` to add multiple parent tasks to a child task.
- `Structure.resolve_relationships()` to resolve asymmetrically defined parent/child relationships. In other words, if a parent declares a child, but the child does not declare the parent, the parent will automatically be added as a parent of the child when running this method. The method is invoked automatically by `Structure.before_run()`.
- `CohereEmbeddingDriver` for using Cohere's embeddings API.
- `CohereStructureConfig` for providing Structures with quick Cohere configuration.

### Changed
- **BREAKING**: `Workflow` no longer modifies task relationships when adding tasks via `tasks` init param, `add_tasks()` or `add_task()`. Previously, adding a task would automatically add the previously added task as its parent. Existing code that relies on this behavior will need to be updated to explicitly add parent/child relationships using the API offered by `BaseTask`.
- `Structure.before_run()` now automatically resolves asymmetrically defined parent/child relationships using the new `Structure.resolve_relationships()`.
- Updated `HuggingFaceHubPromptDriver` to use `transformers`'s `apply_chat_template`.
- Updated `HuggingFacePipelinePromptDriver` to use chat features of `transformers.TextGenerationPipeline`.
- Updated `CoherePromptDriver` to use Cohere's latest SDK.

### Fixed
- `Workflow.insert_task()` no longer inserts duplicate tasks when given multiple parent tasks.
Expand Down
23 changes: 23 additions & 0 deletions docs/griptape-framework/drivers/embedding-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,29 @@ embeddings = driver.embed_string("Hello world!")
print(embeddings[:3])
```

### Cohere Embeddings

The [CohereEmbeddingDriver](../../reference/griptape/drivers/embedding/cohere_embedding_driver.md) uses the [Cohere Embeddings API](https://docs.cohere.com/docs/embeddings).

!!! info
This driver requires the `drivers-embedding-cohere` [extra](../index.md#extras).

```python
import os
from griptape.drivers import CohereEmbeddingDriver

embedding_driver=CohereEmbeddingDriver(
model="embed-english-v3.0",
api_key=os.environ["COHERE_API_KEY"],
input_type="search_document",
)

embeddings = embedding_driver.embed_string("Hello world!")

# display the first 3 embeddings
print(embeddings[:3])
```

### Override Default Structure Embedding Driver
Here is how you can override the Embedding Driver that is used by default in Structures.

Expand Down
2 changes: 2 additions & 0 deletions griptape/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .amazon_bedrock_structure_config import AmazonBedrockStructureConfig
from .anthropic_structure_config import AnthropicStructureConfig
from .google_structure_config import GoogleStructureConfig
from .cohere_structure_config import CohereStructureConfig


__all__ = [
Expand All @@ -19,4 +20,5 @@
"AmazonBedrockStructureConfig",
"AnthropicStructureConfig",
"GoogleStructureConfig",
"CohereStructureConfig",
]
37 changes: 37 additions & 0 deletions griptape/config/cohere_structure_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from attrs import Factory, define, field

from griptape.config import StructureConfig
from griptape.drivers import (
BaseEmbeddingDriver,
BasePromptDriver,
CoherePromptDriver,
CohereEmbeddingDriver,
BaseVectorStoreDriver,
LocalVectorStoreDriver,
)


@define
class CohereStructureConfig(StructureConfig):
api_key: str = field(metadata={"serializable": False}, kw_only=True)

prompt_driver: BasePromptDriver = field(
default=Factory(lambda self: CoherePromptDriver(model="command-r", api_key=self.api_key), takes_self=True),
metadata={"serializable": True},
kw_only=True,
)
embedding_driver: BaseEmbeddingDriver = field(
default=Factory(
lambda self: CohereEmbeddingDriver(
model="embed-english-v3.0", api_key=self.api_key, input_type="search_document"
),
takes_self=True,
),
metadata={"serializable": True},
kw_only=True,
)
vector_store_driver: BaseVectorStoreDriver = field(
default=Factory(lambda self: LocalVectorStoreDriver(embedding_driver=self.embedding_driver), takes_self=True),
kw_only=True,
metadata={"serializable": True},
)
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .embedding.huggingface_hub_embedding_driver import HuggingFaceHubEmbeddingDriver
from .embedding.google_embedding_driver import GoogleEmbeddingDriver
from .embedding.dummy_embedding_driver import DummyEmbeddingDriver
from .embedding.cohere_embedding_driver import CohereEmbeddingDriver

from .embedding_model.base_embedding_model_driver import BaseEmbeddingModelDriver
from .embedding_model.sagemaker_huggingface_embedding_model_driver import SageMakerHuggingFaceEmbeddingModelDriver
Expand Down Expand Up @@ -143,6 +144,7 @@
"GoogleEmbeddingDriver",
"DummyEmbeddingDriver",
"BaseEmbeddingModelDriver",
"CohereEmbeddingDriver",
"SageMakerHuggingFaceEmbeddingModelDriver",
"SageMakerTensorFlowHubEmbeddingModelDriver",
"BaseVectorStoreDriver",
Expand Down
43 changes: 43 additions & 0 deletions griptape/drivers/embedding/cohere_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from attrs import define, field, Factory
from griptape.drivers import BaseEmbeddingDriver
from griptape.tokenizers import CohereTokenizer
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from cohere import Client

Check warning on line 9 in griptape/drivers/embedding/cohere_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/cohere_embedding_driver.py#L9

Added line #L9 was not covered by tests


@define
class CohereEmbeddingDriver(BaseEmbeddingDriver):
"""
Attributes:
api_key: Cohere API key.
model: Cohere model name.
client: Custom `cohere.Client`.
tokenizer: Custom `CohereTokenizer`.
input_type: Cohere embedding input type.
"""

DEFAULT_MODEL = "models/embedding-001"

api_key: str = field(kw_only=True, metadata={"serializable": False})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
kw_only=True,
)
tokenizer: CohereTokenizer = field(
default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
kw_only=True,
)

input_type: str = field(kw_only=True, metadata={"serializable": True})

def try_embed_chunk(self, chunk: str) -> list[float]:
result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)

if isinstance(result.embeddings, list):
return result.embeddings[0]
else:
raise ValueError("Non-float embeddings are not supported.")

Check warning on line 43 in griptape/drivers/embedding/cohere_embedding_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/embedding/cohere_embedding_driver.py#L43

Added line #L43 was not covered by tests
50 changes: 27 additions & 23 deletions griptape/drivers/prompt/cohere_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from collections.abc import Iterator
from attrs import define, field, Factory
from griptape.artifacts import TextArtifact
Expand All @@ -21,7 +21,7 @@
tokenizer: Custom `CohereTokenizer`.
"""

api_key: str = field(kw_only=True, metadata={"serializable": True})
api_key: str = field(kw_only=True, metadata={"serializable": False})
model: str = field(kw_only=True, metadata={"serializable": True})
client: Client = field(
default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
Expand All @@ -33,33 +33,37 @@
)

def try_run(self, prompt_stack: PromptStack) -> TextArtifact:
result = self.client.generate(**self._base_params(prompt_stack))
result = self.client.chat(**self._base_params(prompt_stack))

if result.generations:
if len(result.generations) == 1:
generation = result.generations[0]

return TextArtifact(value=generation.text.strip())
else:
raise Exception("completion with more than one choice is not supported yet")
else:
raise Exception("model response is empty")
return TextArtifact(value=result.text)

def try_stream(self, prompt_stack: PromptStack) -> Iterator[TextArtifact]:
result = self.client.generate(
**self._base_params(prompt_stack),
stream=True, # pyright: ignore[reportCallIssue]
)
result = self.client.chat_stream(**self._base_params(prompt_stack))

for chunk in result:
yield TextArtifact(value=chunk.text)
for event in result:
if event.event_type == "text-generation":
yield TextArtifact(value=event.text)

def _base_params(self, prompt_stack: PromptStack) -> dict:
prompt = self.prompt_stack_to_string(prompt_stack)
user_message = prompt_stack.inputs[-1].content
history_messages = [self.__to_cohere_message(input) for input in prompt_stack.inputs[:-1]]

return {
"prompt": self.prompt_stack_to_string(prompt_stack),
"model": self.model,
"message": user_message,
"chat_history": history_messages,
"temperature": self.temperature,
"end_sequences": self.tokenizer.stop_sequences,
"max_tokens": self.max_output_tokens(prompt),
"stop_sequences": self.tokenizer.stop_sequences,
}

def __to_cohere_message(self, input: PromptStack.Input) -> dict[str, Any]:
return {"role": self.__to_cohere_role(input.role), "text": input.content}

def __to_cohere_role(self, role: str) -> str:
if role == PromptStack.SYSTEM_ROLE:
return "SYSTEM"
if role == PromptStack.USER_ROLE:
return "USER"
elif role == PromptStack.ASSISTANT_ROLE:
return "CHATBOT"

Check warning on line 67 in griptape/drivers/prompt/cohere_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/cohere_prompt_driver.py#L67

Added line #L67 was not covered by tests
else:
return "USER"
2 changes: 2 additions & 0 deletions griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
from typing import Any

boto3 = import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any
Client = import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any

attrs.resolve_types(
attrs_cls,
Expand All @@ -122,6 +123,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"BaseTokenizer": BaseTokenizer,
"BasePromptModelDriver": BasePromptModelDriver,
"boto3": boto3,
"Client": Client,
},
)

Expand Down
4 changes: 2 additions & 2 deletions griptape/tokenizers/cohere_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

@define()
class CohereTokenizer(BaseTokenizer):
MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"command": 4096}
MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"command": 4096}
MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"command-r": 128000, "command": 4096, "embed": 512}
MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"command": 4096, "embed": 512}

client: Client = field(kw_only=True)

Expand Down
Loading
Loading