Skip to content

Commit

Permalink
deprecate old retrieval classes
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjakob committed May 7, 2024
1 parent 8ce55b7 commit 46e1c12
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 3 deletions.
21 changes: 20 additions & 1 deletion libs/elasticsearch/langchain_elasticsearch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,40 @@
from elasticsearch.helpers.vectorstore import (
BM25Strategy,
DenseVectorScriptScoreStrategy,
DenseVectorStrategy,
DistanceMetric,
RetrievalStrategy,
SparseVectorStrategy,
)

from langchain_elasticsearch.cache import ElasticsearchCache
from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
from langchain_elasticsearch.retrievers import ElasticsearchRetriever
from langchain_elasticsearch.vectorstores import (
ApproxRetrievalStrategy,
BM25RetrievalStrategy,
ElasticsearchStore,
ExactRetrievalStrategy,
SparseRetrievalStrategy,
)

__all__ = [
"ApproxRetrievalStrategy",
"ElasticsearchCache",
"ElasticsearchChatMessageHistory",
"ElasticsearchEmbeddings",
"ElasticsearchRetriever",
"ElasticsearchStore",
# retrieval strategies
"BM25Strategy",
"DenseVectorScriptScoreStrategy",
"DenseVectorStrategy",
"DistanceMetric",
"RetrievalStrategy",
"SparseVectorStrategy",
# deprecated retrieval strategies
"ApproxRetrievalStrategy",
"BM25RetrievalStrategy",
"ExactRetrievalStrategy",
"SparseRetrievalStrategy",
]
17 changes: 16 additions & 1 deletion libs/elasticsearch/langchain_elasticsearch/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from elasticsearch.helpers.vectorstore import (
VectorStore as EVectorStore,
)
from langchain_core._api.deprecation import deprecated
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
Expand All @@ -39,6 +40,7 @@
logger = logging.getLogger(__name__)


@deprecated("0.1.4", alternative="RetrievalStrategy", pending=True)
class BaseRetrievalStrategy(ABC):
"""Base class for `Elasticsearch` retrieval strategies."""

Expand Down Expand Up @@ -124,6 +126,7 @@ def require_inference(self) -> bool:
return True


@deprecated("0.1.4", alternative="DenseVectorStrategy", pending=True)
class ApproxRetrievalStrategy(BaseRetrievalStrategy):
"""Approximate retrieval strategy using the `HNSW` algorithm."""

Expand Down Expand Up @@ -251,6 +254,7 @@ def index(
}


@deprecated("0.1.4", alternative="DenseVectorScriptScoreStrategy", pending=True)
class ExactRetrievalStrategy(BaseRetrievalStrategy):
"""Exact retrieval strategy using the `script_score` query."""

Expand Down Expand Up @@ -319,6 +323,7 @@ def index(
}


@deprecated("0.1.4", alternative="SparseVectorStrategy", pending=True)
class SparseRetrievalStrategy(BaseRetrievalStrategy):
"""Sparse retrieval strategy using the `text_expansion` processor."""

Expand Down Expand Up @@ -403,6 +408,7 @@ def require_inference(self) -> bool:
return False


@deprecated("0.1.4", alternative="BM25Strategy", pending=True)
class BM25RetrievalStrategy(BaseRetrievalStrategy):
"""Retrieval strategy using the native BM25 algorithm of Elasticsearch."""

Expand Down Expand Up @@ -474,9 +480,14 @@ def require_inference(self) -> bool:


def _convert_retrieval_strategy(
langchain_strategy: BaseRetrievalStrategy, distance: DistanceStrategy
langchain_strategy: BaseRetrievalStrategy,
distance: Optional[DistanceStrategy] = None,
) -> RetrievalStrategy:
if isinstance(langchain_strategy, ApproxRetrievalStrategy):
if distance is None:
raise ValueError(
"ApproxRetrievalStrategy requires a distance strategy to be provided."
)
return DenseVectorStrategy(
distance=DistanceMetric[distance],
model_id=langchain_strategy.query_model_id,
Expand All @@ -488,6 +499,10 @@ def _convert_retrieval_strategy(
rrf=False if langchain_strategy.rrf is None else langchain_strategy.rrf,
)
elif isinstance(langchain_strategy, ExactRetrievalStrategy):
if distance is None:
raise ValueError(
"ExactRetrievalStrategy requires a distance strategy to be provided."
)
return DenseVectorScriptScoreStrategy(distance=DistanceMetric[distance])
elif isinstance(langchain_strategy, SparseRetrievalStrategy):
return SparseVectorStrategy(langchain_strategy.model_id)
Expand Down
11 changes: 10 additions & 1 deletion libs/elasticsearch/tests/unit_tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from langchain_elasticsearch import __all__

EXPECTED_ALL = [
"ApproxRetrievalStrategy",
"ElasticsearchCache",
"ElasticsearchChatMessageHistory",
"ElasticsearchEmbeddings",
"ElasticsearchRetriever",
"ElasticsearchStore",
# retrieval strategies
"BM25Strategy",
"DenseVectorScriptScoreStrategy",
"DenseVectorStrategy",
"DistanceMetric",
"RetrievalStrategy",
"SparseVectorStrategy",
# deprecated retrieval strategies
"ApproxRetrievalStrategy",
"BM25RetrievalStrategy",
"ExactRetrievalStrategy",
"SparseRetrievalStrategy",
]
Expand Down
43 changes: 43 additions & 0 deletions libs/elasticsearch/tests/unit_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@
from langchain_elasticsearch.embeddings import Embeddings, EmbeddingServiceAdapter
from langchain_elasticsearch.vectorstores import (
ApproxRetrievalStrategy,
BM25RetrievalStrategy,
BM25Strategy,
DenseVectorScriptScoreStrategy,
DenseVectorStrategy,
DistanceMetric,
DistanceStrategy,
ElasticsearchStore,
ExactRetrievalStrategy,
SparseRetrievalStrategy,
SparseVectorStrategy,
_convert_retrieval_strategy,
_hits_to_docs_scores,
)

Expand Down Expand Up @@ -143,6 +153,39 @@ def test_doc_field_to_metadata(self) -> None:
assert actual == expected


class TestConvertStrategy:
def test_dense_approx(self) -> None:
actual = _convert_retrieval_strategy(
ApproxRetrievalStrategy(query_model_id="my model", hybrid=True, rrf=False),
distance=DistanceStrategy.DOT_PRODUCT,
)
assert isinstance(actual, DenseVectorStrategy)
assert actual.distance == DistanceMetric.DOT_PRODUCT
assert actual.model_id == "my model"
assert actual.hybrid is True
assert actual.rrf is False

def test_dense_exact(self) -> None:
actual = _convert_retrieval_strategy(
ExactRetrievalStrategy(), distance=DistanceStrategy.EUCLIDEAN_DISTANCE
)
assert isinstance(actual, DenseVectorScriptScoreStrategy)
assert actual.distance == DistanceMetric.EUCLIDEAN_DISTANCE

def test_sparse(self) -> None:
actual = _convert_retrieval_strategy(
SparseRetrievalStrategy(model_id="my model ID")
)
assert isinstance(actual, SparseVectorStrategy)
assert actual.model_id == "my model ID"

def test_bm25(self) -> None:
actual = _convert_retrieval_strategy(BM25RetrievalStrategy(k1=1.7, b=5.4))
assert isinstance(actual, BM25Strategy)
assert actual.k1 == 1.7
assert actual.b == 5.4


class TestVectorStore:
@pytest.fixture
def embeddings(self) -> Embeddings:
Expand Down

0 comments on commit 46e1c12

Please sign in to comment.