Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

voyageai[patch]: Upgrade root validators for pydantic 2 #25455

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
x
  • Loading branch information
eyurtsev committed Aug 15, 2024
commit 06ab3fd6bf1ea27fbaefb95449493c9fcb818361
27 changes: 12 additions & 15 deletions libs/partners/voyageai/langchain_voyageai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from typing import Iterable, List, Optional

import voyageai # type: ignore
Expand All @@ -10,7 +9,7 @@
SecretStr,
root_validator,
)
from langchain_core.utils import convert_to_secret_str
from langchain_core.utils import secret_from_env

logger = logging.getLogger(__name__)

Expand All @@ -32,34 +31,32 @@ class VoyageAIEmbeddings(BaseModel, Embeddings):
batch_size: int
show_progress_bar: bool = False
truncation: Optional[bool] = None
voyage_api_key: Optional[SecretStr] = None
voyage_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env(
"VOYAGE_API_KEY",
error_message="Must set `VOYAGE_API_KEY` environment variable or "
"pass `api_key` to VoyageAIEmbeddings constructor.",
),
)

class Config:
extra = "forbid"
allow_population_by_field_name = True

@root_validator(pre=True)
def default_values(cls, values: dict) -> dict:
"""Set default batch size based on model"""

model = values.get("model")
batch_size = values.get("batch_size")
if batch_size is None:
values["batch_size"] = 72 if model in ["voyage-2", "voyage-02"] else 7
return values

@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: dict) -> dict:
"""Validate that VoyageAI credentials exist in environment."""
voyage_api_key = values.get("voyage_api_key") or os.getenv(
"VOYAGE_API_KEY", None
)
if voyage_api_key:
api_key_secretstr = convert_to_secret_str(voyage_api_key)
values["voyage_api_key"] = api_key_secretstr

api_key_str = api_key_secretstr.get_secret_value()
else:
api_key_str = None
api_key_str = values["voyage_api_key"].get_secret_value()
values["_client"] = voyageai.Client(api_key=api_key_str)
values["_aclient"] = voyageai.client_async.AsyncClient(api_key=api_key_str)
return values
Expand Down
15 changes: 13 additions & 2 deletions libs/partners/voyageai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@

def test_initialization_voyage_2() -> None:
"""Test embedding model initialization."""
emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model=MODEL)
assert isinstance(emb, Embeddings)
assert emb.batch_size == 72
assert emb.model == MODEL
assert emb._client is not None


def test_initialization_voyage_2_with_full_api_key_name() -> None:
"""Test embedding model initialization."""
# Testing that we can initialize the model using `voyage_api_key`
# instead of `api_key`
emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model=MODEL)
assert isinstance(emb, Embeddings)
assert emb.batch_size == 72
Expand All @@ -18,7 +29,7 @@ def test_initialization_voyage_2() -> None:

def test_initialization_voyage_1() -> None:
"""Test embedding model initialization."""
emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model="voyage-01")
emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model="voyage-01")
assert isinstance(emb, Embeddings)
assert emb.batch_size == 7
assert emb.model == "voyage-01"
Expand All @@ -28,7 +39,7 @@ def test_initialization_voyage_1() -> None:
def test_initialization_voyage_1_batch_size() -> None:
"""Test embedding model initialization."""
emb = VoyageAIEmbeddings(
voyage_api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15
api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15
)
assert isinstance(emb, Embeddings)
assert emb.batch_size == 15
Expand Down
Loading