From aa9c150b9ba2bad483ce4876b4a738b3d8e3cf63 Mon Sep 17 00:00:00 2001 From: Gabriele Ghisleni <74197369+GabrieleGhisleni@users.noreply.github.com> Date: Thu, 30 May 2024 17:28:50 +0200 Subject: [PATCH] integration for caching embeddings (#27) * add cache embeddings * fix typo * fix typo * run lint and fixes * run lint and fixes * add integration test * add integration test * add type annotation * add type annotation * add integration tests and collapsce function * add integration tests and collapsce function * add integration tests and collapsce function * add integration tests and collapsce function * fix docstring * minor refactoring * minor refactoring * as bytestore * as bytestore * fix toml * fix toml * fix toml * fix toml * removed setup connection * removed setup connection * removed setup connection --------- Co-authored-by: Gabriele Ghisleni --- libs/elasticsearch/README.md | 65 +++- .../langchain_elasticsearch/__init__.py | 6 +- .../langchain_elasticsearch/_utilities.py | 13 +- .../langchain_elasticsearch/cache.py | 315 +++++++++++++++--- .../langchain_elasticsearch/chat_history.py | 6 +- .../langchain_elasticsearch/vectorstores.py | 16 +- libs/elasticsearch/pyproject.toml | 2 +- libs/elasticsearch/tests/conftest.py | 38 ++- .../tests/integration_tests/test_cache.py | 180 +++++++++- .../integration_tests/test_chat_history.py | 6 +- .../integration_tests/test_vectorstores.py | 11 +- .../tests/unit_tests/test_cache.py | 310 +++++++++++++++-- .../tests/unit_tests/test_imports.py | 1 + 13 files changed, 836 insertions(+), 133 deletions(-) diff --git a/libs/elasticsearch/README.md b/libs/elasticsearch/README.md index d34d293..f56cbad 100644 --- a/libs/elasticsearch/README.md +++ b/libs/elasticsearch/README.md @@ -119,15 +119,13 @@ A caching layer for LLMs that uses Elasticsearch. Simple example: ```python -from elasticsearch import Elasticsearch from langchain.globals import set_llm_cache from langchain_elasticsearch import ElasticsearchCache -es_client = Elasticsearch(hosts="http://localhost:9200") set_llm_cache( ElasticsearchCache( - es_connection=es_client, + es_url="http://localhost:9200", index_name="llm-chat-cache", metadata={"project": "my_chatgpt_project"}, ) @@ -153,7 +151,6 @@ The new cache class can be applied also to a pre-existing cache index: import json from typing import Any, Dict, List -from elasticsearch import Elasticsearch from langchain.globals import set_llm_cache from langchain_core.caches import RETURN_VAL_TYPE @@ -185,11 +182,65 @@ class SearchableElasticsearchCache(ElasticsearchCache): ] -es_client = Elasticsearch(hosts="http://localhost:9200") set_llm_cache( - SearchableElasticsearchCache(es_connection=es_client, index_name="llm-chat-cache") + SearchableElasticsearchCache( + es_url="http://localhost:9200", + index_name="llm-chat-cache" + ) ) ``` When overriding the mapping and the document building, -please only make additive modifications, keeping the base mapping intact. \ No newline at end of file +please only make additive modifications, keeping the base mapping intact. + +### ElasticsearchEmbeddingsCache + +Store and temporarily cache embeddings. + +Caching embeddings is obtained by using the [CacheBackedEmbeddings](https://python.langchain.com/docs/modules/data_connection/text_embedding/caching_embeddings), it can be instantiated using `CacheBackedEmbeddings.from_bytes_store` method. + +```python +from langchain.embeddings import CacheBackedEmbeddings +from langchain_openai import OpenAIEmbeddings + +from langchain_elasticsearch import ElasticsearchEmbeddingsCache + +underlying_embeddings = OpenAIEmbeddings(model="text-embedding-3-small") + +store = ElasticsearchEmbeddingsCache( + es_url="http://localhost:9200", + index_name="llm-chat-cache", + metadata={"project": "my_chatgpt_project"}, + namespace="my_chatgpt_project", +) + +embeddings = CacheBackedEmbeddings.from_bytes_store( + underlying_embeddings=OpenAIEmbeddings(), + document_embedding_cache=store, + query_embedding_cache=store, +) +``` + +Similarly to the chat cache, one can subclass `ElasticsearchEmbeddingsCache` in order to index vectors for search. + +```python +from typing import Any, Dict, List +from langchain_elasticsearch import ElasticsearchEmbeddingsCache + +class SearchableElasticsearchStore(ElasticsearchEmbeddingsCache): + @property + def mapping(self) -> Dict[str, Any]: + mapping = super().mapping + mapping["mappings"]["properties"]["vector"] = { + "type": "dense_vector", + "dims": 1536, + "index": True, + "similarity": "dot_product", + } + return mapping + + def build_document(self, llm_input: str, vector: List[float]) -> Dict[str, Any]: + body = super().build_document(llm_input, vector) + body["vector"] = vector + return body +``` diff --git a/libs/elasticsearch/langchain_elasticsearch/__init__.py b/libs/elasticsearch/langchain_elasticsearch/__init__.py index 1e40763..17611c4 100644 --- a/libs/elasticsearch/langchain_elasticsearch/__init__.py +++ b/libs/elasticsearch/langchain_elasticsearch/__init__.py @@ -7,7 +7,10 @@ SparseVectorStrategy, ) -from langchain_elasticsearch.cache import ElasticsearchCache +from langchain_elasticsearch.cache import ( + ElasticsearchCache, + ElasticsearchEmbeddingsCache, +) from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings from langchain_elasticsearch.retrievers import ElasticsearchRetriever @@ -23,6 +26,7 @@ "ElasticsearchCache", "ElasticsearchChatMessageHistory", "ElasticsearchEmbeddings", + "ElasticsearchEmbeddingsCache", "ElasticsearchRetriever", "ElasticsearchStore", # retrieval strategies diff --git a/libs/elasticsearch/langchain_elasticsearch/_utilities.py b/libs/elasticsearch/langchain_elasticsearch/_utilities.py index fe3893b..c6dc21d 100644 --- a/libs/elasticsearch/langchain_elasticsearch/_utilities.py +++ b/libs/elasticsearch/langchain_elasticsearch/_utilities.py @@ -1,8 +1,11 @@ +import logging from enum import Enum -from elasticsearch import BadRequestError, ConflictError, Elasticsearch, NotFoundError +from elasticsearch import Elasticsearch, exceptions from langchain_core import __version__ as langchain_version +logger = logging.getLogger(__name__) + class DistanceStrategy(str, Enum): """Enumerator of the Distance strategies for calculating distances @@ -29,15 +32,15 @@ def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None: try: dummy = {"x": "y"} client.ml.infer_trained_model(model_id=model_id, docs=[dummy]) - except NotFoundError as err: + except exceptions.NotFoundError as err: raise err - except ConflictError as err: - raise NotFoundError( + except exceptions.ConflictError as err: + raise exceptions.NotFoundError( f"model '{model_id}' not found, please deploy it first", meta=err.meta, body=err.body, ) from err - except BadRequestError: + except exceptions.BadRequestError: # This error is expected because we do not know the expected document # shape and just use a dummy doc above. pass diff --git a/libs/elasticsearch/langchain_elasticsearch/cache.py b/libs/elasticsearch/langchain_elasticsearch/cache.py index 2f05658..0a8da8a 100644 --- a/libs/elasticsearch/langchain_elasticsearch/cache.py +++ b/libs/elasticsearch/langchain_elasticsearch/cache.py @@ -1,13 +1,29 @@ +import base64 import hashlib import logging from datetime import datetime from functools import cached_property -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, +) -import elasticsearch -from elasticsearch import NotFoundError +from elasticsearch import ( + Elasticsearch, + exceptions, + helpers, +) +from elasticsearch.helpers import BulkIndexError from langchain_core.caches import RETURN_VAL_TYPE, BaseCache from langchain_core.load import dumps, loads +from langchain_core.stores import ByteStore from langchain_elasticsearch.client import create_elasticsearch_client @@ -17,6 +33,22 @@ logger = logging.getLogger(__name__) +def _manage_cache_index( + es_client: Elasticsearch, index_name: str, mapping: Dict[str, Any] +) -> bool: + """Write or update an index or alias according to the default mapping""" + if es_client.indices.exists_alias(name=index_name): + es_client.indices.put_mapping(index=index_name, body=mapping["mappings"]) + return True + + elif not es_client.indices.exists(index=index_name): + logger.debug(f"Creating new Elasticsearch index: {index_name}") + es_client.indices.create(index=index_name, body=mapping) + return False + + return False + + class ElasticsearchCache(BaseCache): """An Elasticsearch cache integration for LLMs.""" @@ -27,7 +59,6 @@ def __init__( store_input_params: bool = True, metadata: Optional[Dict[str, Any]] = None, *, - es_connection: Optional["Elasticsearch"] = None, es_url: Optional[str] = None, es_cloud_id: Optional[str] = None, es_user: Optional[str] = None, @@ -37,8 +68,8 @@ def __init__( ): """ Initialize the Elasticsearch cache store by specifying the index/alias - to use and determining which additional information (like input, timestamp, - input parameters, and any other metadata) should be stored in the cache. + to use and determining which additional information (like input, input + parameters, and any other metadata) should be stored in the cache. Args: index_name (str): The name of the index or the alias to use for the cache. @@ -52,7 +83,6 @@ def __init__( metadata (Optional[dict]): Additional metadata to store in the cache, for filtering purposes. This must be JSON serializable in an Elasticsearch document. Default to None. - es_connection: Optional pre-existing Elasticsearch connection. es_url: URL of the Elasticsearch instance to connect to. es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. es_user: Username to use when connecting to Elasticsearch. @@ -60,48 +90,23 @@ def __init__( es_api_key: API key to use when connecting to Elasticsearch. es_params: Other parameters for the Elasticsearch client. """ - if es_connection is not None: - self._es_client = es_connection - if not self._es_client.ping(): - raise elasticsearch.exceptions.ConnectionError( - "Elasticsearch cluster is not available," - " not able to set up the cache" - ) - elif es_url is not None or es_cloud_id is not None: - try: - self._es_client = create_elasticsearch_client( - url=es_url, - cloud_id=es_cloud_id, - api_key=es_api_key, - username=es_user, - password=es_password, - params=es_params, - ) - except Exception as err: - logger.error(f"Error connecting to Elasticsearch: {err}") - raise err - else: - raise ValueError( - """Either provide a pre-existing Elasticsearch connection, \ - or valid credentials for creating a new connection.""" - ) + self._index_name = index_name self._store_input = store_input self._store_input_params = store_input_params self._metadata = metadata - self._manage_index() - - def _manage_index(self) -> None: - """Write or update an index or alias according to the default mapping""" - self._is_alias = False - if self._es_client.indices.exists_alias(name=self._index_name): - self._is_alias = True - elif not self._es_client.indices.exists(index=self._index_name): - logger.debug(f"Creating new Elasticsearch index: {self._index_name}") - self._es_client.indices.create(index=self._index_name, body=self.mapping) - return - self._es_client.indices.put_mapping( - index=self._index_name, body=self.mapping["mappings"] + self._es_client = create_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + params=es_params, + ) + self._is_alias = _manage_cache_index( + self._es_client, + self._index_name, + self.mapping, ) @cached_property @@ -147,7 +152,7 @@ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: record = self._es_client.get( index=self._index_name, id=cache_key, source=["llm_output"] ) - except NotFoundError: + except exceptions.NotFoundError: return None return [loads(item) for item in record["_source"]["llm_output"]] @@ -186,3 +191,221 @@ def clear(self, **kwargs: Any) -> None: refresh=True, wait_for_completion=True, ) + + +class ElasticsearchEmbeddingsCache(ByteStore): + """An Elasticsearch store for caching embeddings.""" + + def __init__( + self, + index_name: str, + store_input: bool = True, + metadata: Optional[Dict[str, Any]] = None, + namespace: Optional[str] = None, + maximum_duplicates_allowed: int = 1, + *, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_api_key: Optional[str] = None, + es_password: Optional[str] = None, + es_params: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the Elasticsearch cache store by specifying the index/alias + to use and determining which additional information (like input, input + parameters, and any other metadata) should be stored in the cache. + Provide a namespace to organize the cache. + + + Args: + index_name (str): The name of the index or the alias to use for the cache. + If they do not exist an index is created, + according to the default mapping defined by the `mapping` property. + store_input (bool): Whether to store the input in the cache. + Default to True. + metadata (Optional[dict]): Additional metadata to store in the cache, + for filtering purposes. This must be JSON serializable in an + Elasticsearch document. Default to None. + namespace (Optional[str]): A namespace to use for the cache. + maximum_duplicates_allowed (int): Defines the maximum number of duplicate + keys permitted. Must be used in scenarios where the same key appears + across multiple indices that share the same alias. Default to 1. + es_url: URL of the Elasticsearch instance to connect to. + es_cloud_id: Cloud ID of the Elasticsearch instance to connect to. + es_user: Username to use when connecting to Elasticsearch. + es_password: Password to use when connecting to Elasticsearch. + es_api_key: API key to use when connecting to Elasticsearch. + es_params: Other parameters for the Elasticsearch client. + """ + self._namespace = namespace + self._maximum_duplicates_allowed = maximum_duplicates_allowed + self._index_name = index_name + self._store_input = store_input + self._metadata = metadata + self._es_client = create_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + params=es_params, + ) + self._is_alias = _manage_cache_index( + self._es_client, + self._index_name, + self.mapping, + ) + + @staticmethod + def encode_vector(data: bytes) -> str: + """Encode the vector data as bytes to as a base64 string.""" + return base64.b64encode(data).decode("utf-8") + + @staticmethod + def decode_vector(data: str) -> bytes: + """Decode the base64 string to vector data as bytes.""" + return base64.b64decode(data) + + @cached_property + def mapping(self) -> Dict[str, Any]: + """Get the default mapping for the index.""" + return { + "mappings": { + "properties": { + "text_input": {"type": "text", "index": False}, + "vector_dump": { + "type": "binary", + "doc_values": False, + }, + "metadata": {"type": "object"}, + "timestamp": {"type": "date"}, + "namespace": {"type": "keyword"}, + } + } + } + + def _key(self, input_text: str) -> str: + """Generate a key for the store.""" + return hashlib.md5(((self._namespace or "") + input_text).encode()).hexdigest() + + @classmethod + def _deduplicate_hits(cls, hits: List[dict]) -> Dict[str, bytes]: + """ + Collapse the results from a search query with multiple indices + returning only the latest version of the documents + """ + map_ids = {} + for hit in sorted( + hits, + key=lambda x: datetime.fromisoformat(x["_source"]["timestamp"]), + reverse=True, + ): + vector_id: str = hit["_id"] + if vector_id not in map_ids: + map_ids[vector_id] = cls.decode_vector(hit["_source"]["vector_dump"]) + + return map_ids + + def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + """Get the values associated with the given keys.""" + if not any(keys): + return [] + + cache_keys = [self._key(k) for k in keys] + if self._is_alias: + try: + results = self._es_client.search( + index=self._index_name, + body={ + "query": {"ids": {"values": cache_keys}}, + "size": len(cache_keys) * self._maximum_duplicates_allowed, + }, + source_includes=["vector_dump", "timestamp"], + ) + + except exceptions.BadRequestError as e: + if "window too large" in ( + e.body.get("error", {}).get("root_cause", [{}])[0].get("reason", "") + ): + logger.warning( + "Exceeded the maximum window size, " + "Reduce the duplicates manually or lower " + "`maximum_duplicate_allowed.`" + ) + raise e + + total_hits = results["hits"]["total"]["value"] + if self._maximum_duplicates_allowed > 1 and total_hits > len(cache_keys): + logger.warning( + f"Deduplicating, found {total_hits} hits for {len(cache_keys)} keys" + ) + map_ids = self._deduplicate_hits(results["hits"]["hits"]) + else: + map_ids = { + r["_id"]: self.decode_vector(r["_source"]["vector_dump"]) + for r in results["hits"]["hits"] + } + + return [map_ids.get(k) for k in cache_keys] + + else: + records = self._es_client.mget( + index=self._index_name, ids=cache_keys, source_includes=["vector_dump"] + ) + return [ + self.decode_vector(r["_source"]["vector_dump"]) if r["found"] else None + for r in records["docs"] + ] + + def build_document(self, text_input: str, vector: bytes) -> Dict[str, Any]: + """Build the Elasticsearch document for storing a single embedding""" + body: Dict[str, Any] = { + "vector_dump": self.encode_vector(vector), + "timestamp": datetime.now().isoformat(), + } + if self._metadata is not None: + body["metadata"] = self._metadata + if self._store_input: + body["text_input"] = text_input + if self._namespace: + body["namespace"] = self._namespace + return body + + def _bulk(self, actions: Iterable[Dict[str, Any]]) -> None: + try: + helpers.bulk( + client=self._es_client, + actions=actions, + index=self._index_name, + require_alias=self._is_alias, + refresh=True, + ) + except BulkIndexError as e: + first_error = e.errors[0].get("index", {}).get("error", {}) + logger.error(f"First bulk error reason: {first_error.get('reason')}") + raise e + + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + """Set the values for the given keys.""" + actions = ( + { + "_op_type": "index", + "_id": self._key(key), + "_source": self.build_document(key, vector), + } + for key, vector in key_value_pairs + ) + self._bulk(actions) + + def mdelete(self, keys: Sequence[str]) -> None: + """Delete the given keys and their associated values.""" + actions = ({"_op_type": "delete", "_id": self._key(key)} for key in keys) + self._bulk(actions) + + def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: + """Get an iterator over keys that match the given prefix.""" + # TODO This method is not currently used by CacheBackedEmbeddings, + # we can leave it blank. It could be implemented with ES "index_prefixes", + # but they are limited and expensive. + raise NotImplementedError() diff --git a/libs/elasticsearch/langchain_elasticsearch/chat_history.py b/libs/elasticsearch/langchain_elasticsearch/chat_history.py index 5d75bea..d2f0774 100644 --- a/libs/elasticsearch/langchain_elasticsearch/chat_history.py +++ b/libs/elasticsearch/langchain_elasticsearch/chat_history.py @@ -4,11 +4,7 @@ from typing import TYPE_CHECKING, List, Optional from langchain_core.chat_history import BaseChatMessageHistory -from langchain_core.messages import ( - BaseMessage, - message_to_dict, - messages_from_dict, -) +from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict from langchain_elasticsearch._utilities import with_user_agent_header from langchain_elasticsearch.client import create_elasticsearch_client diff --git a/libs/elasticsearch/langchain_elasticsearch/vectorstores.py b/libs/elasticsearch/langchain_elasticsearch/vectorstores.py index 41e3d64..02d8000 100644 --- a/libs/elasticsearch/langchain_elasticsearch/vectorstores.py +++ b/libs/elasticsearch/langchain_elasticsearch/vectorstores.py @@ -1,16 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Literal, - Optional, - Tuple, - Union, -) +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union from elasticsearch import Elasticsearch from elasticsearch.helpers.vectorstore import ( @@ -21,9 +11,7 @@ RetrievalStrategy, SparseVectorStrategy, ) -from elasticsearch.helpers.vectorstore import ( - VectorStore as EVectorStore, -) +from elasticsearch.helpers.vectorstore import VectorStore as EVectorStore from langchain_core._api.deprecation import deprecated from langchain_core.documents import Document from langchain_core.embeddings import Embeddings diff --git a/libs/elasticsearch/pyproject.toml b/libs/elasticsearch/pyproject.toml index c0fe89a..910c2eb 100644 --- a/libs/elasticsearch/pyproject.toml +++ b/libs/elasticsearch/pyproject.toml @@ -92,4 +92,4 @@ markers = [ "requires: mark tests as requiring a specific library", "asyncio: mark tests as requiring asyncio", ] -asyncio_mode = "auto" +asyncio_mode = "auto" \ No newline at end of file diff --git a/libs/elasticsearch/tests/conftest.py b/libs/elasticsearch/tests/conftest.py index be65716..6d91265 100644 --- a/libs/elasticsearch/tests/conftest.py +++ b/libs/elasticsearch/tests/conftest.py @@ -1,4 +1,5 @@ from typing import Generator +from unittest import mock from unittest.mock import MagicMock import pytest @@ -8,7 +9,7 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage -from langchain_elasticsearch import ElasticsearchCache +from langchain_elasticsearch import ElasticsearchCache, ElasticsearchEmbeddingsCache @pytest.fixture @@ -18,15 +19,36 @@ def es_client_fx() -> Generator[MagicMock, None, None]: yield client_mock() +@pytest.fixture +def es_embeddings_cache_fx( + es_client_fx: MagicMock, +) -> Generator[ElasticsearchEmbeddingsCache, None, None]: + with mock.patch( + "langchain_elasticsearch.cache.create_elasticsearch_client", + return_value=es_client_fx, + ): + yield ElasticsearchEmbeddingsCache( + es_url="http://localhost:9200", + index_name="test_index", + store_input=True, + namespace="test", + metadata={"project": "test_project"}, + ) + + @pytest.fixture def es_cache_fx(es_client_fx: MagicMock) -> Generator[ElasticsearchCache, None, None]: - yield ElasticsearchCache( - es_connection=es_client_fx, - index_name="test_index", - store_input=True, - store_input_params=True, - metadata={"project": "test_project"}, - ) + with mock.patch( + "langchain_elasticsearch.cache.create_elasticsearch_client", + return_value=es_client_fx, + ): + yield ElasticsearchCache( + es_url="http://localhost:30096", + index_name="test_index", + store_input=True, + store_input_params=True, + metadata={"project": "test_project"}, + ) @pytest.fixture diff --git a/libs/elasticsearch/tests/integration_tests/test_cache.py b/libs/elasticsearch/tests/integration_tests/test_cache.py index 17a0399..4271210 100644 --- a/libs/elasticsearch/tests/integration_tests/test_cache.py +++ b/libs/elasticsearch/tests/integration_tests/test_cache.py @@ -1,10 +1,12 @@ from typing import Dict, Generator, Union import pytest +from elasticsearch.helpers import BulkIndexError +from langchain.embeddings.cache import _value_serializer from langchain.globals import set_llm_cache from langchain_core.language_models import BaseChatModel -from langchain_elasticsearch import ElasticsearchCache +from langchain_elasticsearch import ElasticsearchCache, ElasticsearchEmbeddingsCache from tests.integration_tests._test_utilities import ( clear_test_indices, create_es_client, @@ -23,12 +25,14 @@ def es_env_fx() -> Union[dict, Generator[dict, None, None]]: es.indices.put_alias(index="test_index1", name="test_alias") es.indices.put_alias(index="test_index2", name="test_alias", is_write_index=True) yield params - es.indices.delete_alias(index="test_index1,test_index2", name="test_alias") + es.options(ignore_status=404).indices.delete_alias( + index="test_index1,test_index2", name="test_alias" + ) clear_test_indices(es) return None -def test_index(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: +def test_index_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: cache = ElasticsearchCache( **es_env_fx, index_name="test_index1", metadata={"project": "test"} ) @@ -68,7 +72,7 @@ def test_index(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: assert all(record.get("metadata") == {"project": "test"} for record in records) -def test_alias(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: +def test_alias_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: cache = ElasticsearchCache( **es_env_fx, index_name="test_alias", metadata={"project": "test"} ) @@ -97,7 +101,7 @@ def test_alias(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: assert fake_chat_fx.invoke("test2") -def test_clear(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: +def test_clear_llm_cache(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: cache = ElasticsearchCache( **es_env_fx, index_name="test_alias", metadata={"project": "test"} ) @@ -115,3 +119,169 @@ def test_clear(es_env_fx: Dict, fake_chat_fx: BaseChatModel) -> None: assert es_client.count(index="test_alias")["count"] == 3 cache.clear() assert es_client.count(index="test_alias")["count"] == 0 + + +def test_mdelete_cache_store(es_env_fx: Dict) -> None: + store = ElasticsearchEmbeddingsCache( + **es_env_fx, index_name="test_alias", metadata={"project": "test"} + ) + + recors = ["my little tests", "my little tests2", "my little tests3"] + store.mset( + [ + (recors[0], _value_serializer([1, 2, 3])), + (recors[1], _value_serializer([1, 2, 3])), + (recors[2], _value_serializer([1, 2, 3])), + ] + ) + + assert store._es_client.count(index="test_alias")["count"] == 3 + + store.mdelete(recors[:2]) + assert store._es_client.count(index="test_alias")["count"] == 1 + + store.mdelete(recors[2:]) + assert store._es_client.count(index="test_alias")["count"] == 0 + + with pytest.raises(BulkIndexError): + store.mdelete(recors) + + +def test_mset_cache_store(es_env_fx: Dict) -> None: + store = ElasticsearchEmbeddingsCache( + **es_env_fx, index_name="test_alias", metadata={"project": "test"} + ) + + records = ["my little tests", "my little tests2", "my little tests3"] + + store.mset([(records[0], _value_serializer([1, 2, 3]))]) + assert store._es_client.count(index="test_alias")["count"] == 1 + store.mset([(records[0], _value_serializer([1, 2, 3]))]) + assert store._es_client.count(index="test_alias")["count"] == 1 + store.mset( + [ + (records[1], _value_serializer([1, 2, 3])), + (records[2], _value_serializer([1, 2, 3])), + ] + ) + assert store._es_client.count(index="test_alias")["count"] == 3 + + +def test_mget_cache_store(es_env_fx: Dict) -> None: + store_no_alias = ElasticsearchEmbeddingsCache( + **es_env_fx, + index_name="test_index3", + metadata={"project": "test"}, + namespace="test", + ) + + records = ["my little tests", "my little tests2", "my little tests3"] + docs = [(r, _value_serializer([0.1, 2, i])) for i, r in enumerate(records)] + + store_no_alias.mset(docs) + assert store_no_alias._es_client.count(index="test_index3")["count"] == 3 + + cached_records = store_no_alias.mget([d[0] for d in docs]) + assert all(cached_records) + assert all([r == d[1] for r, d in zip(cached_records, docs)]) + + store_alias = ElasticsearchEmbeddingsCache( + **es_env_fx, + index_name="test_alias", + metadata={"project": "test"}, + namespace="test", + maximum_duplicates_allowed=1, + ) + + store_alias.mset(docs) + assert store_alias._es_client.count(index="test_alias")["count"] == 3 + + cached_records = store_alias.mget([d[0] for d in docs]) + assert all(cached_records) + assert all([r == d[1] for r, d in zip(cached_records, docs)]) + + +def test_mget_cache_store_multiple_keys(es_env_fx: Dict) -> None: + """verify the logic of deduplication of keys in the cache store""" + + store_alias = ElasticsearchEmbeddingsCache( + **es_env_fx, + index_name="test_alias", + metadata={"project": "test"}, + namespace="test", + maximum_duplicates_allowed=2, + ) + + es_client = store_alias._es_client + + records = ["my little tests", "my little tests2", "my little tests3"] + docs = [(r, _value_serializer([0.1, 2, i])) for i, r in enumerate(records)] + + store_alias.mset(docs) + assert es_client.count(index="test_alias")["count"] == 3 + + store_no_alias = ElasticsearchEmbeddingsCache( + **es_env_fx, + index_name="test_index3", + metadata={"project": "test"}, + namespace="test", + maximum_duplicates_allowed=1, + ) + + new_records = records + ["my little tests4", "my little tests5"] + new_docs = [ + (r, _value_serializer([0.1, 2, i + 100])) for i, r in enumerate(new_records) + ] + + # store the same 3 previous records and 2 more in a fresh index + store_no_alias.mset(new_docs) + assert es_client.count(index="test_index3")["count"] == 5 + + # update the alias to point to the new index and verify the cache + es_client.indices.update_aliases( + actions=[ + { + "add": { + "index": "test_index3", + "alias": "test_alias", + } + } + ] + ) + + # the alias now point to two indices that contains multiple records + # of the same keys, the cache store should return the latest records. + cached_records = store_alias.mget([d[0] for d in new_docs]) + assert all(cached_records) + assert len(cached_records) == 5 + assert es_client.count(index="test_alias")["count"] == 8 + assert cached_records[:3] != [ + d[1] for d in docs + ], "the first 3 records should be updated" + assert cached_records == [ + d[1] for d in new_docs + ], "new records should be returned and the updated ones" + assert all([r == d[1] for r, d in zip(cached_records, new_docs)]) + es_client.options(ignore_status=404).indices.delete_alias( + index="test_index3", name="test_alias" + ) + + +def test_build_document_cache_store(es_env_fx: Dict) -> None: + store = ElasticsearchEmbeddingsCache( + **es_env_fx, + index_name="test_alias", + metadata={"project": "test"}, + namespace="test", + ) + + store.mset([("my little tests", _value_serializer([0.1, 2, 3]))]) + record = store._es_client.search(index="test_alias")["hits"]["hits"][0]["_source"] + + assert record.get("metadata") == {"project": "test"} + assert record.get("namespace") == "test" + assert record.get("timestamp") + assert record.get("text_input") == "my little tests" + assert record.get("vector_dump") == ElasticsearchEmbeddingsCache.encode_vector( + _value_serializer([0.1, 2, 3]) + ) diff --git a/libs/elasticsearch/tests/integration_tests/test_chat_history.py b/libs/elasticsearch/tests/integration_tests/test_chat_history.py index 5673a89..56db550 100644 --- a/libs/elasticsearch/tests/integration_tests/test_chat_history.py +++ b/libs/elasticsearch/tests/integration_tests/test_chat_history.py @@ -8,11 +8,7 @@ from langchain_elasticsearch.chat_history import ElasticsearchChatMessageHistory -from ._test_utilities import ( - clear_test_indices, - create_es_client, - read_env, -) +from ._test_utilities import clear_test_indices, create_es_client, read_env """ cd tests/integration_tests diff --git a/libs/elasticsearch/tests/integration_tests/test_vectorstores.py b/libs/elasticsearch/tests/integration_tests/test_vectorstores.py index 76d2724..39219f0 100644 --- a/libs/elasticsearch/tests/integration_tests/test_vectorstores.py +++ b/libs/elasticsearch/tests/integration_tests/test_vectorstores.py @@ -10,15 +10,8 @@ from langchain_elasticsearch.vectorstores import ElasticsearchStore -from ..fake_embeddings import ( - ConsistentFakeEmbeddings, - FakeEmbeddings, -) -from ._test_utilities import ( - clear_test_indices, - create_es_client, - read_env, -) +from ..fake_embeddings import ConsistentFakeEmbeddings, FakeEmbeddings +from ._test_utilities import clear_test_indices, create_es_client, read_env logging.basicConfig(level=logging.DEBUG) diff --git a/libs/elasticsearch/tests/unit_tests/test_cache.py b/libs/elasticsearch/tests/unit_tests/test_cache.py index ddd6962..fbbe3cc 100644 --- a/libs/elasticsearch/tests/unit_tests/test_cache.py +++ b/libs/elasticsearch/tests/unit_tests/test_cache.py @@ -1,45 +1,57 @@ from datetime import datetime from typing import Any, Dict -from unittest.mock import MagicMock +from unittest import mock +from unittest.mock import ANY, MagicMock, patch -import pytest from _pytest.fixtures import FixtureRequest from elastic_transport import ApiResponseMeta, HttpHeaders, NodeConfig -from elasticsearch import NotFoundError, exceptions +from elasticsearch import NotFoundError +from langchain.embeddings.cache import _value_serializer from langchain_core.load import dumps from langchain_core.outputs import Generation -from langchain_elasticsearch import ElasticsearchCache +from langchain_elasticsearch import ElasticsearchCache, ElasticsearchEmbeddingsCache -def test_initialization(es_client_fx: MagicMock) -> None: - es_client_fx.ping.return_value = False - with pytest.raises(exceptions.ConnectionError): - ElasticsearchCache(es_connection=es_client_fx, index_name="test_index") +def serialize_encode_vector(vector: Any) -> str: + return ElasticsearchEmbeddingsCache.encode_vector(_value_serializer(vector)) + + +def test_initialization_llm_cache(es_client_fx: MagicMock) -> None: es_client_fx.ping.return_value = True es_client_fx.indices.exists_alias.return_value = True - cache = ElasticsearchCache(es_connection=es_client_fx, index_name="test_index") - es_client_fx.indices.exists_alias.assert_called_with(name="test_index") - assert cache._is_alias - es_client_fx.indices.put_mapping.assert_called_with( - index="test_index", body=cache.mapping["mappings"] - ) - es_client_fx.indices.exists_alias.return_value = False - es_client_fx.indices.exists.return_value = False - cache = ElasticsearchCache(es_connection=es_client_fx, index_name="test_index") - assert not cache._is_alias - es_client_fx.indices.create.assert_called_with( - index="test_index", body=cache.mapping - ) + with mock.patch( + "langchain_elasticsearch.cache.create_elasticsearch_client", + return_value=es_client_fx, + ): + cache = ElasticsearchCache( + es_url="http://localhost:9200", index_name="test_index" + ) + es_client_fx.indices.exists_alias.assert_called_with(name="test_index") + assert cache._is_alias + es_client_fx.indices.put_mapping.assert_called_with( + index="test_index", body=cache.mapping["mappings"] + ) + es_client_fx.indices.exists_alias.return_value = False + es_client_fx.indices.exists.return_value = False + cache = ElasticsearchCache( + es_url="http://localhost:9200", index_name="test_index" + ) + assert not cache._is_alias + es_client_fx.indices.create.assert_called_with( + index="test_index", body=cache.mapping + ) -def test_mapping(es_cache_fx: ElasticsearchCache, request: FixtureRequest) -> None: +def test_mapping_llm_cache( + es_cache_fx: ElasticsearchCache, request: FixtureRequest +) -> None: mapping = request.getfixturevalue("es_cache_fx").mapping assert mapping.get("mappings") assert mapping["mappings"].get("properties") -def test_key_generation(es_cache_fx: ElasticsearchCache) -> None: +def test_key_generation_llm_cache(es_cache_fx: ElasticsearchCache) -> None: key1 = es_cache_fx._key("test_prompt", "test_llm_string") assert key1 and isinstance(key1, str) key2 = es_cache_fx._key("test_prompt", "test_llm_string1") @@ -48,7 +60,9 @@ def test_key_generation(es_cache_fx: ElasticsearchCache) -> None: assert key3 and key1 != key3 -def test_clear(es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache) -> None: +def test_clear_llm_cache( + es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache +) -> None: es_cache_fx.clear() es_client_fx.delete_by_query.assert_called_once_with( index="test_index", @@ -58,7 +72,7 @@ def test_clear(es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache) -> None ) -def test_build_document(es_cache_fx: ElasticsearchCache) -> None: +def test_build_document_llm_cache(es_cache_fx: ElasticsearchCache) -> None: doc = es_cache_fx.build_document( "test_prompt", "test_llm_string", [Generation(text="test_prompt")] ) @@ -70,7 +84,9 @@ def test_build_document(es_cache_fx: ElasticsearchCache) -> None: assert doc["metadata"] == es_cache_fx._metadata -def test_update(es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache) -> None: +def test_update_llm_cache( + es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache +) -> None: es_cache_fx.update("test_prompt", "test_llm_string", [Generation(text="test")]) timestamp = es_client_fx.index.call_args.kwargs["body"]["timestamp"] doc = es_cache_fx.build_document( @@ -86,7 +102,9 @@ def test_update(es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache) -> Non ) -def test_lookup(es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache) -> None: +def test_lookup_llm_cache( + es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache +) -> None: cache_key = es_cache_fx._key("test_prompt", "test_llm_string") doc: Dict[str, Any] = { "_source": { @@ -132,3 +150,241 @@ def test_lookup(es_client_fx: MagicMock, es_cache_fx: ElasticsearchCache) -> Non assert es_cache_fx.lookup("test_prompt", "test_llm_string") == [ Generation(text="test2") ] + + +def test_key_generation_cache_store( + es_embeddings_cache_fx: ElasticsearchEmbeddingsCache, +) -> None: + key1 = es_embeddings_cache_fx._key("test_text") + assert key1 and isinstance(key1, str) + key2 = es_embeddings_cache_fx._key("test_text2") + assert key2 and key1 != key2 + es_embeddings_cache_fx._namespace = "other" + key3 = es_embeddings_cache_fx._key("test_text") + assert key3 and key1 != key3 + es_embeddings_cache_fx._namespace = None + key4 = es_embeddings_cache_fx._key("test_text") + assert key4 and key1 != key4 and key3 != key4 + + +def test_build_document_cache_store( + es_embeddings_cache_fx: ElasticsearchEmbeddingsCache, +) -> None: + doc = es_embeddings_cache_fx.build_document( + "test_text", _value_serializer([1.5, 2, 3.6]) + ) + assert doc["text_input"] == "test_text" + assert doc["vector_dump"] == serialize_encode_vector([1.5, 2, 3.6]) + assert datetime.fromisoformat(str(doc["timestamp"])) + assert doc["metadata"] == es_embeddings_cache_fx._metadata + + +def test_mget_cache_store( + es_client_fx: MagicMock, es_embeddings_cache_fx: ElasticsearchEmbeddingsCache +) -> None: + cache_keys = [ + es_embeddings_cache_fx._key("test_text1"), + es_embeddings_cache_fx._key("test_text2"), + es_embeddings_cache_fx._key("test_text3"), + ] + docs = { + "docs": [ + {"_index": "test_index", "_id": cache_keys[0], "found": False}, + { + "_index": "test_index", + "_id": cache_keys[1], + "found": True, + "_source": {"vector_dump": serialize_encode_vector([1.5, 2, 3.6])}, + }, + { + "_index": "test_index", + "_id": cache_keys[2], + "found": True, + "_source": {"vector_dump": serialize_encode_vector([5, 6, 7.1])}, + }, + ] + } + es_embeddings_cache_fx._is_alias = False + es_client_fx.mget.return_value = docs + assert es_embeddings_cache_fx.mget([]) == [] + assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ + None, + _value_serializer([1.5, 2, 3.6]), + _value_serializer([5, 6, 7.1]), + ] + es_client_fx.mget.assert_called_with( + index="test_index", ids=cache_keys, source_includes=["vector_dump"] + ) + es_embeddings_cache_fx._is_alias = True + es_client_fx.search.return_value = {"hits": {"total": {"value": 0}, "hits": []}} + assert es_embeddings_cache_fx.mget([]) == [] + assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ + None, + None, + None, + ] + es_client_fx.search.assert_called_with( + index="test_index", + body={ + "query": {"ids": {"values": cache_keys}}, + "size": 3, + }, + source_includes=["vector_dump", "timestamp"], + ) + resp = { + "hits": {"total": {"value": 3}, "hits": [d for d in docs["docs"] if d["found"]]} + } + es_client_fx.search.return_value = resp + assert es_embeddings_cache_fx.mget(["test_text1", "test_text2", "test_text3"]) == [ + None, + _value_serializer([1.5, 2, 3.6]), + _value_serializer([5, 6, 7.1]), + ] + + +def test_deduplicate_hits(es_embeddings_cache_fx: ElasticsearchEmbeddingsCache) -> None: + hits = [ + { + "_id": "1", + "_source": { + "timestamp": "2022-01-01T00:00:00", + "vector_dump": serialize_encode_vector([1, 2, 3]), + }, + }, + { + "_id": "1", + "_source": { + "timestamp": "2022-01-02T00:00:00", + "vector_dump": serialize_encode_vector([4, 5, 6]), + }, + }, + { + "_id": "2", + "_source": { + "timestamp": "2022-01-01T00:00:00", + "vector_dump": serialize_encode_vector([7, 8, 9]), + }, + }, + ] + + result = es_embeddings_cache_fx._deduplicate_hits(hits) + + assert len(result) == 2 + assert result["1"] == _value_serializer([4, 5, 6]) + assert result["2"] == _value_serializer([7, 8, 9]) + + +def test_mget_duplicate_keys_cache_store( + es_client_fx: MagicMock, es_embeddings_cache_fx: ElasticsearchEmbeddingsCache +) -> None: + cache_keys = [ + es_embeddings_cache_fx._key("test_text1"), + es_embeddings_cache_fx._key("test_text2"), + ] + + resp = { + "hits": { + "total": {"value": 3}, + "hits": [ + { + "_index": "test_index", + "_id": cache_keys[1], + "found": True, + "_source": { + "vector_dump": serialize_encode_vector([1.5, 2, 3.6]), + "timestamp": "2024-03-07T13:25:36.410756", + }, + }, + { + "_index": "test_index", + "_id": cache_keys[0], + "found": True, + "_source": { + "vector_dump": serialize_encode_vector([1, 6, 7.1]), + "timestamp": "2024-03-07T13:25:46.410756", + }, + }, + { + "_index": "test_index", + "_id": cache_keys[0], + "found": True, + "_source": { + "vector_dump": serialize_encode_vector([2, 6, 7.1]), + "timestamp": "2024-03-07T13:27:46.410756", + }, + }, + ], + } + } + + es_embeddings_cache_fx._is_alias = True + es_client_fx.search.return_value = resp + assert es_embeddings_cache_fx.mget(["test_text1", "test_text2"]) == [ + _value_serializer([2, 6, 7.1]), + _value_serializer([1.5, 2, 3.6]), + ] + es_client_fx.search.assert_called_with( + index="test_index", + body={ + "query": {"ids": {"values": cache_keys}}, + "size": len(cache_keys), + }, + source_includes=["vector_dump", "timestamp"], + ) + + +def _del_timestamp(doc: Dict[str, Any]) -> Dict[str, Any]: + del doc["_source"]["timestamp"] + return doc + + +def test_mset_cache_store(es_embeddings_cache_fx: ElasticsearchEmbeddingsCache) -> None: + input = [ + ("test_text1", _value_serializer([1.5, 2, 3.6])), + ("test_text2", _value_serializer([5, 6, 7.1])), + ] + actions = [ + { + "_op_type": "index", + "_id": es_embeddings_cache_fx._key(k), + "_source": es_embeddings_cache_fx.build_document(k, v), + } + for k, v in input + ] + es_embeddings_cache_fx._is_alias = False + with patch("elasticsearch.helpers.bulk") as bulk_mock: + es_embeddings_cache_fx.mset([]) + bulk_mock.assert_called_once() + es_embeddings_cache_fx.mset(input) + bulk_mock.assert_called_with( + client=es_embeddings_cache_fx._es_client, + actions=ANY, + index="test_index", + require_alias=False, + refresh=True, + ) + assert [_del_timestamp(d) for d in bulk_mock.call_args.kwargs["actions"]] == [ + _del_timestamp(d) for d in actions + ] + + +def test_mdelete_cache_store( + es_embeddings_cache_fx: ElasticsearchEmbeddingsCache, +) -> None: + input = ["test_text1", "test_text2"] + actions = [ + {"_op_type": "delete", "_id": es_embeddings_cache_fx._key(k)} for k in input + ] + es_embeddings_cache_fx._is_alias = False + with patch("elasticsearch.helpers.bulk") as bulk_mock: + es_embeddings_cache_fx.mdelete([]) + bulk_mock.assert_called_once() + es_embeddings_cache_fx.mdelete(input) + bulk_mock.assert_called_with( + client=es_embeddings_cache_fx._es_client, + actions=ANY, + index="test_index", + require_alias=False, + refresh=True, + ) + assert list(bulk_mock.call_args.kwargs["actions"]) == actions diff --git a/libs/elasticsearch/tests/unit_tests/test_imports.py b/libs/elasticsearch/tests/unit_tests/test_imports.py index 9b13d57..5cd5fc2 100644 --- a/libs/elasticsearch/tests/unit_tests/test_imports.py +++ b/libs/elasticsearch/tests/unit_tests/test_imports.py @@ -4,6 +4,7 @@ "ElasticsearchCache", "ElasticsearchChatMessageHistory", "ElasticsearchEmbeddings", + "ElasticsearchEmbeddingsCache", "ElasticsearchRetriever", "ElasticsearchStore", # retrieval strategies