-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Azure AI Search DB Vector Support
- Loading branch information
1 parent
078aa66
commit 770b9ad
Showing
7 changed files
with
222 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Any, Dict | ||
|
||
from pydantic import BaseModel, Field, model_validator | ||
|
||
|
||
class AzureAISearchConfig(BaseModel): | ||
collection_name: str = Field("mem0", description="Name of the collection") | ||
service_name: str = Field(None, description="Azure Cognitive Search service name") | ||
api_key: str = Field(None, description="API key for the Azure Cognitive Search service") | ||
embedding_model_dims: int = Field(None, description="Dimension of the embedding vector") | ||
|
||
@model_validator(mode="before") | ||
@classmethod | ||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: | ||
allowed_fields = set(cls.model_fields.keys()) | ||
input_fields = set(values.keys()) | ||
extra_fields = input_fields - allowed_fields | ||
if extra_fields: | ||
raise ValueError( | ||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" | ||
) | ||
return values | ||
|
||
model_config = { | ||
"arbitrary_types_allowed": True, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
import json | ||
import logging | ||
from typing import List, Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
from mem0.vector_stores.base import VectorStoreBase | ||
|
||
try: | ||
from azure.core.credentials import AzureKeyCredential | ||
from azure.search.documents import SearchClient | ||
from azure.search.documents.indexes import SearchIndexClient | ||
from azure.search.documents.indexes.models import ( | ||
HnswAlgorithmConfiguration, | ||
SearchField, | ||
SearchFieldDataType, | ||
SearchIndex, | ||
SimpleField, | ||
VectorSearch, | ||
VectorSearchProfile, | ||
) | ||
from azure.search.documents.models import VectorizedQuery | ||
except ImportError: | ||
raise ImportError("The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.1'.") | ||
|
||
|
||
|
||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class OutputData(BaseModel): | ||
id: Optional[str] | ||
score: Optional[float] | ||
payload: Optional[dict] | ||
|
||
class AzureAISearch(VectorStoreBase): | ||
def __init__(self, service_name, collection_name, api_key, embedding_model_dims): | ||
"""Initialize the Azure Cognitive Search vector store. | ||
Args: | ||
service_name (str): Azure Cognitive Search service name. | ||
collection_name (str): Index name. | ||
api_key (str): API key for the Azure Cognitive Search service. | ||
embedding_model_dims (int): Dimension of the embedding vector. | ||
""" | ||
self.index_name = collection_name | ||
self.collection_name = collection_name | ||
self.embedding_model_dims = embedding_model_dims | ||
self.search_client = SearchClient(endpoint=f"https://{service_name}.search.windows.net", | ||
index_name=self.index_name, | ||
credential=AzureKeyCredential(api_key)) | ||
self.index_client = SearchIndexClient(endpoint=f"https://{service_name}.search.windows.net", | ||
credential=AzureKeyCredential(api_key)) | ||
self.create_col() #create the collection / index | ||
|
||
def create_col(self): | ||
"""Create a new index in Azure Cognitive Search.""" | ||
vector_dimensions = self.embedding_model_dims # Set this to the number of dimensions in your vector | ||
fields = [ | ||
SimpleField(name="id", type=SearchFieldDataType.String, key=True), | ||
SearchField(name="vector", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), | ||
searchable=True, vector_search_dimensions=vector_dimensions, vector_search_profile_name="my-vector-config"), | ||
SimpleField(name="payload", type=SearchFieldDataType.String) | ||
] | ||
vector_search = VectorSearch( | ||
profiles=[VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")], | ||
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")], | ||
) | ||
index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search) | ||
self.index_client.create_or_update_index(index) | ||
|
||
def insert(self, vectors, payloads=None, ids=None): | ||
"""Insert vectors into the index. | ||
Args: | ||
vectors (List[List[float]]): List of vectors to insert. | ||
payloads (List[Dict], optional): List of payloads corresponding to vectors. | ||
ids (List[str], optional): List of IDs corresponding to vectors. | ||
""" | ||
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}") | ||
documents = [ | ||
{"id": id, "vector": vector, "payload": json.dumps(payload)} | ||
for id, vector, payload in zip(ids, vectors, payloads) | ||
] | ||
self.search_client.upload_documents(documents) | ||
|
||
def search(self, query, limit=5, filters=None): | ||
"""Search for similar vectors. | ||
Args: | ||
query (List[float]): Query vectors. | ||
limit (int, optional): Number of results to return. Defaults to 5. | ||
filters (Dict, optional): Filters to apply to the search. Defaults to None. | ||
Returns: | ||
list: Search results. | ||
""" | ||
vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector") | ||
search_results = self.search_client.search( | ||
vector_queries=[vector_query], | ||
top=limit | ||
) | ||
|
||
results = [] | ||
for result in search_results: | ||
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=json.loads(result["payload"]))) | ||
return results | ||
|
||
def delete(self, vector_id): | ||
"""Delete a vector by ID. | ||
Args: | ||
vector_id (str): ID of the vector to delete. | ||
""" | ||
self.search_client.delete_documents(documents=[{"id": vector_id}]) | ||
|
||
def update(self, vector_id, vector=None, payload=None): | ||
"""Update a vector and its payload. | ||
Args: | ||
vector_id (str): ID of the vector to update. | ||
vector (List[float], optional): Updated vector. | ||
payload (Dict, optional): Updated payload. | ||
""" | ||
document = {"id": vector_id} | ||
if vector: | ||
document["vector"] = vector | ||
if payload: | ||
document["payload"] = json.dumps(payload) | ||
self.search_client.merge_or_upload_documents(documents=[document]) | ||
|
||
def get(self, vector_id) -> OutputData: | ||
"""Retrieve a vector by ID. | ||
Args: | ||
vector_id (str): ID of the vector to retrieve. | ||
Returns: | ||
OutputData: Retrieved vector. | ||
""" | ||
result = self.search_client.get_document(key=vector_id) | ||
if not result: | ||
return None | ||
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])) | ||
|
||
def list_cols(self) -> List[str]: | ||
"""List all collections (indexes). | ||
Returns: | ||
List[str]: List of index names. | ||
""" | ||
indexes = self.index_client.list_indexes() | ||
return [index.name for index in indexes] | ||
|
||
def delete_col(self): | ||
"""Delete the index.""" | ||
self.index_client.delete_index(self.index_name) | ||
|
||
|
||
|
||
def col_info(self): | ||
"""Get information about the index. | ||
Returns: | ||
Dict[str, Any]: Index information. | ||
""" | ||
index = self.index_client.get_index(self.index_name) | ||
return {"name": index.name, "fields": index.fields} | ||
|
||
def list(self, filters=None, limit=100): | ||
"""List all vectors in the index. | ||
Args: | ||
filters (Dict, optional): Filters to apply to the list. | ||
limit (int, optional): Number of vectors to return. Defaults to 100. | ||
Returns: | ||
List[OutputData]: List of vectors. | ||
""" | ||
search_results = self.search_client.search(search_text="", top=limit) | ||
results = [] | ||
for result in search_results: | ||
results.append(OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))) | ||
return results | ||
|
||
def __del__(self): | ||
"""Close the search client when the object is deleted.""" | ||
self.search_client.close() | ||
self.index_client.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters