Skip to content

Commit

Permalink
Add Azure AI Search DB Vector Support
Browse files Browse the repository at this point in the history
  • Loading branch information
maljazaery committed Oct 22, 2024
1 parent 770b9ad commit 3d4cbbb
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions mem0/vector_stores/azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

try:
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
Expand Down Expand Up @@ -61,7 +62,7 @@ def create_col(self):
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
SearchField(name="vector", type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True, vector_search_dimensions=vector_dimensions, vector_search_profile_name="my-vector-config"),
SimpleField(name="payload", type=SearchFieldDataType.String)
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True)
]
vector_search = VectorSearch(
profiles=[VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")],
Expand All @@ -84,7 +85,7 @@ def insert(self, vectors, payloads=None, ids=None):
for id, vector, payload in zip(ids, vectors, payloads)
]
self.search_client.upload_documents(documents)

def search(self, query, limit=5, filters=None):
"""Search for similar vectors.
Expand All @@ -96,6 +97,7 @@ def search(self, query, limit=5, filters=None):
Returns:
list: Search results.
"""

vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector")
search_results = self.search_client.search(
vector_queries=[vector_query],
Expand All @@ -104,7 +106,12 @@ def search(self, query, limit=5, filters=None):

results = []
for result in search_results:
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=json.loads(result["payload"])))
payload=json.loads(result["payload"])
if filters:
for key, value in filters.items():
if key not in payload or payload[key]!= value:
continue
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return results

def delete(self, vector_id):
Expand All @@ -114,6 +121,8 @@ def delete(self, vector_id):
vector_id (str): ID of the vector to delete.
"""
self.search_client.delete_documents(documents=[{"id": vector_id}])



def update(self, vector_id, vector=None, payload=None):
"""Update a vector and its payload.
Expand All @@ -139,10 +148,11 @@ def get(self, vector_id) -> OutputData:
Returns:
OutputData: Retrieved vector.
"""
result = self.search_client.get_document(key=vector_id)
if not result:
return None
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))
try:
result = self.search_client.get_document(key=vector_id)
except ResourceNotFoundError:
return None
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))

def list_cols(self) -> List[str]:
"""List all collections (indexes).
Expand Down Expand Up @@ -178,11 +188,21 @@ def list(self, filters=None, limit=100):
Returns:
List[OutputData]: List of vectors.
"""
search_results = self.search_client.search(search_text="", top=limit)
search_results = self.search_client.search(search_text="*", top=limit)
results = []
for result in search_results:
results.append(OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])))
return results
payload=json.loads(result["payload"])
good=True
if filters:
for key, value in filters.items():
if (key not in payload) or (payload[key]!= filters[key]):
good=False
break
if good:
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))


return [results]

def __del__(self):
"""Close the search client when the object is deleted."""
Expand Down

0 comments on commit 3d4cbbb

Please sign in to comment.