Skip to content

Commit

Permalink
voyageai[patch]: Upgrade root validators for pydantic 2 (#25455)
Browse files Browse the repository at this point in the history
Update @root_validators to be consistent with pydantic 2 semantics
  • Loading branch information
eyurtsev committed Aug 15, 2024
1 parent 4cdaca6 commit b297af5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 17 deletions.
27 changes: 12 additions & 15 deletions libs/partners/voyageai/langchain_voyageai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from typing import Iterable, List, Optional

import voyageai # type: ignore
Expand All @@ -10,7 +9,7 @@
SecretStr,
root_validator,
)
from langchain_core.utils import convert_to_secret_str
from langchain_core.utils import secret_from_env

logger = logging.getLogger(__name__)

Expand All @@ -32,34 +31,32 @@ class VoyageAIEmbeddings(BaseModel, Embeddings):
batch_size: int
show_progress_bar: bool = False
truncation: Optional[bool] = None
voyage_api_key: Optional[SecretStr] = None
voyage_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env(
"VOYAGE_API_KEY",
error_message="Must set `VOYAGE_API_KEY` environment variable or "
"pass `api_key` to VoyageAIEmbeddings constructor.",
),
)

class Config:
extra = "forbid"
allow_population_by_field_name = True

@root_validator(pre=True)
def default_values(cls, values: dict) -> dict:
"""Set default batch size based on model"""

model = values.get("model")
batch_size = values.get("batch_size")
if batch_size is None:
values["batch_size"] = 72 if model in ["voyage-2", "voyage-02"] else 7
return values

@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: dict) -> dict:
"""Validate that VoyageAI credentials exist in environment."""
voyage_api_key = values.get("voyage_api_key") or os.getenv(
"VOYAGE_API_KEY", None
)
if voyage_api_key:
api_key_secretstr = convert_to_secret_str(voyage_api_key)
values["voyage_api_key"] = api_key_secretstr

api_key_str = api_key_secretstr.get_secret_value()
else:
api_key_str = None
api_key_str = values["voyage_api_key"].get_secret_value()
values["_client"] = voyageai.Client(api_key=api_key_str)
values["_aclient"] = voyageai.client_async.AsyncClient(api_key=api_key_str)
return values
Expand Down
15 changes: 13 additions & 2 deletions libs/partners/voyageai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@

def test_initialization_voyage_2() -> None:
"""Test embedding model initialization."""
emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model=MODEL)
assert isinstance(emb, Embeddings)
assert emb.batch_size == 72
assert emb.model == MODEL
assert emb._client is not None


def test_initialization_voyage_2_with_full_api_key_name() -> None:
"""Test embedding model initialization."""
# Testing that we can initialize the model using `voyage_api_key`
# instead of `api_key`
emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model=MODEL)
assert isinstance(emb, Embeddings)
assert emb.batch_size == 72
Expand All @@ -18,7 +29,7 @@ def test_initialization_voyage_2() -> None:

def test_initialization_voyage_1() -> None:
"""Test embedding model initialization."""
emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model="voyage-01")
emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model="voyage-01")
assert isinstance(emb, Embeddings)
assert emb.batch_size == 7
assert emb.model == "voyage-01"
Expand All @@ -28,7 +39,7 @@ def test_initialization_voyage_1() -> None:
def test_initialization_voyage_1_batch_size() -> None:
"""Test embedding model initialization."""
emb = VoyageAIEmbeddings(
voyage_api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15
api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15
)
assert isinstance(emb, Embeddings)
assert emb.batch_size == 15
Expand Down

0 comments on commit b297af5

Please sign in to comment.