Skip to content

Commit

Permalink
Add Azure AI Search DB Vector Support
Browse files Browse the repository at this point in the history
  • Loading branch information
maljazaery committed Oct 21, 2024
1 parent 078aa66 commit 770b9ad
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/components/vectordbs/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Config in mem0 is a dictionary that specifies the settings for your vector datab

The config is defined as a Python dictionary with two main keys:
- `vector_store`: Specifies the vector database provider and its configuration
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus")
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus","azure_ai_search")
- `config`: A nested dictionary containing provider-specific settings

## How to Use Config
Expand Down
1 change: 1 addition & 0 deletions docs/components/vectordbs/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ See the list of supported vector databases below.
<Card title="Qdrant" href="/components/vectordbs/dbs/qdrant"></Card>
<Card title="Chroma" href="/components/vectordbs/dbs/chroma"></Card>
<Card title="Pgvector" href="/components/vectordbs/dbs/pgvector"></Card>
<Card title="Azure AI Search" href="/components/vectordbs/dbs/azure_ai_search"></Card>
</CardGroup>

## Usage
Expand Down
3 changes: 2 additions & 1 deletion docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@
"components/vectordbs/dbs/chroma",
"components/vectordbs/dbs/pgvector",
"components/vectordbs/dbs/qdrant",
"components/vectordbs/dbs/milvus"
"components/vectordbs/dbs/milvus",
"components/vectordbs/dbs/azure_ai_search"
]
}
]
Expand Down
26 changes: 26 additions & 0 deletions mem0/configs/vector_stores/azure_ai_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Any, Dict

from pydantic import BaseModel, Field, model_validator


class AzureAISearchConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the collection")
service_name: str = Field(None, description="Azure Cognitive Search service name")
api_key: str = Field(None, description="API key for the Azure Cognitive Search service")
embedding_model_dims: int = Field(None, description="Dimension of the embedding vector")

@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values

model_config = {
"arbitrary_types_allowed": True,
}
1 change: 1 addition & 0 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class VectorStoreFactory:
"chroma": "mem0.vector_stores.chroma.ChromaDB",
"pgvector": "mem0.vector_stores.pgvector.PGVector",
"milvus": "mem0.vector_stores.milvus.MilvusDB",
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
}

@classmethod
Expand Down
190 changes: 190 additions & 0 deletions mem0/vector_stores/azure_ai_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import json
import logging
from typing import List, Optional

from pydantic import BaseModel

from mem0.vector_stores.base import VectorStoreBase

try:
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
HnswAlgorithmConfiguration,
SearchField,
SearchFieldDataType,
SearchIndex,
SimpleField,
VectorSearch,
VectorSearchProfile,
)
from azure.search.documents.models import VectorizedQuery
except ImportError:
raise ImportError("The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.1'.")





logger = logging.getLogger(__name__)

class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[dict]

class AzureAISearch(VectorStoreBase):
def __init__(self, service_name, collection_name, api_key, embedding_model_dims):
"""Initialize the Azure Cognitive Search vector store.
Args:
service_name (str): Azure Cognitive Search service name.
collection_name (str): Index name.
api_key (str): API key for the Azure Cognitive Search service.
embedding_model_dims (int): Dimension of the embedding vector.
"""
self.index_name = collection_name
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.search_client = SearchClient(endpoint=f"https://{service_name}.search.windows.net",
index_name=self.index_name,
credential=AzureKeyCredential(api_key))
self.index_client = SearchIndexClient(endpoint=f"https://{service_name}.search.windows.net",
credential=AzureKeyCredential(api_key))
self.create_col() #create the collection / index

def create_col(self):
"""Create a new index in Azure Cognitive Search."""
vector_dimensions = self.embedding_model_dims # Set this to the number of dimensions in your vector
fields = [
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
SearchField(name="vector", type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True, vector_search_dimensions=vector_dimensions, vector_search_profile_name="my-vector-config"),
SimpleField(name="payload", type=SearchFieldDataType.String)
]
vector_search = VectorSearch(
profiles=[VectorSearchProfile(name="my-vector-config", algorithm_configuration_name="my-algorithms-config")],
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
)
index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
self.index_client.create_or_update_index(index)

def insert(self, vectors, payloads=None, ids=None):
"""Insert vectors into the index.
Args:
vectors (List[List[float]]): List of vectors to insert.
payloads (List[Dict], optional): List of payloads corresponding to vectors.
ids (List[str], optional): List of IDs corresponding to vectors.
"""
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
documents = [
{"id": id, "vector": vector, "payload": json.dumps(payload)}
for id, vector, payload in zip(ids, vectors, payloads)
]
self.search_client.upload_documents(documents)

def search(self, query, limit=5, filters=None):
"""Search for similar vectors.
Args:
query (List[float]): Query vectors.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: Search results.
"""
vector_query = VectorizedQuery(vector=query, k_nearest_neighbors=limit, fields="vector")
search_results = self.search_client.search(
vector_queries=[vector_query],
top=limit
)

results = []
for result in search_results:
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=json.loads(result["payload"])))
return results

def delete(self, vector_id):
"""Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete.
"""
self.search_client.delete_documents(documents=[{"id": vector_id}])

def update(self, vector_id, vector=None, payload=None):
"""Update a vector and its payload.
Args:
vector_id (str): ID of the vector to update.
vector (List[float], optional): Updated vector.
payload (Dict, optional): Updated payload.
"""
document = {"id": vector_id}
if vector:
document["vector"] = vector
if payload:
document["payload"] = json.dumps(payload)
self.search_client.merge_or_upload_documents(documents=[document])

def get(self, vector_id) -> OutputData:
"""Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve.
Returns:
OutputData: Retrieved vector.
"""
result = self.search_client.get_document(key=vector_id)
if not result:
return None
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))

def list_cols(self) -> List[str]:
"""List all collections (indexes).
Returns:
List[str]: List of index names.
"""
indexes = self.index_client.list_indexes()
return [index.name for index in indexes]

def delete_col(self):
"""Delete the index."""
self.index_client.delete_index(self.index_name)



def col_info(self):
"""Get information about the index.
Returns:
Dict[str, Any]: Index information.
"""
index = self.index_client.get_index(self.index_name)
return {"name": index.name, "fields": index.fields}

def list(self, filters=None, limit=100):
"""List all vectors in the index.
Args:
filters (Dict, optional): Filters to apply to the list.
limit (int, optional): Number of vectors to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors.
"""
search_results = self.search_client.search(search_text="", top=limit)
results = []
for result in search_results:
results.append(OutputData(id=result["id"], score=None, payload=json.loads(result["payload"])))
return results

def __del__(self):
"""Close the search client when the object is deleted."""
self.search_client.close()
self.index_client.close()
1 change: 1 addition & 0 deletions mem0/vector_stores/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class VectorStoreConfig(BaseModel):
"chroma": "ChromaDbConfig",
"pgvector": "PGVectorConfig",
"milvus": "MilvusDBConfig",
"azure_ai_search": "AzureAISearchConfig",
}

@model_validator(mode="after")
Expand Down

0 comments on commit 770b9ad

Please sign in to comment.