From 3d4cbbb3a59928e4c890a18d64c1508769a4d5f7 Mon Sep 17 00:00:00 2001 From: Mohamad Aljazaery Date: Tue, 22 Oct 2024 01:01:49 +0000 Subject: [PATCH] Add Azure AI Search DB Vector Support --- mem0/vector_stores/azure_ai_search.py | 40 ++++++++++++++++++++------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py index 7938248103..def28643c7 100644 --- a/mem0/vector_stores/azure_ai_search.py +++ b/mem0/vector_stores/azure_ai_search.py @@ -8,6 +8,7 @@ try: from azure.core.credentials import AzureKeyCredential + from azure.core.exceptions import ResourceNotFoundError from azure.search.documents import SearchClient from azure.search.documents.indexes import SearchIndexClient from azure.search.documents.indexes.models import ( @@ -61,7 +62,7 @@ def create_col(self): SimpleField(name="id", type=SearchFieldDataType.String, key=True), SearchField(name="vector", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, vector_search_dimensions=vector_dimensions, vector_search_profile_name="my-vector-config"), - SimpleField(name="payload", type=SearchFieldDataType.String) + SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True) ] vector_search = VectorSearch( profiles=[VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")], @@ -84,7 +85,7 @@ def insert(self, vectors, payloads=None, ids=None): for id, vector, payload in zip(ids, vectors, payloads) ] self.search_client.upload_documents(documents) - + def search(self, query, limit=5, filters=None): """Search for similar vectors. @@ -96,6 +97,7 @@ def search(self, query, limit=5, filters=None): Returns: list: Search results. """ + vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector") search_results = self.search_client.search( vector_queries=[vector_query], @@ -104,7 +106,12 @@ def search(self, query, limit=5, filters=None): results = [] for result in search_results: - results.append(OutputData(id=result["id"], score=result["@search.score"], payload=json.loads(result["payload"]))) + payload=json.loads(result["payload"]) + if filters: + for key, value in filters.items(): + if key not in payload or payload[key]!= value: + continue + results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) return results def delete(self, vector_id): @@ -114,6 +121,8 @@ def delete(self, vector_id): vector_id (str): ID of the vector to delete. """ self.search_client.delete_documents(documents=[{"id": vector_id}]) + + def update(self, vector_id, vector=None, payload=None): """Update a vector and its payload. @@ -139,10 +148,11 @@ def get(self, vector_id) -> OutputData: Returns: OutputData: Retrieved vector. """ - result = self.search_client.get_document(key=vector_id) - if not result: - return None - return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])) + try: + result = self.search_client.get_document(key=vector_id) + except ResourceNotFoundError: + return None + return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])) def list_cols(self) -> List[str]: """List all collections (indexes). @@ -178,11 +188,21 @@ def list(self, filters=None, limit=100): Returns: List[OutputData]: List of vectors. """ - search_results = self.search_client.search(search_text="", top=limit) + search_results = self.search_client.search(search_text="*", top=limit) results = [] for result in search_results: - results.append(OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))) - return results + payload=json.loads(result["payload"]) + good=True + if filters: + for key, value in filters.items(): + if (key not in payload) or (payload[key]!= filters[key]): + good=False + break + if good: + results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload)) + + + return [results] def __del__(self): """Close the search client when the object is deleted."""