Skip to content

Commit

Permalink
Text to Speech Support (griptape-ai#755)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfrench authored and hkhajgiwale committed May 25, 2024
1 parent a7fa3c7 commit c98b26d
Show file tree
Hide file tree
Showing 51 changed files with 955 additions and 32 deletions.
1 change: 1 addition & 0 deletions .github/workflows/docs-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ jobs:
GT_CLOUD_STRUCTURE_RUN_ID: ${{ secrets.INTEG_GT_CLOUD_STRUCTURE_RUN_ID }}
AWS_IOT_CORE_ENDPOINT: ${{ secrets.INTEG_AWS_IOT_CORE_ENDPOINT }}
AWS_IOT_CORE_TOPIC: ${{ secrets.INTEG_AWS_IOT_CORE_TOPIC }}
ELEVEN_LABS_API_KEY: ${{ secrets.INTEG_ELEVEN_LABS_API_KEY }}
services:
postgres:
image: ankane/pgvector:v0.5.0
Expand Down
6 changes: 5 additions & 1 deletion docs/griptape-framework/data/artifacts.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,8 @@ Each blob has a [name](../../reference/griptape/artifacts/base_artifact.md#gript

## ImageArtifact

An [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md) is used for passing images back to the LLM. In addition to binary image data, an ImageArtifact includes image metadata like MIME type, dimensions, and prompt and model information for images returned by [image generation Drivers](../drivers/image-generation-drivers.md). It inherits from [BlobArtifact](#blobartifact).
An [ImageArtifact](../../reference/griptape/artifacts/image_artifact.md) is used for passing images back to the LLM. In addition to binary image data, an Image Artifact includes image metadata like MIME type, dimensions, and prompt and model information for images returned by [image generation Drivers](../drivers/image-generation-drivers.md). It inherits from [BlobArtifact](#blobartifact).

## AudioArtifact

An [AudioArtifact](../../reference/griptape/artifacts/audio_artifact.md) allows the Framework to interact with audio content. An Audio Artifact includes binary audio content as well as metadata like format, duration, and prompt and model information for audio returned generative models. It inherits from [BlobArtifact](#blobartifact).
54 changes: 54 additions & 0 deletions docs/griptape-framework/drivers/text-to-speech-drivers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
## Overview

[Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md) are used by [Text To Speech Engines](../engines/audio-engines.md) to build and execute API calls to audio generation models.

Provide a Driver when building an [Engine](../engines/audio-engines.md), then pass it to a [Tool](../tools/index.md) for use by an [Agent](../structures/agents.md):

### Eleven Labs

The [Eleven Labs Text to Speech Driver](../../reference/griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.md) provides support for text-to-speech models hosted by Eleven Labs. This Driver supports configurations specific to Eleven Labs, like voice selection and output format.

```python
import os

from griptape.drivers import ElevenLabsTextToSpeechDriver
from griptape.engines import TextToSpeechEngine
from griptape.tools.text_to_speech_client.tool import TextToSpeechClient
from griptape.structures import Agent


driver = ElevenLabsTextToSpeechDriver(
api_key=os.getenv("ELEVEN_LABS_API_KEY"),
model="eleven_multilingual_v2",
voice="Matilda",
)

tool = TextToSpeechClient(
engine=TextToSpeechEngine(
text_to_speech_driver=driver,
),
)

Agent(tools=[tool]).run("Generate audio from this text: 'Hello, world!'")
```

## OpenAI

The [OpenAI Text to Speech Driver](../../reference/griptape/drivers/text_to_speech/openai_text_to_speech_driver.md) provides support for text-to-speech models hosted by OpenAI. This Driver supports configurations specific to OpenAI, like voice selection and output format.

```python
from griptape.drivers import OpenAiTextToSpeechDriver
from griptape.engines import TextToSpeechEngine
from griptape.tools.text_to_speech_client.tool import TextToSpeechClient
from griptape.structures import Agent

driver = OpenAiTextToSpeechDriver()

tool = TextToSpeechClient(
engine=TextToSpeechEngine(
text_to_speech_driver=driver,
),
)

Agent(tools=[tool]).run("Generate audio from this text: 'Hello, world!'")
```
29 changes: 29 additions & 0 deletions docs/griptape-framework/engines/audio-engines.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## Overview

[Audio Generation Engines](../../reference/griptape/engines/audio/index.md) facilitate audio generation. Audio Generation Engines provides a `run` method that accepts the necessary inputs for its particular mode and provides the request to the configured [Driver](../drivers/text-to-speech-drivers.md).

### Text to Speech Engine

This Engine facilitates synthesizing speech from text inputs.

```python
import os

from griptape.drivers import ElevenLabsTextToSpeechDriver
from griptape.engines import TextToSpeechEngine


driver = ElevenLabsTextToSpeechDriver(
api_key=os.getenv("ELEVEN_LABS_API_KEY"),
model="eleven_multilingual_v2",
voice="Rachel",
)

engine = TextToSpeechEngine(
text_to_speech_driver=driver,
)

engine.run(
prompts=["Hello, world!"],
)
```
27 changes: 27 additions & 0 deletions docs/griptape-tools/official-tools/text-to-speech-client.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# TextToSpeechClient

This tool enables LLMs to synthesize speech from text using [Text to Speech Engines](../../reference/griptape/engines/audio/text_to_speech_engine.md) and [Text to Speech Drivers](../../reference/griptape/drivers/text_to_speech/index.md).

```python
import os

from griptape.drivers import ElevenLabsTextToSpeechDriver
from griptape.engines import TextToSpeechEngine
from griptape.tools.text_to_speech_client.tool import TextToSpeechClient
from griptape.structures import Agent


driver = ElevenLabsTextToSpeechDriver(
api_key=os.getenv("ELEVEN_LABS_API_KEY"),
model="eleven_multilingual_v2",
voice="Matilda",
)

tool = TextToSpeechClient(
engine=TextToSpeechEngine(
text_to_speech_driver=driver,
),
)

Agent(tools=[tool]).run("Generate audio from this text: 'Hello, world!'")
```
4 changes: 3 additions & 1 deletion griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .list_artifact import ListArtifact
from .media_artifact import MediaArtifact
from .image_artifact import ImageArtifact
from .audio_artifact import AudioArtifact


__all__ = [
Expand All @@ -17,6 +18,7 @@
"BlobArtifact",
"CsvRowArtifact",
"ListArtifact",
"ImageArtifact",
"MediaArtifact",
"ImageArtifact",
"AudioArtifact",
]
12 changes: 12 additions & 0 deletions griptape/artifacts/audio_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

from attr import define

from griptape.artifacts import MediaArtifact


@define
class AudioArtifact(MediaArtifact):
"""AudioArtifact is a type of MediaArtifact representing audio."""

media_type: str = "audio"
5 changes: 5 additions & 0 deletions griptape/config/structure_global_drivers_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
DummyPromptDriver,
DummyImageQueryDriver,
BaseImageQueryDriver,
BaseTextToSpeechDriver,
)
from griptape.drivers.text_to_speech.dummy_text_to_speech_driver import DummyTextToSpeechDriver
from griptape.mixins.serializable_mixin import SerializableMixin


Expand All @@ -38,3 +40,6 @@ class StructureGlobalDriversConfig(SerializableMixin):
conversation_memory_driver: Optional[BaseConversationMemoryDriver] = field(
default=None, kw_only=True, metadata={"serializable": True}
)
text_to_speech_driver: BaseTextToSpeechDriver = field(
default=Factory(lambda: DummyTextToSpeechDriver()), kw_only=True, metadata={"serializable": True}
)
9 changes: 9 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@
from .file_manager.local_file_manager_driver import LocalFileManagerDriver
from .file_manager.amazon_s3_file_manager_driver import AmazonS3FileManagerDriver

from .text_to_speech.base_text_to_speech_driver import BaseTextToSpeechDriver
from .text_to_speech.dummy_text_to_speech_driver import DummyTextToSpeechDriver
from .text_to_speech.elevenlabs_text_to_speech_driver import ElevenLabsTextToSpeechDriver
from .text_to_speech.openai_text_to_speech_driver import OpenAiTextToSpeechDriver

from .structure_run.base_structure_run_driver import BaseStructureRunDriver
from .structure_run.griptape_cloud_structure_run_driver import GriptapeCloudStructureRunDriver
from .structure_run.local_structure_run_driver import LocalStructureRunDriver
Expand Down Expand Up @@ -185,6 +190,10 @@
"BaseFileManagerDriver",
"LocalFileManagerDriver",
"AmazonS3FileManagerDriver",
"BaseTextToSpeechDriver",
"DummyTextToSpeechDriver",
"ElevenLabsTextToSpeechDriver",
"OpenAiTextToSpeechDriver",
"BaseStructureRunDriver",
"GriptapeCloudStructureRunDriver",
"LocalStructureRunDriver",
Expand Down
Empty file.
44 changes: 44 additions & 0 deletions griptape/drivers/text_to_speech/base_text_to_speech_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional

from attr import define, field

from griptape.artifacts.audio_artifact import AudioArtifact
from griptape.events.finish_text_to_speech_event import FinishTextToSpeechEvent
from griptape.events.start_text_to_speech_event import StartTextToSpeechEvent
from griptape.mixins import ExponentialBackoffMixin, SerializableMixin

if TYPE_CHECKING:
from griptape.structures import Structure


@define
class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
model: str = field(kw_only=True, metadata={"serializable": True})
structure: Optional[Structure] = field(default=None, kw_only=True)

def before_run(self, prompts: list[str]) -> None:
if self.structure:
self.structure.publish_event(StartTextToSpeechEvent(prompts=prompts))

def after_run(self) -> None:
if self.structure:
self.structure.publish_event(FinishTextToSpeechEvent())

def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
for attempt in self.retrying():
with attempt:
self.before_run(prompts)
result = self.try_text_to_audio(prompts)
self.after_run()

return result

else:
raise Exception("Failed to run text to audio generation")

@abstractmethod
def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
...
13 changes: 13 additions & 0 deletions griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Optional
from attrs import define, field
from griptape.artifacts.audio_artifact import AudioArtifact
from griptape.drivers import BaseTextToSpeechDriver
from griptape.exceptions import DummyException


@define
class DummyTextToSpeechDriver(BaseTextToSpeechDriver):
model: str = field(init=False)

def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
raise DummyException(__class__.__name__, "try_text_to_audio")
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Any

from attr import define, field, Factory

from griptape.artifacts.audio_artifact import AudioArtifact
from griptape.drivers import BaseTextToSpeechDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from elevenlabs.client import ElevenLabs


@define
class ElevenLabsTextToSpeechDriver(BaseTextToSpeechDriver):
api_key: str = field(kw_only=True, metadata={"serializable": True})
client: Any = field(
default=Factory(
lambda self: import_optional_dependency("elevenlabs.client").ElevenLabs(api_key=self.api_key),
takes_self=True,
),
kw_only=True,
metadata={"serializable": True},
)
voice: str = field(kw_only=True, metadata={"serializable": True})
output_format: str = field(default="mp3_44100_128", kw_only=True, metadata={"serializable": True})

def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
audio = self.client.generate(
text=". ".join(prompts), voice=self.voice, model=self.model, output_format=self.output_format
)

content = b""
for chunk in audio:
content += chunk

# All ElevenLabs audio format strings have the following structure:
# {format}_{sample_rate}_{bitrate}
artifact_format = self.output_format.split("_")[0]

return AudioArtifact(value=content, format=artifact_format)
36 changes: 36 additions & 0 deletions griptape/drivers/text_to_speech/openai_text_to_speech_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

from typing import Optional, Literal

import openai
from attr import define, field, Factory

from griptape.artifacts.audio_artifact import AudioArtifact
from griptape.drivers import BaseTextToSpeechDriver


@define
class OpenAiTextToSpeechDriver(BaseTextToSpeechDriver):
model: str = field(default="tts-1", kw_only=True, metadata={"serializable": True})
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = field(
default="alloy", kw_only=True, metadata={"serializable": True}
)
format: Literal["mp3", "opus", "aac", "flac"] = field(default="mp3", kw_only=True, metadata={"serializable": True})
api_type: str = field(default=openai.api_type, kw_only=True)
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True)
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
takes_self=True,
)
)

def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
response = self.client.audio.speech.create(
input=". ".join(prompts), voice=self.voice, model=self.model, response_format=self.format
)

return AudioArtifact(value=response.content, format=self.format)
2 changes: 2 additions & 0 deletions griptape/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .image.inpainting_image_generation_engine import InpaintingImageGenerationEngine
from .image.outpainting_image_generation_engine import OutpaintingImageGenerationEngine
from .image_query.image_query_engine import ImageQueryEngine
from .audio.text_to_speech_engine import TextToSpeechEngine

__all__ = [
"BaseQueryEngine",
Expand All @@ -26,4 +27,5 @@
"InpaintingImageGenerationEngine",
"OutpaintingImageGenerationEngine",
"ImageQueryEngine",
"TextToSpeechEngine",
]
Empty file.
14 changes: 14 additions & 0 deletions griptape/engines/audio/text_to_speech_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from __future__ import annotations

from attr import define, field

from griptape.artifacts.audio_artifact import AudioArtifact
from griptape.drivers import BaseTextToSpeechDriver


@define
class TextToSpeechEngine:
text_to_speech_driver: BaseTextToSpeechDriver = field(kw_only=True)

def run(self, prompts: list[str], *args, **kwargs) -> AudioArtifact:
return self.text_to_speech_driver.try_text_to_audio(prompts=prompts)
7 changes: 6 additions & 1 deletion griptape/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from .finish_image_generation_event import FinishImageGenerationEvent
from .start_image_query_event import StartImageQueryEvent
from .finish_image_query_event import FinishImageQueryEvent

from .base_text_to_speech_event import BaseTextToSpeechEvent
from .start_text_to_speech_event import StartTextToSpeechEvent
from .finish_text_to_speech_event import FinishTextToSpeechEvent

__all__ = [
"BaseEvent",
Expand All @@ -37,4 +39,7 @@
"FinishImageGenerationEvent",
"StartImageQueryEvent",
"FinishImageQueryEvent",
"BaseTextToSpeechEvent",
"StartTextToSpeechEvent",
"FinishTextToSpeechEvent",
]
Loading

0 comments on commit c98b26d

Please sign in to comment.