diff --git a/libs/partners/voyageai/langchain_voyageai/embeddings.py b/libs/partners/voyageai/langchain_voyageai/embeddings.py index 7aa7bdd3319eb..3010a8dd6948e 100644 --- a/libs/partners/voyageai/langchain_voyageai/embeddings.py +++ b/libs/partners/voyageai/langchain_voyageai/embeddings.py @@ -1,5 +1,4 @@ import logging -import os from typing import Iterable, List, Optional import voyageai # type: ignore @@ -10,7 +9,7 @@ SecretStr, root_validator, ) -from langchain_core.utils import convert_to_secret_str +from langchain_core.utils import secret_from_env logger = logging.getLogger(__name__) @@ -32,34 +31,32 @@ class VoyageAIEmbeddings(BaseModel, Embeddings): batch_size: int show_progress_bar: bool = False truncation: Optional[bool] = None - voyage_api_key: Optional[SecretStr] = None + voyage_api_key: SecretStr = Field( + alias="api_key", + default_factory=secret_from_env( + "VOYAGE_API_KEY", + error_message="Must set `VOYAGE_API_KEY` environment variable or " + "pass `api_key` to VoyageAIEmbeddings constructor.", + ), + ) class Config: extra = "forbid" + allow_population_by_field_name = True @root_validator(pre=True) def default_values(cls, values: dict) -> dict: """Set default batch size based on model""" - model = values.get("model") batch_size = values.get("batch_size") if batch_size is None: values["batch_size"] = 72 if model in ["voyage-2", "voyage-02"] else 7 return values - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: dict) -> dict: """Validate that VoyageAI credentials exist in environment.""" - voyage_api_key = values.get("voyage_api_key") or os.getenv( - "VOYAGE_API_KEY", None - ) - if voyage_api_key: - api_key_secretstr = convert_to_secret_str(voyage_api_key) - values["voyage_api_key"] = api_key_secretstr - - api_key_str = api_key_secretstr.get_secret_value() - else: - api_key_str = None + api_key_str = values["voyage_api_key"].get_secret_value() values["_client"] = voyageai.Client(api_key=api_key_str) values["_aclient"] = voyageai.client_async.AsyncClient(api_key=api_key_str) return values diff --git a/libs/partners/voyageai/tests/unit_tests/test_embeddings.py b/libs/partners/voyageai/tests/unit_tests/test_embeddings.py index 3ccee4e1a6e57..990ffeb5a97d3 100644 --- a/libs/partners/voyageai/tests/unit_tests/test_embeddings.py +++ b/libs/partners/voyageai/tests/unit_tests/test_embeddings.py @@ -9,6 +9,17 @@ def test_initialization_voyage_2() -> None: """Test embedding model initialization.""" + emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model=MODEL) + assert isinstance(emb, Embeddings) + assert emb.batch_size == 72 + assert emb.model == MODEL + assert emb._client is not None + + +def test_initialization_voyage_2_with_full_api_key_name() -> None: + """Test embedding model initialization.""" + # Testing that we can initialize the model using `voyage_api_key` + # instead of `api_key` emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model=MODEL) assert isinstance(emb, Embeddings) assert emb.batch_size == 72 @@ -18,7 +29,7 @@ def test_initialization_voyage_2() -> None: def test_initialization_voyage_1() -> None: """Test embedding model initialization.""" - emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model="voyage-01") + emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model="voyage-01") assert isinstance(emb, Embeddings) assert emb.batch_size == 7 assert emb.model == "voyage-01" @@ -28,7 +39,7 @@ def test_initialization_voyage_1() -> None: def test_initialization_voyage_1_batch_size() -> None: """Test embedding model initialization.""" emb = VoyageAIEmbeddings( - voyage_api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15 + api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15 ) assert isinstance(emb, Embeddings) assert emb.batch_size == 15