Skip to content

Commit

Permalink
elasticsearch: check for deployed models (#18973)
Browse files Browse the repository at this point in the history
When creating a new index, if we use a retrieval strategy that expects a
model to be deployed in Elasticsearch, check if a model with this name
is indeed deployed before creating an index. This lowers the probability
to get into a state in which an index was created with a faulty model
ID, which cannot be overwritten any more (the index has to manually be
deleted).
  • Loading branch information
maxjakob authored and hinthornw committed Apr 26, 2024
1 parent e513b50 commit a5020f2
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Union

import numpy as np
from elasticsearch import Elasticsearch
from elasticsearch import BadRequestError, ConflictError, Elasticsearch, NotFoundError
from langchain_core import __version__ as langchain_version

Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
Expand Down Expand Up @@ -88,3 +88,21 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity


def check_if_model_deployed(client: Elasticsearch, model_id: str) -> None:
try:
dummy = {"x": "y"}
client.ml.infer_trained_model(model_id=model_id, docs=[dummy])
except NotFoundError as err:
raise err
except ConflictError as err:
raise NotFoundError(
f"model '{model_id}' not found, please deploy it first",
meta=err.meta,
body=err.body,
) from err
except BadRequestError:
# This error is expected because we do not know the expected document
# shape and just use a dummy doc above.
pass
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from langchain_elasticsearch._utilities import (
DistanceStrategy,
check_if_model_deployed,
maximal_marginal_relevance,
with_user_agent_header,
)
Expand Down Expand Up @@ -199,6 +200,12 @@ def query(
else:
return {"knn": knn}

def before_index_setup(
self, client: "Elasticsearch", text_field: str, vector_query_field: str
) -> None:
if self.query_model_id:
check_if_model_deployed(client, self.query_model_id)

def index(
self,
dims_length: Union[int, None],
Expand Down Expand Up @@ -340,8 +347,10 @@ def _get_pipeline_name(self) -> str:
def before_index_setup(
self, client: "Elasticsearch", text_field: str, vector_query_field: str
) -> None:
# If model_id is provided, create a pipeline for the model
if self.model_id:
check_if_model_deployed(client, self.model_id)

# Create a pipeline for the model
client.ingest.put_pipeline(
id=self._get_pipeline_name(),
description="Embedding pipeline for langchain vectorstore",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, Generator, List, Union

import pytest
from elasticsearch import Elasticsearch
from elasticsearch import Elasticsearch, NotFoundError
from elasticsearch.helpers import BulkIndexError
from langchain_core.documents import Document

Expand Down Expand Up @@ -40,7 +40,7 @@
"""

modelsDeployed: List[str] = [
# "elser",
# ".elser_model_1",
# "sentence-transformers__all-minilm-l6-v2",
]

Expand Down Expand Up @@ -709,7 +709,7 @@ def assert_query(query_body: dict, query: str) -> dict:
assert output == [Document(page_content="bar")]

@pytest.mark.skipif(
"elser" not in modelsDeployed,
".elser_model_1" not in modelsDeployed,
reason="ELSER not deployed in ML Node, skipping test",
)
def test_similarity_search_with_sparse_infer_instack(
Expand All @@ -726,6 +726,35 @@ def test_similarity_search_with_sparse_infer_instack(
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]

def test_deployed_model_check_fails_approx(
self, elasticsearch_connection: dict, index_name: str
) -> None:
"""test that exceptions are raised if a specified model is not deployed"""
with pytest.raises(NotFoundError):
ElasticsearchStore.from_texts(
texts=["foo", "bar", "baz"],
embedding=ConsistentFakeEmbeddings(10),
**elasticsearch_connection,
index_name=index_name,
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
query_model_id="non-existing model ID",
),
)

def test_deployed_model_check_fails_sparse(
self, elasticsearch_connection: dict, index_name: str
) -> None:
"""test that exceptions are raised if a specified model is not deployed"""
with pytest.raises(NotFoundError):
ElasticsearchStore.from_texts(
texts=["foo", "bar", "baz"],
**elasticsearch_connection,
index_name=index_name,
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(
model_id="non-existing model ID"
),
)

def test_elasticsearch_with_relevance_score(
self, elasticsearch_connection: dict, index_name: str
) -> None:
Expand Down

0 comments on commit a5020f2

Please sign in to comment.