Skip to content

Commit

Permalink
Enable ElasticsearchStore to retrieve with the pure BM25 algorithm wi…
Browse files Browse the repository at this point in the history
…thout vector search (#6)
  • Loading branch information
g-votte authored Apr 2, 2024
1 parent 8a30585 commit 10a96cb
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 0 deletions.
90 changes: 90 additions & 0 deletions libs/elasticsearch/langchain_elasticsearch/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def index(
self,
dims_length: Union[int, None],
vector_query_field: str,
text_field: str,
similarity: Union[DistanceStrategy, None],
) -> Dict:
"""
Expand All @@ -80,6 +81,7 @@ def index(
or None if not using vector-based query.
vector_query_field: The field containing the vector
representations in the index.
text_field: The field containing the text data in the index.
similarity: The similarity strategy to use,
or None if not using one.
Expand Down Expand Up @@ -210,6 +212,7 @@ def index(
self,
dims_length: Union[int, None],
vector_query_field: str,
text_field: str,
similarity: Union[DistanceStrategy, None],
) -> Dict:
"""Create the mapping for the Elasticsearch index."""
Expand Down Expand Up @@ -289,6 +292,7 @@ def index(
self,
dims_length: Union[int, None],
vector_query_field: str,
text_field: str,
similarity: Union[DistanceStrategy, None],
) -> Dict:
"""Create the mapping for the Elasticsearch index."""
Expand Down Expand Up @@ -372,6 +376,7 @@ def index(
self,
dims_length: Union[int, None],
vector_query_field: str,
text_field: str,
similarity: Union[DistanceStrategy, None],
) -> Dict:
return {
Expand All @@ -389,6 +394,76 @@ def require_inference(self) -> bool:
return False


class BM25RetrievalStrategy(BaseRetrievalStrategy):
"""Retrieval strategy using the native BM25 algorithm of Elasticsearch."""

def __init__(self, k1: Union[float, None] = None, b: Union[float, None] = None):
self.k1 = k1
self.b = b

def query(
self,
query_vector: Union[List[float], None],
query: Union[str, None],
k: int,
fetch_k: int,
vector_query_field: str,
text_field: str,
filter: List[dict],
similarity: Union[DistanceStrategy, None],
) -> Dict:
return {
"query": {
"bool": {
"must": [
{
"match": {
text_field: {
"query": query,
}
},
},
],
"filter": filter,
},
},
}

def index(
self,
dims_length: Union[int, None],
vector_query_field: str,
text_field: str,
similarity: Union[DistanceStrategy, None],
) -> Dict:
mappings: Dict = {
"properties": {
text_field: {
"type": "text",
"similarity": "custom_bm25",
},
},
}
settings: Dict = {
"similarity": {
"custom_bm25": {
"type": "BM25",
},
},
}

if self.k1 is not None:
settings["similarity"]["custom_bm25"]["k1"] = self.k1

if self.b is not None:
settings["similarity"]["custom_bm25"]["b"] = self.b

return {"mappings": mappings, "settings": settings}

def require_inference(self) -> bool:
return False


class ElasticsearchStore(VectorStore):
"""`Elasticsearch` vector store.
Expand Down Expand Up @@ -905,6 +980,7 @@ def _create_index_if_not_exists(

indexSettings = self.strategy.index(
vector_query_field=self.vector_query_field,
text_field=self.query_field,
dims_length=dims_length,
similarity=self.distance_strategy,
)
Expand Down Expand Up @@ -1284,3 +1360,17 @@ def SparseVectorRetrievalStrategy(
deployed to Elasticsearch.
"""
return SparseRetrievalStrategy(model_id=model_id)

@staticmethod
def BM25RetrievalStrategy(
k1: Union[float, None] = None, b: Union[float, None] = None
) -> "BM25RetrievalStrategy":
"""Used to apply BM25 without vector search.
Args:
k1: Optional. This corresponds to the BM25 parameter, k1. Default is None,
which uses the default setting of Elasticsearch.
b: Optional. This corresponds to the BM25 parameter, b. Default is None,
which uses the default setting of Elasticsearch.
"""
return BM25RetrievalStrategy(k1=k1, b=b)
61 changes: 61 additions & 0 deletions libs/elasticsearch/tests/integration_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,67 @@ def test_elasticsearch_with_relevance_score(
)
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)]

def test_similarity_search_bm25_search(
self, elasticsearch_connection: dict, index_name: str
) -> None:
"""Test end to end using the BM25 retrieval strategy."""
texts = ["foo", "bar", "baz"]
docsearch = ElasticsearchStore.from_texts(
texts,
None,
**elasticsearch_connection,
index_name=index_name,
strategy=ElasticsearchStore.BM25RetrievalStrategy(),
)

def assert_query(query_body: dict, query: str) -> dict:
assert query_body == {
"query": {
"bool": {
"must": [{"match": {"text": {"query": "foo"}}}],
"filter": [],
}
}
}
return query_body

output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
assert output == [Document(page_content="foo")]

def test_similarity_search_bm25_search_with_filter(
self, elasticsearch_connection: dict, index_name: str
) -> None:
"""Test end to using the BM25 retrieval strategy with metadata."""
texts = ["foo", "foo", "foo"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = ElasticsearchStore.from_texts(
texts,
None,
**elasticsearch_connection,
index_name=index_name,
metadatas=metadatas,
strategy=ElasticsearchStore.BM25RetrievalStrategy(),
)

def assert_query(query_body: dict, query: str) -> dict:
assert query_body == {
"query": {
"bool": {
"must": [{"match": {"text": {"query": "foo"}}}],
"filter": [{"term": {"metadata.page": 1}}],
}
}
}
return query_body

output = docsearch.similarity_search(
"foo",
k=3,
custom_query=assert_query,
filter=[{"term": {"metadata.page": 1}}],
)
assert output == [Document(page_content="foo", metadata={"page": 1})]

def test_elasticsearch_with_relevance_threshold(
self, elasticsearch_connection: dict, index_name: str
) -> None:
Expand Down

0 comments on commit 10a96cb

Please sign in to comment.