From e3e678399141821a635ff37224c0894a9f95685f Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 2 Apr 2024 09:28:58 +0000 Subject: [PATCH 01/27] Added vectordb base and chromadb --- autogen/agentchat/contrib/vectordb/base.py | 202 +++++++++++++ .../agentchat/contrib/vectordb/chromadb.py | 286 ++++++++++++++++++ autogen/agentchat/contrib/vectordb/utils.py | 92 ++++++ 3 files changed, 580 insertions(+) create mode 100644 autogen/agentchat/contrib/vectordb/base.py create mode 100644 autogen/agentchat/contrib/vectordb/chromadb.py create mode 100644 autogen/agentchat/contrib/vectordb/utils.py diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py new file mode 100644 index 000000000000..06927b0c377a --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -0,0 +1,202 @@ +from typing import Any, Callable, List, Protocol, runtime_checkable + + +@runtime_checkable +class VectorDB(Protocol): + """ + Abstract class for vector database. A vector database is responsible for storing and retrieving documents. + """ + + def __init__(self, db_config: dict = None) -> None: + """ + Initialize the vector database. + + Args: + db_config: dict | configuration for initializing the vector database. Default is None. + + Returns: + None + """ + ... + + def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any: + """ + Create a collection in the vector database. + Case 1. if the collection does not exist, create the collection. + Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, + otherwise it raise a ValueError. + + Args: + collection_name: str | The name of the collection. + overwrite: bool | Whether to overwrite the collection if it exists. Default is False. + get_or_create: bool | Whether to get the collection if it exists. Default is True. + + Returns: + Any | The collection object. + """ + ... + + def get_collection(self, collection_name: str = None) -> Any: + """ + Get the collection from the vector database. + + Args: + collection_name: str | The name of the collection. Default is None. If None, return the + current active collection. + + Returns: + Any | The collection object. + """ + ... + + def delete_collection(self, collection_name: str) -> Any: + """ + Delete the collection from the vector database. + + Args: + collection_name: str | The name of the collection. + + Returns: + Any + """ + ... + + def insert_docs(self, docs: List[dict], collection_name: str = None, upsert: bool = False, **kwargs) -> Any: + """ + Insert documents into the collection of the vector database. + + Args: + docs: List[dict] | A list of documents. Each document is a dictionary. + It should include the following fields: + - required: "id", "content" + - optional: "embedding", "metadata", "distance", etc. + collection_name: str | The name of the collection. Default is None. + upsert: bool | Whether to update the document if it exists. Default is False. + kwargs: dict | Additional keyword arguments. + + Returns: + None + """ + ... + + def update_docs(self, docs: List[dict], collection_name: str = None, **kwargs) -> None: + """ + Update documents in the collection of the vector database. + + Args: + docs: List[dict] | A list of documents. + collection_name: str | The name of the collection. Default is None. + kwargs: dict | Additional keyword arguments. + + Returns: + None + """ + ... + + def delete_docs(self, ids: List[Any], collection_name: str = None, **kwargs) -> None: + """ + Delete documents from the collection of the vector database. + + Args: + ids: List[Any] | A list of document ids. + collection_name: str | The name of the collection. Default is None. + kwargs: dict | Additional keyword arguments. + + Returns: + None + """ + ... + + def retrieve_docs( + self, + queries: List[str], + collection_name: str = None, + n_results: int = 10, + distance_threshold: float = -1, + **kwargs, + ) -> List[List[dict]]: + """ + Retrieve documents from the collection of the vector database based on the queries. + + Args: + queries: List[str] | A list of queries. Each query is a string. + collection_name: str | The name of the collection. Default is None. + n_results: int | The number of relevant documents to return. Default is 10. + distance_threshold: float | The threshold for the distance score, only distance smaller than it will be + returned. Don't filter with it if < 0. Default is -1. + kwargs: dict | Additional keyword arguments. + + Returns: + List[List[dict]] | The query results. Each query result is a list of dictionaries. + It should include the following fields: + - required: "ids", "contents" + - optional: "embeddings", "metadatas", "distances", etc. + + queries example: ["query1", "query2"] + query results example: [ + { + "ids": ["id1", "id2", ...], + "contents": ["content1", "content2", ...], + "embeddings": ["embedding1", "embedding2", ...], + "metadatas": ["metadata1", "metadata2", ...], + "distances": ["distance1", "distance2", ...] + }, + { + "ids": ["id1", "id2", ...], + "contents": ["content1", "content2", ...], + "embeddings": ["embedding1", "embedding2", ...], + "metadatas": ["metadata1", "metadata2", ...], + "distances": ["distance1", "distance2", ...] + } + ] + + """ + ... + + def get_docs_by_ids(self, ids: List[Any], collection_name: str = None, include=None, **kwargs) -> List[dict]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: List[Any] | A list of document ids. + collection_name: str | The name of the collection. Default is None. + include: List[str] | The fields to include. Default is None. + If None, will include ["ids", "contents"] + kwargs: dict | Additional keyword arguments. + + Returns: + List[dict] | The query results. + """ + ... + + +class VectorDBFactory: + """ + Factory class for creating vector databases. + """ + + PREDEFINED_VECTOR_DB = ["chroma"] + + @staticmethod + def create_vector_db(db_type: str, db_config: dict = None) -> VectorDB: + """ + Create a vector database. + + Args: + db_type: str | The type of the vector database. + db_config: dict | The configuration of the vector database. Default is None. + + Returns: + VectorDB | The vector database. + """ + if db_config is None: + db_config = {} + if db_type.lower() in ["chroma", "chromadb"]: + from .chromadb import ChromaVectorDB + + return ChromaVectorDB(db_config=db_config) + else: + raise ValueError( + f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}." + ) diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py new file mode 100644 index 000000000000..04a9268625bb --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -0,0 +1,286 @@ +import os +from typing import Any, Callable, List + +from .utils import get_logger, timer + +try: + import chromadb + + if chromadb.__version__ < "0.4.15": + raise ImportError("Please upgrade chromadb to version 0.4.15 or later.") + from chromadb.api.models.Collection import Collection +except ImportError: + raise ImportError("Please install chromadb: `pip install chromadb`") + +CHROMADB_MAX_BATCH_SIZE = os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000) +logger = get_logger(__name__) + + +class ChromaVectorDB: + """ + A vector database that uses ChromaDB as the backend. + """ + + def __init__(self, db_config: dict = None) -> None: + """ + Initialize the vector database. + + Args: + db_config: dict | configuration for initializing the vector database. Default is None. + It can contain the following keys: + client: chromadb.Client | The client object of the vector database. Default is None. + If not None, it will use the client object to connect to the vector database. + path: str | The path to the vector database. Default is None. + embedding_function: Callable | The embedding function used to generate the vector representation + of the documents. Default is None. + metadata: dict | The metadata of the vector database. Default is None. If None, it will use this + setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of + the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances), + [hnsw](https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184), + and [ALGO_PARAMS](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md). + kwargs: dict | Additional keyword arguments. + + Returns: + None + """ + if db_config is None: + db_config = {} + self.client = db_config.get("client") + if not self.client: + self.path = db_config.get("path") + self.embedding_function = db_config.get("embedding_function") + self.metadata = db_config.get("metadata", {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}) + kwargs = db_config.get("kwargs", {}) + if self.path is not None: + self.client = chromadb.PersistentClient(path=self.path, **kwargs) + else: + self.client = chromadb.Client(**kwargs) + self.active_collection = None + + def create_collection( + self, collection_name: str, overwrite: bool = False, get_or_create: bool = True + ) -> Collection: + """ + Create a collection in the vector database. + Case 1. if the collection does not exist, create the collection. + Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, + otherwise it raise a ValueError. + + Args: + collection_name: str | The name of the collection. + overwrite: bool | Whether to overwrite the collection if it exists. Default is False. + get_or_create: bool | Whether to get the collection if it exists. Default is True. + + Returns: + Collection | The collection object. + """ + try: + collection = self.client.get_collection(collection_name) + except ValueError: + collection = None + if collection is None: + return self.client.create_collection( + collection_name, + embedding_function=self.embedding_function, + get_or_create=get_or_create, + metadata=self.metadata, + ) + elif overwrite: + self.client.delete_collection(collection_name) + return self.client.create_collection( + collection_name, + embedding_function=self.embedding_function, + get_or_create=get_or_create, + metadata=self.metadata, + ) + elif get_or_create: + return collection + else: + raise ValueError(f"Collection {collection_name} already exists.") + + def get_collection(self, collection_name: str = None) -> Collection: + """ + Get the collection from the vector database. + + Args: + collection_name: str | The name of the collection. Default is None. If None, return the + current active collection. + + Returns: + Collection | The collection object. + """ + if collection_name is None: + if self.active_collection is None: + raise ValueError("No collection is specified.") + else: + logger.info( + f"No collection is specified. Using current active collection {self.active_collection.name}." + ) + else: + self.active_collection = self.client.get_collection(collection_name) + return self.active_collection + + def delete_collection(self, collection_name: str) -> None: + """ + Delete the collection from the vector database. + + Args: + collection_name: str | The name of the collection. + + Returns: + None + """ + self.client.delete_collection(collection_name) + if self.active_collection: + if self.active_collection.name == collection_name: + self.active_collection = None + + def _batch_insert(self, collection, embeddings=None, ids=None, metadata=None, documents=None, upsert=False): + batch_size = int(CHROMADB_MAX_BATCH_SIZE) + for i in range(0, len(documents), min(batch_size, len(documents))): + end_idx = i + min(batch_size, len(documents) - i) + collection_kwargs = { + "documents": documents[i:end_idx], + "ids": ids[i:end_idx], + "metadatas": metadata[i:end_idx] if metadata else None, + "embeddings": embeddings[i:end_idx] if embeddings else None, + } + if upsert: + collection.upsert(**collection_kwargs) + else: + collection.add(**collection_kwargs) + + @timer + def insert_docs(self, docs: List[dict], collection_name: str = None, upsert: bool = False) -> None: + """ + Insert documents into the collection of the vector database. + + Args: + docs: List[dict] | A list of documents. Each document is a dictionary. + It should include the following fields: + - required: "id", "content" + - optional: "embedding", "metadata", "distance", etc. + collection_name: str | The name of the collection. Default is None. + upsert: bool | Whether to update the document if it exists. Default is False. + kwargs: dict | Additional keyword arguments. + + Returns: + None + """ + if not docs: + return + collection = self.get_collection(collection_name) + if docs[0].get("embedding") is None: + logger.info( + "No content embedding is provided. Will use the VectorDB's embedding function to generate the content embedding." + ) + embeddings = None + else: + embeddings = [doc.embedding for doc in docs] + documents = [doc.content for doc in docs] + ids = [doc.id for doc in docs] + metadata = [doc.get("metadata") for doc in docs] + self._batch_insert(collection, embeddings, ids, metadata, documents, upsert) + + def update_docs(self, docs: List[dict], collection_name: str = None) -> None: + """ + Update documents in the collection of the vector database. + + Args: + docs: List[dict] | A list of documents. + collection_name: str | The name of the collection. Default is None. + + Returns: + None + """ + self.insert_docs(docs, collection_name, upsert=True) + + def delete_docs(self, ids: List[Any], collection_name: str = None, **kwargs) -> None: + """ + Delete documents from the collection of the vector database. + + Args: + ids: List[Any] | A list of document ids. + collection_name: str | The name of the collection. Default is None. + kwargs: dict | Additional keyword arguments. + + Returns: + None + """ + collection = self.get_collection(collection_name) + collection.delete(ids, **kwargs) + + def retrieve_docs( + self, + queries: List[str], + collection_name: str = None, + n_results: int = 10, + distance_threshold: float = -1, + **kwargs, + ) -> List[List[dict]]: + """ + Retrieve documents from the collection of the vector database based on the queries. + + Args: + queries: List[str] | A list of queries. Each query is a string. + collection_name: str | The name of the collection. Default is None. + n_results: int | The number of relevant documents to return. Default is 10. + distance_threshold: float | The threshold for the distance score, only distance smaller than it will be + returned. Don't filter with it if < 0. Default is -1. + kwargs: dict | Additional keyword arguments. + + Returns: + List[List[dict]] | The query results. Each query result is a list of dictionaries. + It should include the following fields: + - required: "ids", "contents" + - optional: "embeddings", "metadatas", "distances", etc. + + queries example: ["query1", "query2"] + query results example: [ + { + "ids": ["id1", "id2", ...], + "contents": ["content1", "content2", ...], + "embeddings": ["embedding1", "embedding2", ...], + "metadatas": ["metadata1", "metadata2", ...], + "distances": ["distance1", "distance2", ...] + }, + { + "ids": ["id1", "id2", ...], + "contents": ["content1", "content2", ...], + "embeddings": ["embedding1", "embedding2", ...], + "metadatas": ["metadata1", "metadata2", ...], + "distances": ["distance1", "distance2", ...] + } + ] + + """ + collection = self.get_collection(collection_name) + if isinstance(queries, str): + queries = [queries] + results = collection.query( + query_texts=queries, + n_results=n_results, + **kwargs, + ) + results["contents"] = results.pop("documents") + return results + + def get_docs_by_ids(self, ids: List[Any], collection_name: str = None, include=None, **kwargs) -> List[dict]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: List[Any] | A list of document ids. + collection_name: str | The name of the collection. Default is None. + include: List[str] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"] + kwargs: dict | Additional keyword arguments. + + Returns: + List[dict] | The query results. + """ + collection = self.get_collection(collection_name) + include = include if include else ["metadatas", "documents"] + results = collection.get(ids, include=include, **kwargs) + return results diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py new file mode 100644 index 000000000000..54e5a47ae9a8 --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -0,0 +1,92 @@ +import importlib +import logging +import time +from functools import wraps +from typing import Any + +from termcolor import colored + + +class ColoredLogger(logging.Logger): + def __init__(self, name, level=logging.NOTSET): + super().__init__(name, level) + + def debug(self, msg, *args, color=None, **kwargs): + super().debug(colored(msg, color), *args, **kwargs) + + def info(self, msg, *args, color=None, **kwargs): + super().info(colored(msg, color), *args, **kwargs) + + def warning(self, msg, *args, color="yellow", **kwargs): + super().warning(colored(msg, color), *args, **kwargs) + + def error(self, msg, *args, color="light_red", **kwargs): + super().error(colored(msg, color), *args, **kwargs) + + def critical(self, msg, *args, color="red", **kwargs): + super().critical(colored(msg, color), *args, **kwargs) + + +def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger: + logger = ColoredLogger(name, level) + console_handler = logging.StreamHandler() + logger.addHandler(console_handler) + formatter = logging.Formatter("%(asctime)s - %(filename)s:%(lineno)5d - %(levelname)s - %(message)s") + logger.handlers[0].setFormatter(formatter) + return logger + + +lazy_imported = {} +logger = get_logger(__name__) + + +def lazy_import(module_name: str, attr_name: str = None) -> Any: + """lazy import module and attribute. + + Args: + module_name: The name of the module to import. + attr_name: The name of the attribute to import. + + Returns: + The imported module or attribute. + + Example usage: + ```python + from autogen.agentchat.contrib.vectordb.utils import lazy_import + os = lazy_import("os") + p = lazy_import("os", "path") + print(os) + print(p) + print(os.path is p) # True + ``` + """ + if module_name not in lazy_imported: + try: + lazy_imported[module_name] = importlib.import_module(module_name) + except ImportError: + logger.error(f"Failed to import {module_name}.") + return None + if attr_name: + attr = getattr(lazy_imported[module_name], attr_name, None) + if attr is None: + logger.error(f"Failed to import {attr_name} from {module_name}") + return None + else: + return attr + else: + return lazy_imported[module_name] + + +def timer(func) -> Any: + """ + Timer decorator. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + logger.debug(f"{func.__name__} took {time.time() - start:.2f} seconds.") + return result + + return wrapper From 5c0770aeff24ff98dd02782a4d264c4068911b79 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 3 Apr 2024 01:21:40 +0000 Subject: [PATCH 02/27] Remove timer and unused functions --- .../agentchat/contrib/vectordb/chromadb.py | 3 +- autogen/agentchat/contrib/vectordb/utils.py | 59 ------------------- 2 files changed, 1 insertion(+), 61 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 04a9268625bb..0230a5cd4574 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, List -from .utils import get_logger, timer +from .utils import get_logger try: import chromadb @@ -151,7 +151,6 @@ def _batch_insert(self, collection, embeddings=None, ids=None, metadata=None, do else: collection.add(**collection_kwargs) - @timer def insert_docs(self, docs: List[dict], collection_name: str = None, upsert: bool = False) -> None: """ Insert documents into the collection of the vector database. diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index 54e5a47ae9a8..4a808449dded 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -1,9 +1,4 @@ -import importlib import logging -import time -from functools import wraps -from typing import Any - from termcolor import colored @@ -35,58 +30,4 @@ def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger: logger.handlers[0].setFormatter(formatter) return logger - -lazy_imported = {} logger = get_logger(__name__) - - -def lazy_import(module_name: str, attr_name: str = None) -> Any: - """lazy import module and attribute. - - Args: - module_name: The name of the module to import. - attr_name: The name of the attribute to import. - - Returns: - The imported module or attribute. - - Example usage: - ```python - from autogen.agentchat.contrib.vectordb.utils import lazy_import - os = lazy_import("os") - p = lazy_import("os", "path") - print(os) - print(p) - print(os.path is p) # True - ``` - """ - if module_name not in lazy_imported: - try: - lazy_imported[module_name] = importlib.import_module(module_name) - except ImportError: - logger.error(f"Failed to import {module_name}.") - return None - if attr_name: - attr = getattr(lazy_imported[module_name], attr_name, None) - if attr is None: - logger.error(f"Failed to import {attr_name} from {module_name}") - return None - else: - return attr - else: - return lazy_imported[module_name] - - -def timer(func) -> Any: - """ - Timer decorator. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - start = time.time() - result = func(*args, **kwargs) - logger.debug(f"{func.__name__} took {time.time() - start:.2f} seconds.") - return result - - return wrapper From 901ac1a0f4a985251f1e88fe3ca3e15e9d0e5dbf Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 3 Apr 2024 02:33:35 +0000 Subject: [PATCH 03/27] Added filter by distance --- .../agentchat/contrib/vectordb/__init__.py | 0 autogen/agentchat/contrib/vectordb/base.py | 33 +++++++---------- .../agentchat/contrib/vectordb/chromadb.py | 35 ++++++++----------- autogen/agentchat/contrib/vectordb/utils.py | 28 +++++++++++++++ 4 files changed, 54 insertions(+), 42 deletions(-) create mode 100644 autogen/agentchat/contrib/vectordb/__init__.py diff --git a/autogen/agentchat/contrib/vectordb/__init__.py b/autogen/agentchat/contrib/vectordb/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index 06927b0c377a..8e55e3434a74 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -115,7 +115,7 @@ def retrieve_docs( n_results: int = 10, distance_threshold: float = -1, **kwargs, - ) -> List[List[dict]]: + ) -> dict[str, List[List[dict]]]: """ Retrieve documents from the collection of the vector database based on the queries. @@ -128,33 +128,24 @@ def retrieve_docs( kwargs: dict | Additional keyword arguments. Returns: - List[List[dict]] | The query results. Each query result is a list of dictionaries. + dict[str, List[List[dict]]] | The query results. Each query result is a dictionary. It should include the following fields: - required: "ids", "contents" - optional: "embeddings", "metadatas", "distances", etc. queries example: ["query1", "query2"] - query results example: [ - { - "ids": ["id1", "id2", ...], - "contents": ["content1", "content2", ...], - "embeddings": ["embedding1", "embedding2", ...], - "metadatas": ["metadata1", "metadata2", ...], - "distances": ["distance1", "distance2", ...] - }, - { - "ids": ["id1", "id2", ...], - "contents": ["content1", "content2", ...], - "embeddings": ["embedding1", "embedding2", ...], - "metadatas": ["metadata1", "metadata2", ...], - "distances": ["distance1", "distance2", ...] - } - ] + query results example: { + "ids": [["id1", "id2", ...], ["id3", "id4", ...]], + "contents": [["content1", "content2", ...], ["content3", "content4", ...]], + "embeddings": [["embedding1", "embedding2", ...], ["embedding3", "embedding4", ...]], + "metadatas": [["metadata1", "metadata2", ...], ["metadata3", "metadata4", ...]], + "distances": [["distance1", "distance2", ...], ["distance3", "distance4", ...]], + } """ ... - def get_docs_by_ids(self, ids: List[Any], collection_name: str = None, include=None, **kwargs) -> List[dict]: + def get_docs_by_ids(self, ids: List[Any], collection_name: str = None, include=None, **kwargs) -> dict[str, List[dict]]: """ Retrieve documents from the collection of the vector database based on the ids. @@ -162,11 +153,11 @@ def get_docs_by_ids(self, ids: List[Any], collection_name: str = None, include=N ids: List[Any] | A list of document ids. collection_name: str | The name of the collection. Default is None. include: List[str] | The fields to include. Default is None. - If None, will include ["ids", "contents"] + If None, will include ["metadatas", "documents"] kwargs: dict | Additional keyword arguments. Returns: - List[dict] | The query results. + dict[str, List[dict]] | The results. """ ... diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 0230a5cd4574..49d564c2065a 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, List -from .utils import get_logger +from .utils import get_logger, filter_results_by_distance try: import chromadb @@ -217,7 +217,7 @@ def retrieve_docs( n_results: int = 10, distance_threshold: float = -1, **kwargs, - ) -> List[List[dict]]: + ) -> dict[str, List[List[dict]]]: """ Retrieve documents from the collection of the vector database based on the queries. @@ -230,28 +230,19 @@ def retrieve_docs( kwargs: dict | Additional keyword arguments. Returns: - List[List[dict]] | The query results. Each query result is a list of dictionaries. + dict[str, List[List[dict]]] | The query results. Each query result is a dictionary. It should include the following fields: - required: "ids", "contents" - optional: "embeddings", "metadatas", "distances", etc. queries example: ["query1", "query2"] - query results example: [ - { - "ids": ["id1", "id2", ...], - "contents": ["content1", "content2", ...], - "embeddings": ["embedding1", "embedding2", ...], - "metadatas": ["metadata1", "metadata2", ...], - "distances": ["distance1", "distance2", ...] - }, - { - "ids": ["id1", "id2", ...], - "contents": ["content1", "content2", ...], - "embeddings": ["embedding1", "embedding2", ...], - "metadatas": ["metadata1", "metadata2", ...], - "distances": ["distance1", "distance2", ...] - } - ] + query results example: { + "ids": [["id1", "id2", ...], ["id3", "id4", ...]], + "contents": [["content1", "content2", ...], ["content3", "content4", ...]], + "embeddings": [["embedding1", "embedding2", ...], ["embedding3", "embedding4", ...]], + "metadatas": [["metadata1", "metadata2", ...], ["metadata3", "metadata4", ...]], + "distances": [["distance1", "distance2", ...], ["distance3", "distance4", ...]], + } """ collection = self.get_collection(collection_name) @@ -263,9 +254,11 @@ def retrieve_docs( **kwargs, ) results["contents"] = results.pop("documents") + results = filter_results_by_distance(results, distance_threshold) + return results - def get_docs_by_ids(self, ids: List[Any], collection_name: str = None, include=None, **kwargs) -> List[dict]: + def get_docs_by_ids(self, ids: List[Any], collection_name: str = None, include=None, **kwargs) -> dict[str, List[dict]]: """ Retrieve documents from the collection of the vector database based on the ids. @@ -277,7 +270,7 @@ def get_docs_by_ids(self, ids: List[Any], collection_name: str = None, include=N kwargs: dict | Additional keyword arguments. Returns: - List[dict] | The query results. + dict[str, List[dict]] | The results. """ collection = self.get_collection(collection_name) include = include if include else ["metadatas", "documents"] diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index 4a808449dded..3275022a7b7b 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -1,3 +1,4 @@ +from typing import List import logging from termcolor import colored @@ -31,3 +32,30 @@ def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger: return logger logger = get_logger(__name__) + + +def filter_results_by_distance(results:dict[str, List[List[dict]]], distance_threshold: float = -1) -> dict[str, List[List[dict]]]: + """Filters results based on a distance threshold. + + Args: + results: A dictionary containing results to be filtered. + distance_threshold: The maximum distance allowed for results. + + Returns: + dict[str, List[List[dict]]] | A filtered dictionary containing only results within the threshold. + """ + + if distance_threshold > 0: + # Filter distances first: + return_ridx = [ + [ridx for ridx, distance in enumerate(distances) if distance < distance_threshold] + for distances in results["distances"] + ] + + # Filter other keys based on filtered distances: + results = { + key: [[value for ridx, value in enumerate(results_list) if ridx in return_ridx[qidx]] for qidx, results_list in enumerate(results_lists)] + for key, results_lists in results.items() + } + + return results From 5af4a0ca67cc505392f3a921fc783193a6bcdbed Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 3 Apr 2024 02:33:53 +0000 Subject: [PATCH 04/27] Added test utils --- test/agentchat/contrib/vectordb/test_utils.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 test/agentchat/contrib/vectordb/test_utils.py diff --git a/test/agentchat/contrib/vectordb/test_utils.py b/test/agentchat/contrib/vectordb/test_utils.py new file mode 100644 index 000000000000..eecd8a76f573 --- /dev/null +++ b/test/agentchat/contrib/vectordb/test_utils.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 -m pytest + +import pytest +import os +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) +from conftest import skip_openai # noqa: E402 + +from autogen.agentchat.contrib.vectordb.utils import filter_results_by_distance + +skip = skip_openai + +@pytest.mark.skipif( + skip, + reason="do not run for openai", +) +def test_retrieve_config(): + results = { + "ids": [["id1", "id2"], ["id3", "id4"]], + "contents": [["content1", "content2"], ["content3", "content4", ]], + "embeddings": [["embedding1", "embedding2", ], ["embedding3", "embedding4", ]], + "metadatas": [["metadata1", "metadata2", ], ["metadata3", "metadata4", ]], + "distances": [[1,2], [2,3]], + } + print(filter_results_by_distance(results, 2.1)) + filter_results = {'ids': [['id1', 'id2'], ['id3']], 'contents': [['content1', 'content2'], ['content3']], 'embeddings': [['embedding1', 'embedding2'], ['embedding3']], 'metadatas': [['metadata1', 'metadata2'], ['metadata3']], 'distances': [[1, 2], [2]]} + assert filter_results == filter_results_by_distance(results, 2.1) \ No newline at end of file From 1d9984ec8d946dc64323ebe955d4d921db15d0fe Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 3 Apr 2024 02:36:41 +0000 Subject: [PATCH 05/27] Fix format --- autogen/agentchat/contrib/vectordb/base.py | 4 +- .../agentchat/contrib/vectordb/chromadb.py | 4 +- autogen/agentchat/contrib/vectordb/utils.py | 10 +++- test/agentchat/contrib/vectordb/test_utils.py | 47 +++++++++++++++---- 4 files changed, 53 insertions(+), 12 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index 8e55e3434a74..3b940cdadeb6 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -145,7 +145,9 @@ def retrieve_docs( """ ... - def get_docs_by_ids(self, ids: List[Any], collection_name: str = None, include=None, **kwargs) -> dict[str, List[dict]]: + def get_docs_by_ids( + self, ids: List[Any], collection_name: str = None, include=None, **kwargs + ) -> dict[str, List[dict]]: """ Retrieve documents from the collection of the vector database based on the ids. diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 49d564c2065a..07031f3f0901 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -258,7 +258,9 @@ def retrieve_docs( return results - def get_docs_by_ids(self, ids: List[Any], collection_name: str = None, include=None, **kwargs) -> dict[str, List[dict]]: + def get_docs_by_ids( + self, ids: List[Any], collection_name: str = None, include=None, **kwargs + ) -> dict[str, List[dict]]: """ Retrieve documents from the collection of the vector database based on the ids. diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index 3275022a7b7b..5f2e4d8f3da9 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -31,10 +31,13 @@ def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger: logger.handlers[0].setFormatter(formatter) return logger + logger = get_logger(__name__) -def filter_results_by_distance(results:dict[str, List[List[dict]]], distance_threshold: float = -1) -> dict[str, List[List[dict]]]: +def filter_results_by_distance( + results: dict[str, List[List[dict]]], distance_threshold: float = -1 +) -> dict[str, List[List[dict]]]: """Filters results based on a distance threshold. Args: @@ -54,7 +57,10 @@ def filter_results_by_distance(results:dict[str, List[List[dict]]], distance_thr # Filter other keys based on filtered distances: results = { - key: [[value for ridx, value in enumerate(results_list) if ridx in return_ridx[qidx]] for qidx, results_list in enumerate(results_lists)] + key: [ + [value for ridx, value in enumerate(results_list) if ridx in return_ridx[qidx]] + for qidx, results_list in enumerate(results_lists) + ] for key, results_lists in results.items() } diff --git a/test/agentchat/contrib/vectordb/test_utils.py b/test/agentchat/contrib/vectordb/test_utils.py index eecd8a76f573..527af6be9a5c 100644 --- a/test/agentchat/contrib/vectordb/test_utils.py +++ b/test/agentchat/contrib/vectordb/test_utils.py @@ -11,18 +11,49 @@ skip = skip_openai + @pytest.mark.skipif( skip, reason="do not run for openai", ) def test_retrieve_config(): results = { - "ids": [["id1", "id2"], ["id3", "id4"]], - "contents": [["content1", "content2"], ["content3", "content4", ]], - "embeddings": [["embedding1", "embedding2", ], ["embedding3", "embedding4", ]], - "metadatas": [["metadata1", "metadata2", ], ["metadata3", "metadata4", ]], - "distances": [[1,2], [2,3]], - } + "ids": [["id1", "id2"], ["id3", "id4"]], + "contents": [ + ["content1", "content2"], + [ + "content3", + "content4", + ], + ], + "embeddings": [ + [ + "embedding1", + "embedding2", + ], + [ + "embedding3", + "embedding4", + ], + ], + "metadatas": [ + [ + "metadata1", + "metadata2", + ], + [ + "metadata3", + "metadata4", + ], + ], + "distances": [[1, 2], [2, 3]], + } print(filter_results_by_distance(results, 2.1)) - filter_results = {'ids': [['id1', 'id2'], ['id3']], 'contents': [['content1', 'content2'], ['content3']], 'embeddings': [['embedding1', 'embedding2'], ['embedding3']], 'metadatas': [['metadata1', 'metadata2'], ['metadata3']], 'distances': [[1, 2], [2]]} - assert filter_results == filter_results_by_distance(results, 2.1) \ No newline at end of file + filter_results = { + "ids": [["id1", "id2"], ["id3"]], + "contents": [["content1", "content2"], ["content3"]], + "embeddings": [["embedding1", "embedding2"], ["embedding3"]], + "metadatas": [["metadata1", "metadata2"], ["metadata3"]], + "distances": [[1, 2], [2]], + } + assert filter_results == filter_results_by_distance(results, 2.1) From 7793a06b2901a58dbe567c2f2aa1758d24ddd77f Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 3 Apr 2024 03:00:42 +0000 Subject: [PATCH 06/27] Fix type hint of dict --- autogen/agentchat/contrib/vectordb/base.py | 36 +++++++++---------- .../agentchat/contrib/vectordb/chromadb.py | 34 +++++++++--------- autogen/agentchat/contrib/vectordb/utils.py | 8 ++--- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index 3b940cdadeb6..3c922279c35b 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Protocol, runtime_checkable +from typing import Any, Callable, List, Protocol, runtime_checkable, Dict @runtime_checkable @@ -7,12 +7,12 @@ class VectorDB(Protocol): Abstract class for vector database. A vector database is responsible for storing and retrieving documents. """ - def __init__(self, db_config: dict = None) -> None: + def __init__(self, db_config: Dict = None) -> None: """ Initialize the vector database. Args: - db_config: dict | configuration for initializing the vector database. Default is None. + db_config: Dict | configuration for initializing the vector database. Default is None. Returns: None @@ -62,32 +62,32 @@ def delete_collection(self, collection_name: str) -> Any: """ ... - def insert_docs(self, docs: List[dict], collection_name: str = None, upsert: bool = False, **kwargs) -> Any: + def insert_docs(self, docs: List[Dict], collection_name: str = None, upsert: bool = False, **kwargs) -> Any: """ Insert documents into the collection of the vector database. Args: - docs: List[dict] | A list of documents. Each document is a dictionary. + docs: List[Dict] | A list of documents. Each document is a dictionary. It should include the following fields: - required: "id", "content" - optional: "embedding", "metadata", "distance", etc. collection_name: str | The name of the collection. Default is None. upsert: bool | Whether to update the document if it exists. Default is False. - kwargs: dict | Additional keyword arguments. + kwargs: Dict | Additional keyword arguments. Returns: None """ ... - def update_docs(self, docs: List[dict], collection_name: str = None, **kwargs) -> None: + def update_docs(self, docs: List[Dict], collection_name: str = None, **kwargs) -> None: """ Update documents in the collection of the vector database. Args: - docs: List[dict] | A list of documents. + docs: List[Dict] | A list of documents. collection_name: str | The name of the collection. Default is None. - kwargs: dict | Additional keyword arguments. + kwargs: Dict | Additional keyword arguments. Returns: None @@ -101,7 +101,7 @@ def delete_docs(self, ids: List[Any], collection_name: str = None, **kwargs) -> Args: ids: List[Any] | A list of document ids. collection_name: str | The name of the collection. Default is None. - kwargs: dict | Additional keyword arguments. + kwargs: Dict | Additional keyword arguments. Returns: None @@ -115,7 +115,7 @@ def retrieve_docs( n_results: int = 10, distance_threshold: float = -1, **kwargs, - ) -> dict[str, List[List[dict]]]: + ) -> Dict[str, List[List[Dict]]]: """ Retrieve documents from the collection of the vector database based on the queries. @@ -125,10 +125,10 @@ def retrieve_docs( n_results: int | The number of relevant documents to return. Default is 10. distance_threshold: float | The threshold for the distance score, only distance smaller than it will be returned. Don't filter with it if < 0. Default is -1. - kwargs: dict | Additional keyword arguments. + kwargs: Dict | Additional keyword arguments. Returns: - dict[str, List[List[dict]]] | The query results. Each query result is a dictionary. + Dict[str, List[List[Dict]]] | The query results. Each query result is a dictionary. It should include the following fields: - required: "ids", "contents" - optional: "embeddings", "metadatas", "distances", etc. @@ -147,7 +147,7 @@ def retrieve_docs( def get_docs_by_ids( self, ids: List[Any], collection_name: str = None, include=None, **kwargs - ) -> dict[str, List[dict]]: + ) -> Dict[str, List[Dict]]: """ Retrieve documents from the collection of the vector database based on the ids. @@ -156,10 +156,10 @@ def get_docs_by_ids( collection_name: str | The name of the collection. Default is None. include: List[str] | The fields to include. Default is None. If None, will include ["metadatas", "documents"] - kwargs: dict | Additional keyword arguments. + kwargs: Dict | Additional keyword arguments. Returns: - dict[str, List[dict]] | The results. + Dict[str, List[Dict]] | The results. """ ... @@ -172,13 +172,13 @@ class VectorDBFactory: PREDEFINED_VECTOR_DB = ["chroma"] @staticmethod - def create_vector_db(db_type: str, db_config: dict = None) -> VectorDB: + def create_vector_db(db_type: str, db_config: Dict = None) -> VectorDB: """ Create a vector database. Args: db_type: str | The type of the vector database. - db_config: dict | The configuration of the vector database. Default is None. + db_config: Dict | The configuration of the vector database. Default is None. Returns: VectorDB | The vector database. diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 07031f3f0901..447a6c68f734 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, List +from typing import Any, Callable, List, Dict from .utils import get_logger, filter_results_by_distance @@ -21,24 +21,24 @@ class ChromaVectorDB: A vector database that uses ChromaDB as the backend. """ - def __init__(self, db_config: dict = None) -> None: + def __init__(self, db_config: Dict = None) -> None: """ Initialize the vector database. Args: - db_config: dict | configuration for initializing the vector database. Default is None. + db_config: Dict | configuration for initializing the vector database. Default is None. It can contain the following keys: client: chromadb.Client | The client object of the vector database. Default is None. If not None, it will use the client object to connect to the vector database. path: str | The path to the vector database. Default is None. embedding_function: Callable | The embedding function used to generate the vector representation of the documents. Default is None. - metadata: dict | The metadata of the vector database. Default is None. If None, it will use this + metadata: Dict | The metadata of the vector database. Default is None. If None, it will use this setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances), [hnsw](https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184), and [ALGO_PARAMS](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md). - kwargs: dict | Additional keyword arguments. + kwargs: Dict | Additional keyword arguments. Returns: None @@ -151,18 +151,18 @@ def _batch_insert(self, collection, embeddings=None, ids=None, metadata=None, do else: collection.add(**collection_kwargs) - def insert_docs(self, docs: List[dict], collection_name: str = None, upsert: bool = False) -> None: + def insert_docs(self, docs: List[Dict], collection_name: str = None, upsert: bool = False) -> None: """ Insert documents into the collection of the vector database. Args: - docs: List[dict] | A list of documents. Each document is a dictionary. + docs: List[Dict] | A list of documents. Each document is a dictionary. It should include the following fields: - required: "id", "content" - optional: "embedding", "metadata", "distance", etc. collection_name: str | The name of the collection. Default is None. upsert: bool | Whether to update the document if it exists. Default is False. - kwargs: dict | Additional keyword arguments. + kwargs: Dict | Additional keyword arguments. Returns: None @@ -182,12 +182,12 @@ def insert_docs(self, docs: List[dict], collection_name: str = None, upsert: boo metadata = [doc.get("metadata") for doc in docs] self._batch_insert(collection, embeddings, ids, metadata, documents, upsert) - def update_docs(self, docs: List[dict], collection_name: str = None) -> None: + def update_docs(self, docs: List[Dict], collection_name: str = None) -> None: """ Update documents in the collection of the vector database. Args: - docs: List[dict] | A list of documents. + docs: List[Dict] | A list of documents. collection_name: str | The name of the collection. Default is None. Returns: @@ -202,7 +202,7 @@ def delete_docs(self, ids: List[Any], collection_name: str = None, **kwargs) -> Args: ids: List[Any] | A list of document ids. collection_name: str | The name of the collection. Default is None. - kwargs: dict | Additional keyword arguments. + kwargs: Dict | Additional keyword arguments. Returns: None @@ -217,7 +217,7 @@ def retrieve_docs( n_results: int = 10, distance_threshold: float = -1, **kwargs, - ) -> dict[str, List[List[dict]]]: + ) -> Dict[str, List[List[Dict]]]: """ Retrieve documents from the collection of the vector database based on the queries. @@ -227,10 +227,10 @@ def retrieve_docs( n_results: int | The number of relevant documents to return. Default is 10. distance_threshold: float | The threshold for the distance score, only distance smaller than it will be returned. Don't filter with it if < 0. Default is -1. - kwargs: dict | Additional keyword arguments. + kwargs: Dict | Additional keyword arguments. Returns: - dict[str, List[List[dict]]] | The query results. Each query result is a dictionary. + Dict[str, List[List[Dict]]] | The query results. Each query result is a dictionary. It should include the following fields: - required: "ids", "contents" - optional: "embeddings", "metadatas", "distances", etc. @@ -260,7 +260,7 @@ def retrieve_docs( def get_docs_by_ids( self, ids: List[Any], collection_name: str = None, include=None, **kwargs - ) -> dict[str, List[dict]]: + ) -> Dict[str, List[Dict]]: """ Retrieve documents from the collection of the vector database based on the ids. @@ -269,10 +269,10 @@ def get_docs_by_ids( collection_name: str | The name of the collection. Default is None. include: List[str] | The fields to include. Default is None. If None, will include ["metadatas", "documents"] - kwargs: dict | Additional keyword arguments. + kwargs: Dict | Additional keyword arguments. Returns: - dict[str, List[dict]] | The results. + Dict[str, List[Dict]] | The results. """ collection = self.get_collection(collection_name) include = include if include else ["metadatas", "documents"] diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index 5f2e4d8f3da9..1926170c6571 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict import logging from termcolor import colored @@ -36,8 +36,8 @@ def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger: def filter_results_by_distance( - results: dict[str, List[List[dict]]], distance_threshold: float = -1 -) -> dict[str, List[List[dict]]]: + results: Dict[str, List[List[Dict]]], distance_threshold: float = -1 +) -> Dict[str, List[List[Dict]]]: """Filters results based on a distance threshold. Args: @@ -45,7 +45,7 @@ def filter_results_by_distance( distance_threshold: The maximum distance allowed for results. Returns: - dict[str, List[List[dict]]] | A filtered dictionary containing only results within the threshold. + Dict[str, List[List[Dict]]] | A filtered dictionary containing only results within the threshold. """ if distance_threshold > 0: From cf7fcda24e1dc1489778c9b46c61bb452ff6cbb0 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 3 Apr 2024 03:04:48 +0000 Subject: [PATCH 07/27] Rename test --- .../contrib/vectordb/{test_utils.py => test_vectordb_utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/agentchat/contrib/vectordb/{test_utils.py => test_vectordb_utils.py} (100%) diff --git a/test/agentchat/contrib/vectordb/test_utils.py b/test/agentchat/contrib/vectordb/test_vectordb_utils.py similarity index 100% rename from test/agentchat/contrib/vectordb/test_utils.py rename to test/agentchat/contrib/vectordb/test_vectordb_utils.py From 7d32490a6535e4fc018fb05ea3062ae986689dd9 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 3 Apr 2024 09:54:36 +0000 Subject: [PATCH 08/27] Add test chromadb --- .../agentchat/contrib/vectordb/__init__.py | 2 + .../agentchat/contrib/vectordb/chromadb.py | 15 +++- autogen/agentchat/contrib/vectordb/utils.py | 12 ++- .../contrib/vectordb/test_chromadb.py | 75 +++++++++++++++++++ 4 files changed, 96 insertions(+), 8 deletions(-) create mode 100644 test/agentchat/contrib/vectordb/test_chromadb.py diff --git a/autogen/agentchat/contrib/vectordb/__init__.py b/autogen/agentchat/contrib/vectordb/__init__.py index e69de29bb2d1..8f2902fb6ad6 100644 --- a/autogen/agentchat/contrib/vectordb/__init__.py +++ b/autogen/agentchat/contrib/vectordb/__init__.py @@ -0,0 +1,2 @@ +from .base import VectorDB, VectorDBFactory +from .chromadb import ChromaVectorDB diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 447a6c68f734..75f7bbc60030 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -169,6 +169,12 @@ def insert_docs(self, docs: List[Dict], collection_name: str = None, upsert: boo """ if not docs: return + if docs[0].get("content") is None: + raise ValueError("The document content is required.") + if docs[0].get("id") is None: + raise ValueError("The document id is required.") + documents = [doc.get("content") for doc in docs] + ids = [doc.get("id") for doc in docs] collection = self.get_collection(collection_name) if docs[0].get("embedding") is None: logger.info( @@ -176,10 +182,11 @@ def insert_docs(self, docs: List[Dict], collection_name: str = None, upsert: boo ) embeddings = None else: - embeddings = [doc.embedding for doc in docs] - documents = [doc.content for doc in docs] - ids = [doc.id for doc in docs] - metadata = [doc.get("metadata") for doc in docs] + embeddings = [doc.get("embedding") for doc in docs] + if docs[0].get("metadata") is None: + metadata = None + else: + metadata = [doc.get("metadata") for doc in docs] self._batch_insert(collection, embeddings, ids, metadata, documents, upsert) def update_docs(self, docs: List[Dict], collection_name: str = None) -> None: diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index 1926170c6571..0bddf5915f48 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -57,10 +57,14 @@ def filter_results_by_distance( # Filter other keys based on filtered distances: results = { - key: [ - [value for ridx, value in enumerate(results_list) if ridx in return_ridx[qidx]] - for qidx, results_list in enumerate(results_lists) - ] + key: ( + [ + [value for ridx, value in enumerate(results_list) if ridx in return_ridx[qidx]] + for qidx, results_list in enumerate(results_lists) + ] + if isinstance(results_lists, list) + else results_lists + ) for key, results_lists in results.items() } diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py new file mode 100644 index 000000000000..61a182cb791b --- /dev/null +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -0,0 +1,75 @@ +import os +import sys + +import pytest + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +try: + import chromadb + import sentence_transformers + + from autogen.agentchat.contrib.vectordb import ChromaVectorDB +except ImportError: + skip = True +else: + skip = False + + +@pytest.mark.skipif(skip, reason="dependency is not installed OR requested to skip") +def test_chromadb(): + # test create collection + db_config = {"path": ".db"} + db = ChromaVectorDB(db_config) + collection_name = "test_collection" + collection = db.create_collection(collection_name, overwrite=True, get_or_create=True) + assert collection.name == collection_name + + # test_delete_collection + db.delete_collection(collection_name) + pytest.raises(ValueError, db.get_collection, collection_name) + + # test more create collection + collection = db.create_collection(collection_name, overwrite=False, get_or_create=False) + assert collection.name == collection_name + pytest.raises(ValueError, db.create_collection, collection_name, overwrite=False, get_or_create=False) + collection = db.create_collection(collection_name, overwrite=True, get_or_create=False) + assert collection.name == collection_name + collection = db.create_collection(collection_name, overwrite=False, get_or_create=True) + assert collection.name == collection_name + + # test_get_collection + collection = db.get_collection(collection_name) + assert collection.name == collection_name + + # test_insert_docs + docs = [{"content": "doc1", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}] + db.insert_docs(docs, collection_name, upsert=False) + res = db.get_docs_by_ids(["1", "2"], collection_name, include=["documents"]) + assert res["documents"] == ["doc1", "doc2"] + + # test_update_docs + docs = [{"content": "doc11", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}] + db.update_docs(docs, collection_name) + res = db.get_docs_by_ids(["1", "2"], collection_name) + assert res["documents"] == ["doc11", "doc2"] + + # test_delete_docs + ids = ["1"] + collection_name = "test_collection" + db.delete_docs(ids, collection_name) + res = db.get_docs_by_ids(ids, collection_name) + assert res["documents"] == [] + + # test_retrieve_docs + queries = ["doc2", "doc3"] + collection_name = "test_collection" + res = db.retrieve_docs(queries, collection_name) + assert res["ids"] == [["2", "3"], ["3", "2"]] + res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1) + print(res) + assert res["ids"] == [["2"], ["3"]] + + +if __name__ == "__main__": + test_chromadb() From b2a80b55e2ccc2721c50fa66fd192f71b0a0fe33 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 3 Apr 2024 09:57:54 +0000 Subject: [PATCH 09/27] Fix test no chromadb --- autogen/agentchat/contrib/vectordb/__init__.py | 2 -- test/agentchat/contrib/vectordb/test_chromadb.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/__init__.py b/autogen/agentchat/contrib/vectordb/__init__.py index 8f2902fb6ad6..e69de29bb2d1 100644 --- a/autogen/agentchat/contrib/vectordb/__init__.py +++ b/autogen/agentchat/contrib/vectordb/__init__.py @@ -1,2 +0,0 @@ -from .base import VectorDB, VectorDBFactory -from .chromadb import ChromaVectorDB diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py index 61a182cb791b..efab776fe42a 100644 --- a/test/agentchat/contrib/vectordb/test_chromadb.py +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -9,7 +9,7 @@ import chromadb import sentence_transformers - from autogen.agentchat.contrib.vectordb import ChromaVectorDB + from autogen.agentchat.contrib.vectordb.chromadb import ChromaVectorDB except ImportError: skip = True else: From c2dce1dfeaae1fc0e7e2cd9f033a9eddd43aa63f Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 3 Apr 2024 10:08:43 +0000 Subject: [PATCH 10/27] Add coverage --- .github/workflows/contrib-openai.yml | 2 +- .github/workflows/contrib-tests.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/contrib-openai.yml b/.github/workflows/contrib-openai.yml index 4eda8d930710..4cd31dac9229 100644 --- a/.github/workflows/contrib-openai.yml +++ b/.github/workflows/contrib-openai.yml @@ -53,7 +53,7 @@ jobs: AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }} OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }} run: | - coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py + coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/vectordb coverage xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 800085f9194c..2d8be8b70cbd 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -60,11 +60,11 @@ jobs: fi - name: Test RetrieveChat run: | - pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai + pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/vectordb --skip-openai - name: Coverage run: | pip install coverage>=5.3 - coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai + coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/vectordb --skip-openai coverage xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 From dddf1ec3b840b21eb14322bfd4a1d897c0b46cb5 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 3 Apr 2024 10:18:53 +0000 Subject: [PATCH 11/27] Don't skip test vectordb utils --- test/agentchat/contrib/vectordb/test_vectordb_utils.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/agentchat/contrib/vectordb/test_vectordb_utils.py b/test/agentchat/contrib/vectordb/test_vectordb_utils.py index 527af6be9a5c..8f6ba8395a9d 100644 --- a/test/agentchat/contrib/vectordb/test_vectordb_utils.py +++ b/test/agentchat/contrib/vectordb/test_vectordb_utils.py @@ -4,18 +4,9 @@ import os import sys -sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) -from conftest import skip_openai # noqa: E402 - from autogen.agentchat.contrib.vectordb.utils import filter_results_by_distance -skip = skip_openai - -@pytest.mark.skipif( - skip, - reason="do not run for openai", -) def test_retrieve_config(): results = { "ids": [["id1", "id2"], ["id3", "id4"]], From 95df2009b34a6089f3626073279e909cd234fba0 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Thu, 4 Apr 2024 10:38:11 +0800 Subject: [PATCH 12/27] Add types --- autogen/agentchat/contrib/vectordb/_types.py | 54 ++++++++++++++ autogen/agentchat/contrib/vectordb/base.py | 51 +++++-------- .../agentchat/contrib/vectordb/chromadb.py | 74 +++++++++---------- 3 files changed, 105 insertions(+), 74 deletions(-) create mode 100644 autogen/agentchat/contrib/vectordb/_types.py diff --git a/autogen/agentchat/contrib/vectordb/_types.py b/autogen/agentchat/contrib/vectordb/_types.py new file mode 100644 index 000000000000..f6b4c9debe50 --- /dev/null +++ b/autogen/agentchat/contrib/vectordb/_types.py @@ -0,0 +1,54 @@ +from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict, Union + +Metadata = Union[Mapping[str, Union[str, int, float, bool, None]], None] +Vector = Union[Sequence[float], Sequence[int]] +ItemID = str # chromadb doesn't support int ids + + +class Document(TypedDict): + """A Document is a record in the vector database. + + id: ItemID | the unique identifier of the document. + content: str | the text content of the chunk. + metadata: Metadata | contains additional information about the document such as source, date, etc. + embedding: Vector | the vector representation of the content. + dimensions: int | the dimensions of the content_embedding. + """ + + id: ItemID + content: str + metadata: Optional[Metadata] + embedding: Optional[Vector] + dimensions: Optional[int] + + +class QueryResults(TypedDict): + """QueryResults is the response from the vector database for a query. + + ids: List[List[ItemID]] | the unique identifiers of the documents. + contents: List[List[str]] | the text content of the documents. + embeddings: List[List[Vector]] | the vector representations of the documents. + metadatas: List[List[Metadata]] | the metadata of the documents. + distances: List[List[float]] | the distances between the query and the documents. + """ + + ids: List[List[ItemID]] + contents: List[List[str]] + embeddings: Optional[List[List[Vector]]] + metadatas: Optional[List[List[Metadata]]] + distances: Optional[List[List[float]]] + + +class GetResults(TypedDict): + """GetResults is the response from the vector database for getting documents by ids. + + ids: List[ItemID] | the unique identifiers of the documents. + contents: List[str] | the text content of the documents. + embeddings: List[Vector] | the vector representations of the documents. + metadatas: List[Metadata] | the metadata of the documents. + """ + + ids: List[ItemID] + contents: Optional[List[str]] + embeddings: Optional[List[Vector]] + metadatas: Optional[List[Metadata]] diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index 3c922279c35b..9fe3becc34a3 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -1,4 +1,6 @@ -from typing import Any, Callable, List, Protocol, runtime_checkable, Dict +from typing import Any, Dict, List, Protocol, runtime_checkable + +from ._types import Document, GetResults, ItemID, QueryResults @runtime_checkable @@ -7,18 +9,6 @@ class VectorDB(Protocol): Abstract class for vector database. A vector database is responsible for storing and retrieving documents. """ - def __init__(self, db_config: Dict = None) -> None: - """ - Initialize the vector database. - - Args: - db_config: Dict | configuration for initializing the vector database. Default is None. - - Returns: - None - """ - ... - def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any: """ Create a collection in the vector database. @@ -62,15 +52,12 @@ def delete_collection(self, collection_name: str) -> Any: """ ... - def insert_docs(self, docs: List[Dict], collection_name: str = None, upsert: bool = False, **kwargs) -> Any: + def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None: """ Insert documents into the collection of the vector database. Args: - docs: List[Dict] | A list of documents. Each document is a dictionary. - It should include the following fields: - - required: "id", "content" - - optional: "embedding", "metadata", "distance", etc. + docs: List[Document] | A list of documents. Each document is a TypedDict `Document`. collection_name: str | The name of the collection. Default is None. upsert: bool | Whether to update the document if it exists. Default is False. kwargs: Dict | Additional keyword arguments. @@ -80,12 +67,12 @@ def insert_docs(self, docs: List[Dict], collection_name: str = None, upsert: boo """ ... - def update_docs(self, docs: List[Dict], collection_name: str = None, **kwargs) -> None: + def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs) -> None: """ Update documents in the collection of the vector database. Args: - docs: List[Dict] | A list of documents. + docs: List[Document] | A list of documents. collection_name: str | The name of the collection. Default is None. kwargs: Dict | Additional keyword arguments. @@ -94,12 +81,12 @@ def update_docs(self, docs: List[Dict], collection_name: str = None, **kwargs) - """ ... - def delete_docs(self, ids: List[Any], collection_name: str = None, **kwargs) -> None: + def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None: """ Delete documents from the collection of the vector database. Args: - ids: List[Any] | A list of document ids. + ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. collection_name: str | The name of the collection. Default is None. kwargs: Dict | Additional keyword arguments. @@ -115,7 +102,7 @@ def retrieve_docs( n_results: int = 10, distance_threshold: float = -1, **kwargs, - ) -> Dict[str, List[List[Dict]]]: + ) -> QueryResults: """ Retrieve documents from the collection of the vector database based on the queries. @@ -128,7 +115,7 @@ def retrieve_docs( kwargs: Dict | Additional keyword arguments. Returns: - Dict[str, List[List[Dict]]] | The query results. Each query result is a dictionary. + QueryResults | The query results. Each query result is a TypedDict `QueryResults`. It should include the following fields: - required: "ids", "contents" - optional: "embeddings", "metadatas", "distances", etc. @@ -145,9 +132,7 @@ def retrieve_docs( """ ... - def get_docs_by_ids( - self, ids: List[Any], collection_name: str = None, include=None, **kwargs - ) -> Dict[str, List[Dict]]: + def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> GetResults: """ Retrieve documents from the collection of the vector database based on the ids. @@ -155,11 +140,11 @@ def get_docs_by_ids( ids: List[Any] | A list of document ids. collection_name: str | The name of the collection. Default is None. include: List[str] | The fields to include. Default is None. - If None, will include ["metadatas", "documents"] + If None, will include ["metadatas", "documents"]. ids are always included. kwargs: Dict | Additional keyword arguments. Returns: - Dict[str, List[Dict]] | The results. + GetResults | The results. """ ... @@ -172,23 +157,21 @@ class VectorDBFactory: PREDEFINED_VECTOR_DB = ["chroma"] @staticmethod - def create_vector_db(db_type: str, db_config: Dict = None) -> VectorDB: + def create_vector_db(db_type: str, **kwargs) -> VectorDB: """ Create a vector database. Args: db_type: str | The type of the vector database. - db_config: Dict | The configuration of the vector database. Default is None. + kwargs: Dict | The keyword arguments for initializing the vector database. Returns: VectorDB | The vector database. """ - if db_config is None: - db_config = {} if db_type.lower() in ["chroma", "chromadb"]: from .chromadb import ChromaVectorDB - return ChromaVectorDB(db_config=db_config) + return ChromaVectorDB(**kwargs) else: raise ValueError( f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}." diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 75f7bbc60030..1b07c4fa5230 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -1,7 +1,8 @@ import os -from typing import Any, Callable, List, Dict +from typing import Any, Callable, Dict, List -from .utils import get_logger, filter_results_by_distance +from ._types import Document, GetResults, ItemID, QueryResults +from .utils import filter_results_by_distance, get_logger try: import chromadb @@ -21,36 +22,32 @@ class ChromaVectorDB: A vector database that uses ChromaDB as the backend. """ - def __init__(self, db_config: Dict = None) -> None: + def __init__( + self, client=None, path: str = None, embedding_function: Callable = None, metadata: dict = None, **kwargs + ) -> None: """ Initialize the vector database. Args: - db_config: Dict | configuration for initializing the vector database. Default is None. - It can contain the following keys: - client: chromadb.Client | The client object of the vector database. Default is None. - If not None, it will use the client object to connect to the vector database. - path: str | The path to the vector database. Default is None. - embedding_function: Callable | The embedding function used to generate the vector representation - of the documents. Default is None. - metadata: Dict | The metadata of the vector database. Default is None. If None, it will use this - setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of - the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances), - [hnsw](https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184), - and [ALGO_PARAMS](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md). - kwargs: Dict | Additional keyword arguments. + client: chromadb.Client | The client object of the vector database. Default is None. + path: str | The path to the vector database. Default is None. + embedding_function: Callable | The embedding function used to generate the vector representation + of the documents. Default is None. + metadata: dict | The metadata of the vector database. Default is None. If None, it will use this + setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of + the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances), + [hnsw](https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184), + and [ALGO_PARAMS](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md). + kwargs: dict | Additional keyword arguments. Returns: None """ - if db_config is None: - db_config = {} - self.client = db_config.get("client") + self.client = client if not self.client: - self.path = db_config.get("path") - self.embedding_function = db_config.get("embedding_function") - self.metadata = db_config.get("metadata", {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}) - kwargs = db_config.get("kwargs", {}) + self.path = path + self.embedding_function = embedding_function + self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32} if self.path is not None: self.client = chromadb.PersistentClient(path=self.path, **kwargs) else: @@ -136,7 +133,9 @@ def delete_collection(self, collection_name: str) -> None: if self.active_collection.name == collection_name: self.active_collection = None - def _batch_insert(self, collection, embeddings=None, ids=None, metadata=None, documents=None, upsert=False): + def _batch_insert( + self, collection: Collection, embeddings=None, ids=None, metadata=None, documents=None, upsert=False + ): batch_size = int(CHROMADB_MAX_BATCH_SIZE) for i in range(0, len(documents), min(batch_size, len(documents))): end_idx = i + min(batch_size, len(documents) - i) @@ -151,15 +150,12 @@ def _batch_insert(self, collection, embeddings=None, ids=None, metadata=None, do else: collection.add(**collection_kwargs) - def insert_docs(self, docs: List[Dict], collection_name: str = None, upsert: bool = False) -> None: + def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None: """ Insert documents into the collection of the vector database. Args: - docs: List[Dict] | A list of documents. Each document is a dictionary. - It should include the following fields: - - required: "id", "content" - - optional: "embedding", "metadata", "distance", etc. + docs: List[Document] | A list of documents. Each document is a TypedDict `Document`. collection_name: str | The name of the collection. Default is None. upsert: bool | Whether to update the document if it exists. Default is False. kwargs: Dict | Additional keyword arguments. @@ -189,12 +185,12 @@ def insert_docs(self, docs: List[Dict], collection_name: str = None, upsert: boo metadata = [doc.get("metadata") for doc in docs] self._batch_insert(collection, embeddings, ids, metadata, documents, upsert) - def update_docs(self, docs: List[Dict], collection_name: str = None) -> None: + def update_docs(self, docs: List[Document], collection_name: str = None) -> None: """ Update documents in the collection of the vector database. Args: - docs: List[Dict] | A list of documents. + docs: List[Document] | A list of documents. collection_name: str | The name of the collection. Default is None. Returns: @@ -202,12 +198,12 @@ def update_docs(self, docs: List[Dict], collection_name: str = None) -> None: """ self.insert_docs(docs, collection_name, upsert=True) - def delete_docs(self, ids: List[Any], collection_name: str = None, **kwargs) -> None: + def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None: """ Delete documents from the collection of the vector database. Args: - ids: List[Any] | A list of document ids. + ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. collection_name: str | The name of the collection. Default is None. kwargs: Dict | Additional keyword arguments. @@ -224,7 +220,7 @@ def retrieve_docs( n_results: int = 10, distance_threshold: float = -1, **kwargs, - ) -> Dict[str, List[List[Dict]]]: + ) -> QueryResults: """ Retrieve documents from the collection of the vector database based on the queries. @@ -237,7 +233,7 @@ def retrieve_docs( kwargs: Dict | Additional keyword arguments. Returns: - Dict[str, List[List[Dict]]] | The query results. Each query result is a dictionary. + QueryResults | The query results. Each query result is a TypedDict `QueryResults`. It should include the following fields: - required: "ids", "contents" - optional: "embeddings", "metadatas", "distances", etc. @@ -265,9 +261,7 @@ def retrieve_docs( return results - def get_docs_by_ids( - self, ids: List[Any], collection_name: str = None, include=None, **kwargs - ) -> Dict[str, List[Dict]]: + def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> GetResults: """ Retrieve documents from the collection of the vector database based on the ids. @@ -275,11 +269,11 @@ def get_docs_by_ids( ids: List[Any] | A list of document ids. collection_name: str | The name of the collection. Default is None. include: List[str] | The fields to include. Default is None. - If None, will include ["metadatas", "documents"] + If None, will include ["metadatas", "documents"]. ids are always included. kwargs: Dict | Additional keyword arguments. Returns: - Dict[str, List[Dict]] | The results. + GetResults | The results. """ collection = self.get_collection(collection_name) include = include if include else ["metadatas", "documents"] From 83ed9d08f6dc0f8c81517a6898937396dc547344 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Thu, 4 Apr 2024 10:42:09 +0800 Subject: [PATCH 13/27] Fix tests --- autogen/agentchat/contrib/vectordb/chromadb.py | 2 +- test/agentchat/contrib/vectordb/test_chromadb.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 1b07c4fa5230..ca371b2f734c 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -23,7 +23,7 @@ class ChromaVectorDB: """ def __init__( - self, client=None, path: str = None, embedding_function: Callable = None, metadata: dict = None, **kwargs + self, *, client=None, path: str = None, embedding_function: Callable = None, metadata: dict = None, **kwargs ) -> None: """ Initialize the vector database. diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py index efab776fe42a..3e38b21ebba6 100644 --- a/test/agentchat/contrib/vectordb/test_chromadb.py +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -19,8 +19,7 @@ @pytest.mark.skipif(skip, reason="dependency is not installed OR requested to skip") def test_chromadb(): # test create collection - db_config = {"path": ".db"} - db = ChromaVectorDB(db_config) + db = ChromaVectorDB(path=".db") collection_name = "test_collection" collection = db.create_collection(collection_name, overwrite=True, get_or_create=True) assert collection.name == collection_name From b50d93cb858b29e9f71d5b71bb6a3fbe0bef6536 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Thu, 4 Apr 2024 10:56:14 +0800 Subject: [PATCH 14/27] Fix docs build error --- autogen/agentchat/contrib/vectordb/base.py | 2 +- autogen/agentchat/contrib/vectordb/chromadb.py | 2 +- autogen/agentchat/contrib/vectordb/{_types.py => types.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename autogen/agentchat/contrib/vectordb/{_types.py => types.py} (100%) diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index 9fe3becc34a3..a4edbab56a16 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Protocol, runtime_checkable -from ._types import Document, GetResults, ItemID, QueryResults +from .types import Document, GetResults, ItemID, QueryResults @runtime_checkable diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index ca371b2f734c..be298a050345 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, List -from ._types import Document, GetResults, ItemID, QueryResults +from .types import Document, GetResults, ItemID, QueryResults from .utils import filter_results_by_distance, get_logger try: diff --git a/autogen/agentchat/contrib/vectordb/_types.py b/autogen/agentchat/contrib/vectordb/types.py similarity index 100% rename from autogen/agentchat/contrib/vectordb/_types.py rename to autogen/agentchat/contrib/vectordb/types.py From 5e37f4a339a210251988d7ddf10287a77893ad4b Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Thu, 4 Apr 2024 11:00:17 +0800 Subject: [PATCH 15/27] Add types to base --- autogen/agentchat/contrib/vectordb/base.py | 55 ++++++++++++++++++- .../agentchat/contrib/vectordb/chromadb.py | 4 +- autogen/agentchat/contrib/vectordb/types.py | 54 ------------------ autogen/agentchat/contrib/vectordb/utils.py | 3 +- 4 files changed, 57 insertions(+), 59 deletions(-) delete mode 100644 autogen/agentchat/contrib/vectordb/types.py diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index a4edbab56a16..e6d2d6532ab4 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -1,6 +1,57 @@ -from typing import Any, Dict, List, Protocol, runtime_checkable +from typing import Any, List, Mapping, Optional, Protocol, Sequence, TypedDict, Union, runtime_checkable -from .types import Document, GetResults, ItemID, QueryResults +Metadata = Union[Mapping[str, Union[str, int, float, bool, None]], None] +Vector = Union[Sequence[float], Sequence[int]] +ItemID = str # chromadb doesn't support int ids + + +class Document(TypedDict): + """A Document is a record in the vector database. + + id: ItemID | the unique identifier of the document. + content: str | the text content of the chunk. + metadata: Metadata | contains additional information about the document such as source, date, etc. + embedding: Vector | the vector representation of the content. + dimensions: int | the dimensions of the content_embedding. + """ + + id: ItemID + content: str + metadata: Optional[Metadata] + embedding: Optional[Vector] + dimensions: Optional[int] + + +class QueryResults(TypedDict): + """QueryResults is the response from the vector database for a query. + + ids: List[List[ItemID]] | the unique identifiers of the documents. + contents: List[List[str]] | the text content of the documents. + embeddings: List[List[Vector]] | the vector representations of the documents. + metadatas: List[List[Metadata]] | the metadata of the documents. + distances: List[List[float]] | the distances between the query and the documents. + """ + + ids: List[List[ItemID]] + contents: List[List[str]] + embeddings: Optional[List[List[Vector]]] + metadatas: Optional[List[List[Metadata]]] + distances: Optional[List[List[float]]] + + +class GetResults(TypedDict): + """GetResults is the response from the vector database for getting documents by ids. + + ids: List[ItemID] | the unique identifiers of the documents. + contents: List[str] | the text content of the documents. + embeddings: List[Vector] | the vector representations of the documents. + metadatas: List[Metadata] | the metadata of the documents. + """ + + ids: List[ItemID] + contents: Optional[List[str]] + embeddings: Optional[List[Vector]] + metadatas: Optional[List[Metadata]] @runtime_checkable diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index be298a050345..a393246b6fac 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -1,7 +1,7 @@ import os -from typing import Any, Callable, Dict, List +from typing import Callable, List -from .types import Document, GetResults, ItemID, QueryResults +from .base import Document, GetResults, ItemID, QueryResults from .utils import filter_results_by_distance, get_logger try: diff --git a/autogen/agentchat/contrib/vectordb/types.py b/autogen/agentchat/contrib/vectordb/types.py deleted file mode 100644 index f6b4c9debe50..000000000000 --- a/autogen/agentchat/contrib/vectordb/types.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict, Union - -Metadata = Union[Mapping[str, Union[str, int, float, bool, None]], None] -Vector = Union[Sequence[float], Sequence[int]] -ItemID = str # chromadb doesn't support int ids - - -class Document(TypedDict): - """A Document is a record in the vector database. - - id: ItemID | the unique identifier of the document. - content: str | the text content of the chunk. - metadata: Metadata | contains additional information about the document such as source, date, etc. - embedding: Vector | the vector representation of the content. - dimensions: int | the dimensions of the content_embedding. - """ - - id: ItemID - content: str - metadata: Optional[Metadata] - embedding: Optional[Vector] - dimensions: Optional[int] - - -class QueryResults(TypedDict): - """QueryResults is the response from the vector database for a query. - - ids: List[List[ItemID]] | the unique identifiers of the documents. - contents: List[List[str]] | the text content of the documents. - embeddings: List[List[Vector]] | the vector representations of the documents. - metadatas: List[List[Metadata]] | the metadata of the documents. - distances: List[List[float]] | the distances between the query and the documents. - """ - - ids: List[List[ItemID]] - contents: List[List[str]] - embeddings: Optional[List[List[Vector]]] - metadatas: Optional[List[List[Metadata]]] - distances: Optional[List[List[float]]] - - -class GetResults(TypedDict): - """GetResults is the response from the vector database for getting documents by ids. - - ids: List[ItemID] | the unique identifiers of the documents. - contents: List[str] | the text content of the documents. - embeddings: List[Vector] | the vector representations of the documents. - metadatas: List[Metadata] | the metadata of the documents. - """ - - ids: List[ItemID] - contents: Optional[List[str]] - embeddings: Optional[List[Vector]] - metadatas: Optional[List[Metadata]] diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index 0bddf5915f48..a8231cff8b57 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -1,5 +1,6 @@ -from typing import List, Dict import logging +from typing import Dict, List + from termcolor import colored From 9bd612aa90df378aaa8dfbd543f29277099d6e8d Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Thu, 4 Apr 2024 22:46:12 +0800 Subject: [PATCH 16/27] Update base --- autogen/agentchat/contrib/vectordb/base.py | 79 ++++------------------ 1 file changed, 13 insertions(+), 66 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index e6d2d6532ab4..8db0006559f7 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -1,8 +1,8 @@ -from typing import Any, List, Mapping, Optional, Protocol, Sequence, TypedDict, Union, runtime_checkable +from typing import Any, List, Mapping, Optional, Protocol, Sequence, Tuple, TypedDict, Union, runtime_checkable -Metadata = Union[Mapping[str, Union[str, int, float, bool, None]], None] +Metadata = Union[Mapping[str, Any], None] Vector = Union[Sequence[float], Sequence[int]] -ItemID = str # chromadb doesn't support int ids +ItemID = Union[str, int] # chromadb doesn't support int ids, VikingDB does class Document(TypedDict): @@ -10,9 +10,9 @@ class Document(TypedDict): id: ItemID | the unique identifier of the document. content: str | the text content of the chunk. - metadata: Metadata | contains additional information about the document such as source, date, etc. - embedding: Vector | the vector representation of the content. - dimensions: int | the dimensions of the content_embedding. + metadata: Metadata, Optional | contains additional information about the document such as source, date, etc. + embedding: Vector, Optional | the vector representation of the content. + dimensions: int, Optional | the dimensions of the content_embedding. """ id: ItemID @@ -22,36 +22,11 @@ class Document(TypedDict): dimensions: Optional[int] -class QueryResults(TypedDict): - """QueryResults is the response from the vector database for a query. - - ids: List[List[ItemID]] | the unique identifiers of the documents. - contents: List[List[str]] | the text content of the documents. - embeddings: List[List[Vector]] | the vector representations of the documents. - metadatas: List[List[Metadata]] | the metadata of the documents. - distances: List[List[float]] | the distances between the query and the documents. - """ - - ids: List[List[ItemID]] - contents: List[List[str]] - embeddings: Optional[List[List[Vector]]] - metadatas: Optional[List[List[Metadata]]] - distances: Optional[List[List[float]]] - - -class GetResults(TypedDict): - """GetResults is the response from the vector database for getting documents by ids. - - ids: List[ItemID] | the unique identifiers of the documents. - contents: List[str] | the text content of the documents. - embeddings: List[Vector] | the vector representations of the documents. - metadatas: List[Metadata] | the metadata of the documents. - """ - - ids: List[ItemID] - contents: Optional[List[str]] - embeddings: Optional[List[Vector]] - metadatas: Optional[List[Metadata]] +"""QueryResults is the response from the vector database for a query/queries. +A query is a list containing one string while queries is a list containing multiple strings. +The response is a list of query results, each query result is a list of tuples containing the document and the distance. +""" +QueryResults = List[List[Tuple[Document, float]]] @runtime_checkable @@ -166,36 +141,8 @@ def retrieve_docs( kwargs: Dict | Additional keyword arguments. Returns: - QueryResults | The query results. Each query result is a TypedDict `QueryResults`. - It should include the following fields: - - required: "ids", "contents" - - optional: "embeddings", "metadatas", "distances", etc. - - queries example: ["query1", "query2"] - query results example: { - "ids": [["id1", "id2", ...], ["id3", "id4", ...]], - "contents": [["content1", "content2", ...], ["content3", "content4", ...]], - "embeddings": [["embedding1", "embedding2", ...], ["embedding3", "embedding4", ...]], - "metadatas": [["metadata1", "metadata2", ...], ["metadata3", "metadata4", ...]], - "distances": [["distance1", "distance2", ...], ["distance3", "distance4", ...]], - } - - """ - ... - - def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> GetResults: - """ - Retrieve documents from the collection of the vector database based on the ids. - - Args: - ids: List[Any] | A list of document ids. - collection_name: str | The name of the collection. Default is None. - include: List[str] | The fields to include. Default is None. - If None, will include ["metadatas", "documents"]. ids are always included. - kwargs: Dict | Additional keyword arguments. - - Returns: - GetResults | The results. + QueryResults | The query results. Each query result is a list of list of tuples containing the document and + the distance. """ ... From bc61162db02ae9afa4295e75e4d0ea03aedd0d8b Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Thu, 4 Apr 2024 23:03:25 +0800 Subject: [PATCH 17/27] Update utils --- autogen/agentchat/contrib/vectordb/utils.py | 29 +++-------- .../contrib/vectordb/test_vectordb_utils.py | 49 ++++--------------- 2 files changed, 16 insertions(+), 62 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index a8231cff8b57..abc6699032a8 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -3,6 +3,8 @@ from termcolor import colored +from .base import QueryResults + class ColoredLogger(logging.Logger): def __init__(self, name, level=logging.NOTSET): @@ -36,37 +38,18 @@ def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger: logger = get_logger(__name__) -def filter_results_by_distance( - results: Dict[str, List[List[Dict]]], distance_threshold: float = -1 -) -> Dict[str, List[List[Dict]]]: +def filter_results_by_distance(results: QueryResults, distance_threshold: float = -1) -> QueryResults: """Filters results based on a distance threshold. Args: - results: A dictionary containing results to be filtered. + results: QueryResults | The query results. List[List[Tuple[Document, float]]] distance_threshold: The maximum distance allowed for results. Returns: - Dict[str, List[List[Dict]]] | A filtered dictionary containing only results within the threshold. + QueryResults | A filtered results containing only distances smaller than the threshold. """ if distance_threshold > 0: - # Filter distances first: - return_ridx = [ - [ridx for ridx, distance in enumerate(distances) if distance < distance_threshold] - for distances in results["distances"] - ] - - # Filter other keys based on filtered distances: - results = { - key: ( - [ - [value for ridx, value in enumerate(results_list) if ridx in return_ridx[qidx]] - for qidx, results_list in enumerate(results_lists) - ] - if isinstance(results_lists, list) - else results_lists - ) - for key, results_lists in results.items() - } + results = [[(key, value) for key, value in data if value < distance_threshold] for data in results] return results diff --git a/test/agentchat/contrib/vectordb/test_vectordb_utils.py b/test/agentchat/contrib/vectordb/test_vectordb_utils.py index 8f6ba8395a9d..dde5e67af6f2 100644 --- a/test/agentchat/contrib/vectordb/test_vectordb_utils.py +++ b/test/agentchat/contrib/vectordb/test_vectordb_utils.py @@ -1,50 +1,21 @@ #!/usr/bin/env python3 -m pytest -import pytest import os import sys +import pytest + from autogen.agentchat.contrib.vectordb.utils import filter_results_by_distance def test_retrieve_config(): - results = { - "ids": [["id1", "id2"], ["id3", "id4"]], - "contents": [ - ["content1", "content2"], - [ - "content3", - "content4", - ], - ], - "embeddings": [ - [ - "embedding1", - "embedding2", - ], - [ - "embedding3", - "embedding4", - ], - ], - "metadatas": [ - [ - "metadata1", - "metadata2", - ], - [ - "metadata3", - "metadata4", - ], - ], - "distances": [[1, 2], [2, 3]], - } + results = [ + [("id1", 1), ("id2", 2)], + [("id3", 2), ("id4", 3)], + ] print(filter_results_by_distance(results, 2.1)) - filter_results = { - "ids": [["id1", "id2"], ["id3"]], - "contents": [["content1", "content2"], ["content3"]], - "embeddings": [["embedding1", "embedding2"], ["embedding3"]], - "metadatas": [["metadata1", "metadata2"], ["metadata3"]], - "distances": [[1, 2], [2]], - } + filter_results = [ + [("id1", 1), ("id2", 2)], + [("id3", 2)], + ] assert filter_results == filter_results_by_distance(results, 2.1) From e9ece8a1ef7cd0ad3a3e83a32f3c550df02adda9 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Fri, 5 Apr 2024 10:14:07 +0800 Subject: [PATCH 18/27] Update chromadb --- autogen/agentchat/contrib/vectordb/base.py | 2 - .../agentchat/contrib/vectordb/chromadb.py | 121 +++++++++++------- .../contrib/vectordb/test_chromadb.py | 37 +++++- 3 files changed, 107 insertions(+), 53 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index 8db0006559f7..906eb49e4ac4 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -12,14 +12,12 @@ class Document(TypedDict): content: str | the text content of the chunk. metadata: Metadata, Optional | contains additional information about the document such as source, date, etc. embedding: Vector, Optional | the vector representation of the content. - dimensions: int, Optional | the dimensions of the content_embedding. """ id: ItemID content: str metadata: Optional[Metadata] embedding: Optional[Vector] - dimensions: Optional[int] """QueryResults is the response from the vector database for a query/queries. diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index a393246b6fac..d89c82a19079 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -1,7 +1,7 @@ import os from typing import Callable, List -from .base import Document, GetResults, ItemID, QueryResults +from .base import Document, ItemID, QueryResults, VectorDB from .utils import filter_results_by_distance, get_logger try: @@ -17,7 +17,7 @@ logger = get_logger(__name__) -class ChromaVectorDB: +class ChromaVectorDB(VectorDB): """ A vector database that uses ChromaDB as the backend. """ @@ -73,7 +73,10 @@ def create_collection( Collection | The collection object. """ try: - collection = self.client.get_collection(collection_name) + if self.active_collection and self.active_collection.name == collection_name: + collection = self.active_collection + else: + collection = self.client.get_collection(collection_name) except ValueError: collection = None if collection is None: @@ -115,7 +118,8 @@ def get_collection(self, collection_name: str = None) -> Collection: f"No collection is specified. Using current active collection {self.active_collection.name}." ) else: - self.active_collection = self.client.get_collection(collection_name) + if not (self.active_collection and self.active_collection.name == collection_name): + self.active_collection = self.client.get_collection(collection_name) return self.active_collection def delete_collection(self, collection_name: str) -> None: @@ -129,12 +133,11 @@ def delete_collection(self, collection_name: str) -> None: None """ self.client.delete_collection(collection_name) - if self.active_collection: - if self.active_collection.name == collection_name: - self.active_collection = None + if self.active_collection and self.active_collection.name == collection_name: + self.active_collection = None def _batch_insert( - self, collection: Collection, embeddings=None, ids=None, metadata=None, documents=None, upsert=False + self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False ): batch_size = int(CHROMADB_MAX_BATCH_SIZE) for i in range(0, len(documents), min(batch_size, len(documents))): @@ -142,7 +145,7 @@ def _batch_insert( collection_kwargs = { "documents": documents[i:end_idx], "ids": ids[i:end_idx], - "metadatas": metadata[i:end_idx] if metadata else None, + "metadatas": metadatas[i:end_idx] if metadatas else None, "embeddings": embeddings[i:end_idx] if embeddings else None, } if upsert: @@ -180,10 +183,10 @@ def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: else: embeddings = [doc.get("embedding") for doc in docs] if docs[0].get("metadata") is None: - metadata = None + metadatas = None else: - metadata = [doc.get("metadata") for doc in docs] - self._batch_insert(collection, embeddings, ids, metadata, documents, upsert) + metadatas = [doc.get("metadata") for doc in docs] + self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert) def update_docs(self, docs: List[Document], collection_name: str = None) -> None: """ @@ -213,6 +216,63 @@ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) collection = self.get_collection(collection_name) collection.delete(ids, **kwargs) + @staticmethod + def _chroma_results_to_query_results(data_dict, special_key="distances"): + """Converts a dictionary with list-of-list values to a list of tuples. + + Args: + data_dict: A dictionary where keys map to lists of lists or None. + special_key: The key in the dictionary containing the special values + for each tuple. + + Returns: + A list of tuples, where each tuple contains a sub-dictionary with + some keys from the original dictionary and the value from the + special_key. + + Example: + data_dict = { + "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "key2s": [["a", "b", "c"], ["c", "d", "e"], ["e", "f", "g"]], + "key3s": None, + "key4s": [["x", "y", "z"], ["1", "2", "3"], ["4", "5", "6"]], + "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + } + + results = [ + [ + ({"key1": 1, "key2": "a", "key4": "x"}, 0.1), + ({"key1": 2, "key2": "b", "key4": "y"}, 0.2), + ({"key1": 3, "key2": "c", "key4": "z"}, 0.3), + ], + [ + ({"key1": 4, "key2": "c", "key4": "1"}, 0.4), + ({"key1": 5, "key2": "d", "key4": "2"}, 0.5), + ({"key1": 6, "key2": "e", "key4": "3"}, 0.6), + ], + [ + ({"key1": 7, "key2": "e", "key4": "4"}, 0.7), + ({"key1": 8, "key2": "f", "key4": "5"}, 0.8), + ({"key1": 9, "key2": "g", "key4": "6"}, 0.9), + ], + ] + """ + + keys = [key for key in data_dict if key != special_key] + result = [] + + for i in range(len(data_dict[special_key])): + sub_result = [] + for j, distance in enumerate(data_dict[special_key][i]): + sub_dict = {} + for key in keys: + if data_dict[key] is not None and len(data_dict[key]) > i: + sub_dict[key[:-1]] = data_dict[key][i][j] # remove 's' in the end from key + sub_result.append((sub_dict, distance)) + result.append(sub_result) + + return result + def retrieve_docs( self, queries: List[str], @@ -233,20 +293,8 @@ def retrieve_docs( kwargs: Dict | Additional keyword arguments. Returns: - QueryResults | The query results. Each query result is a TypedDict `QueryResults`. - It should include the following fields: - - required: "ids", "contents" - - optional: "embeddings", "metadatas", "distances", etc. - - queries example: ["query1", "query2"] - query results example: { - "ids": [["id1", "id2", ...], ["id3", "id4", ...]], - "contents": [["content1", "content2", ...], ["content3", "content4", ...]], - "embeddings": [["embedding1", "embedding2", ...], ["embedding3", "embedding4", ...]], - "metadatas": [["metadata1", "metadata2", ...], ["metadata3", "metadata4", ...]], - "distances": [["distance1", "distance2", ...], ["distance3", "distance4", ...]], - } - + QueryResults | The query results. Each query result is a list of list of tuples containing the document and + the distance. """ collection = self.get_collection(collection_name) if isinstance(queries, str): @@ -257,25 +305,6 @@ def retrieve_docs( **kwargs, ) results["contents"] = results.pop("documents") + results = self._chroma_results_to_query_results(results) results = filter_results_by_distance(results, distance_threshold) - - return results - - def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> GetResults: - """ - Retrieve documents from the collection of the vector database based on the ids. - - Args: - ids: List[Any] | A list of document ids. - collection_name: str | The name of the collection. Default is None. - include: List[str] | The fields to include. Default is None. - If None, will include ["metadatas", "documents"]. ids are always included. - kwargs: Dict | Additional keyword arguments. - - Returns: - GetResults | The results. - """ - collection = self.get_collection(collection_name) - include = include if include else ["metadatas", "documents"] - results = collection.get(ids, include=include, **kwargs) return results diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py index 3e38b21ebba6..569677c03912 100644 --- a/test/agentchat/contrib/vectordb/test_chromadb.py +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -44,30 +44,57 @@ def test_chromadb(): # test_insert_docs docs = [{"content": "doc1", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}] db.insert_docs(docs, collection_name, upsert=False) - res = db.get_docs_by_ids(["1", "2"], collection_name, include=["documents"]) + res = db.get_collection(collection_name).get(["1", "2"]) assert res["documents"] == ["doc1", "doc2"] # test_update_docs docs = [{"content": "doc11", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}] db.update_docs(docs, collection_name) - res = db.get_docs_by_ids(["1", "2"], collection_name) + res = db.get_collection(collection_name).get(["1", "2"]) assert res["documents"] == ["doc11", "doc2"] # test_delete_docs ids = ["1"] collection_name = "test_collection" db.delete_docs(ids, collection_name) - res = db.get_docs_by_ids(ids, collection_name) + res = db.get_collection(collection_name).get(ids) assert res["documents"] == [] # test_retrieve_docs queries = ["doc2", "doc3"] collection_name = "test_collection" res = db.retrieve_docs(queries, collection_name) - assert res["ids"] == [["2", "3"], ["3", "2"]] + assert [[r[0]["id"] for r in rr] for rr in res] == [["2", "3"], ["3", "2"]] res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1) print(res) - assert res["ids"] == [["2"], ["3"]] + assert [[r[0]["id"] for r in rr] for rr in res] == [["2"], ["3"]] + + # test _chroma_results_to_query_results + data_dict = { + "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "key2s": [["a", "b", "c"], ["c", "d", "e"], ["e", "f", "g"]], + "key3s": None, + "key4s": [["x", "y", "z"], ["1", "2", "3"], ["4", "5", "6"]], + "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + } + results = [ + [ + ({"key1": 1, "key2": "a", "key4": "x"}, 0.1), + ({"key1": 2, "key2": "b", "key4": "y"}, 0.2), + ({"key1": 3, "key2": "c", "key4": "z"}, 0.3), + ], + [ + ({"key1": 4, "key2": "c", "key4": "1"}, 0.4), + ({"key1": 5, "key2": "d", "key4": "2"}, 0.5), + ({"key1": 6, "key2": "e", "key4": "3"}, 0.6), + ], + [ + ({"key1": 7, "key2": "e", "key4": "4"}, 0.7), + ({"key1": 8, "key2": "f", "key4": "5"}, 0.8), + ({"key1": 9, "key2": "g", "key4": "6"}, 0.9), + ], + ] + assert db._chroma_results_to_query_results(data_dict) == results if __name__ == "__main__": From 1db36d6922738bce7636a11a4aef6f678db4fedb Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Fri, 5 Apr 2024 10:52:37 +0800 Subject: [PATCH 19/27] Add get_docs_by_ids --- autogen/agentchat/contrib/vectordb/base.py | 16 +++++ .../agentchat/contrib/vectordb/chromadb.py | 58 +++++++++++++++++++ .../contrib/vectordb/test_chromadb.py | 19 ++++++ 3 files changed, 93 insertions(+) diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index 906eb49e4ac4..8d4409621f60 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -144,6 +144,22 @@ def retrieve_docs( """ ... + def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> List[Document]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: List[ItemID] | A list of document ids. + collection_name: str | The name of the collection. Default is None. + include: List[str] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"], ids will always be included. + kwargs: dict | Additional keyword arguments. + + Returns: + List[Document] | The results. + """ + ... + class VectorDBFactory: """ diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index d89c82a19079..73c287a3d93a 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -308,3 +308,61 @@ def retrieve_docs( results = self._chroma_results_to_query_results(results) results = filter_results_by_distance(results, distance_threshold) return results + + @staticmethod + def _chroma_get_results_to_list_documents(data_dict, include=None): + """Converts a dictionary with list values to a list of Document. + + Args: + data_dict: A dictionary where keys map to lists or None. + include: List[str] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"], ids will always be included. + + Returns: + List[Document] | The list of Document. + + Example: + data_dict = { + "key1s": [1, 2, 3], + "key2s": ["a", "b", "c"], + "key3s": None, + "key4s": ["x", "y", "z"], + } + + results = [ + {"key1": 1, "key2": "a", "key4": "x"}, + {"key1": 2, "key2": "b", "key4": "y"}, + {"key1": 3, "key2": "c", "key4": "z"}, + ] + """ + + results = [] + keys = [key for key in data_dict if data_dict[key] is not None] + + for i in range(len(data_dict[keys[0]])): + sub_dict = {} + for key in data_dict.keys(): + if data_dict[key] is not None and len(data_dict[key]) > i: + sub_dict[key[:-1]] = data_dict[key][i] + results.append(sub_dict) + return results + + def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> List[Document]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: List[ItemID] | A list of document ids. + collection_name: str | The name of the collection. Default is None. + include: List[str] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"], ids will always be included. + kwargs: dict | Additional keyword arguments. + + Returns: + List[Document] | The results. + """ + collection = self.get_collection(collection_name) + include = include if include else ["metadatas", "documents"] + results = collection.get(ids, include=include, **kwargs) + results = self._chroma_get_results_to_list_documents(results) + return results diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py index 569677c03912..7cefcba2dbb6 100644 --- a/test/agentchat/contrib/vectordb/test_chromadb.py +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -69,6 +69,10 @@ def test_chromadb(): print(res) assert [[r[0]["id"] for r in rr] for rr in res] == [["2"], ["3"]] + # test_get_docs_by_ids + res = db.get_docs_by_ids(["1", "2"], collection_name) + assert [r["id"] for r in res] == ["2"] # "1" has been deleted + # test _chroma_results_to_query_results data_dict = { "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], @@ -96,6 +100,21 @@ def test_chromadb(): ] assert db._chroma_results_to_query_results(data_dict) == results + # test _chroma_get_results_to_list_documents + data_dict = { + "key1s": [1, 2, 3], + "key2s": ["a", "b", "c"], + "key3s": None, + "key4s": ["x", "y", "z"], + } + + results = [ + {"key1": 1, "key2": "a", "key4": "x"}, + {"key1": 2, "key2": "b", "key4": "y"}, + {"key1": 3, "key2": "c", "key4": "z"}, + ] + assert db._chroma_get_results_to_list_documents(data_dict) == results + if __name__ == "__main__": test_chromadb() From 411a1dba6de4e5465a99a206ff591e1cc0c3a15b Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Sat, 6 Apr 2024 15:46:05 +0800 Subject: [PATCH 20/27] Improve docstring --- autogen/agentchat/contrib/vectordb/chromadb.py | 1 + 1 file changed, 1 insertion(+) diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 73c287a3d93a..8141e2e1468c 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -30,6 +30,7 @@ def __init__( Args: client: chromadb.Client | The client object of the vector database. Default is None. + If provided, it will use the client object directly and ignore other arguments. path: str | The path to the vector database. Default is None. embedding_function: Callable | The embedding function used to generate the vector representation of the documents. Default is None. From a78d572b3091dd5bb223db7cedb37902df424df9 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Sun, 7 Apr 2024 11:41:03 +0800 Subject: [PATCH 21/27] Add get all docs --- autogen/agentchat/contrib/vectordb/base.py | 6 ++++-- autogen/agentchat/contrib/vectordb/chromadb.py | 6 ++++-- test/agentchat/contrib/vectordb/test_chromadb.py | 2 ++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index 8d4409621f60..d04f42b5a6bc 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -144,12 +144,14 @@ def retrieve_docs( """ ... - def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> List[Document]: + def get_docs_by_ids( + self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs + ) -> List[Document]: """ Retrieve documents from the collection of the vector database based on the ids. Args: - ids: List[ItemID] | A list of document ids. + ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None. collection_name: str | The name of the collection. Default is None. include: List[str] | The fields to include. Default is None. If None, will include ["metadatas", "documents"], ids will always be included. diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 8141e2e1468c..1238237bab4c 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -348,12 +348,14 @@ def _chroma_get_results_to_list_documents(data_dict, include=None): results.append(sub_dict) return results - def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> List[Document]: + def get_docs_by_ids( + self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs + ) -> List[Document]: """ Retrieve documents from the collection of the vector database based on the ids. Args: - ids: List[ItemID] | A list of document ids. + ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None. collection_name: str | The name of the collection. Default is None. include: List[str] | The fields to include. Default is None. If None, will include ["metadatas", "documents"], ids will always be included. diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py index 7cefcba2dbb6..465f1e85a832 100644 --- a/test/agentchat/contrib/vectordb/test_chromadb.py +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -72,6 +72,8 @@ def test_chromadb(): # test_get_docs_by_ids res = db.get_docs_by_ids(["1", "2"], collection_name) assert [r["id"] for r in res] == ["2"] # "1" has been deleted + res = db.get_docs_by_ids(collection_name=collection_name) + assert [r["id"] for r in res] == ["2", "3"] # test _chroma_results_to_query_results data_dict = { From 4b644c3cc7de1eed4d01c45a9a48d0d474cb3fd5 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Sun, 7 Apr 2024 13:21:11 +0800 Subject: [PATCH 22/27] Move chroma_results_to_query_results to utils --- .../agentchat/contrib/vectordb/chromadb.py | 61 +------------------ autogen/agentchat/contrib/vectordb/utils.py | 57 +++++++++++++++++ .../contrib/vectordb/test_chromadb.py | 27 -------- .../contrib/vectordb/test_vectordb_utils.py | 30 ++++++++- 4 files changed, 88 insertions(+), 87 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 1238237bab4c..fc31ee1e07b7 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -2,7 +2,7 @@ from typing import Callable, List from .base import Document, ItemID, QueryResults, VectorDB -from .utils import filter_results_by_distance, get_logger +from .utils import chroma_results_to_query_results, filter_results_by_distance, get_logger try: import chromadb @@ -217,63 +217,6 @@ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) collection = self.get_collection(collection_name) collection.delete(ids, **kwargs) - @staticmethod - def _chroma_results_to_query_results(data_dict, special_key="distances"): - """Converts a dictionary with list-of-list values to a list of tuples. - - Args: - data_dict: A dictionary where keys map to lists of lists or None. - special_key: The key in the dictionary containing the special values - for each tuple. - - Returns: - A list of tuples, where each tuple contains a sub-dictionary with - some keys from the original dictionary and the value from the - special_key. - - Example: - data_dict = { - "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - "key2s": [["a", "b", "c"], ["c", "d", "e"], ["e", "f", "g"]], - "key3s": None, - "key4s": [["x", "y", "z"], ["1", "2", "3"], ["4", "5", "6"]], - "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], - } - - results = [ - [ - ({"key1": 1, "key2": "a", "key4": "x"}, 0.1), - ({"key1": 2, "key2": "b", "key4": "y"}, 0.2), - ({"key1": 3, "key2": "c", "key4": "z"}, 0.3), - ], - [ - ({"key1": 4, "key2": "c", "key4": "1"}, 0.4), - ({"key1": 5, "key2": "d", "key4": "2"}, 0.5), - ({"key1": 6, "key2": "e", "key4": "3"}, 0.6), - ], - [ - ({"key1": 7, "key2": "e", "key4": "4"}, 0.7), - ({"key1": 8, "key2": "f", "key4": "5"}, 0.8), - ({"key1": 9, "key2": "g", "key4": "6"}, 0.9), - ], - ] - """ - - keys = [key for key in data_dict if key != special_key] - result = [] - - for i in range(len(data_dict[special_key])): - sub_result = [] - for j, distance in enumerate(data_dict[special_key][i]): - sub_dict = {} - for key in keys: - if data_dict[key] is not None and len(data_dict[key]) > i: - sub_dict[key[:-1]] = data_dict[key][i][j] # remove 's' in the end from key - sub_result.append((sub_dict, distance)) - result.append(sub_result) - - return result - def retrieve_docs( self, queries: List[str], @@ -306,7 +249,7 @@ def retrieve_docs( **kwargs, ) results["contents"] = results.pop("documents") - results = self._chroma_results_to_query_results(results) + results = chroma_results_to_query_results(results) results = filter_results_by_distance(results, distance_threshold) return results diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index abc6699032a8..16f318420181 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -53,3 +53,60 @@ def filter_results_by_distance(results: QueryResults, distance_threshold: float results = [[(key, value) for key, value in data if value < distance_threshold] for data in results] return results + + +def chroma_results_to_query_results(data_dict, special_key="distances"): + """Converts a dictionary with list-of-list values to a list of tuples. + + Args: + data_dict: A dictionary where keys map to lists of lists or None. + special_key: The key in the dictionary containing the special values + for each tuple. + + Returns: + A list of tuples, where each tuple contains a sub-dictionary with + some keys from the original dictionary and the value from the + special_key. + + Example: + data_dict = { + "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "key2s": [["a", "b", "c"], ["c", "d", "e"], ["e", "f", "g"]], + "key3s": None, + "key4s": [["x", "y", "z"], ["1", "2", "3"], ["4", "5", "6"]], + "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + } + + results = [ + [ + ({"key1": 1, "key2": "a", "key4": "x"}, 0.1), + ({"key1": 2, "key2": "b", "key4": "y"}, 0.2), + ({"key1": 3, "key2": "c", "key4": "z"}, 0.3), + ], + [ + ({"key1": 4, "key2": "c", "key4": "1"}, 0.4), + ({"key1": 5, "key2": "d", "key4": "2"}, 0.5), + ({"key1": 6, "key2": "e", "key4": "3"}, 0.6), + ], + [ + ({"key1": 7, "key2": "e", "key4": "4"}, 0.7), + ({"key1": 8, "key2": "f", "key4": "5"}, 0.8), + ({"key1": 9, "key2": "g", "key4": "6"}, 0.9), + ], + ] + """ + + keys = [key for key in data_dict if key != special_key] + result = [] + + for i in range(len(data_dict[special_key])): + sub_result = [] + for j, distance in enumerate(data_dict[special_key][i]): + sub_dict = {} + for key in keys: + if data_dict[key] is not None and len(data_dict[key]) > i: + sub_dict[key[:-1]] = data_dict[key][i][j] # remove 's' in the end from key + sub_result.append((sub_dict, distance)) + result.append(sub_result) + + return result diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py index 465f1e85a832..9c36c121f4ea 100644 --- a/test/agentchat/contrib/vectordb/test_chromadb.py +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -75,33 +75,6 @@ def test_chromadb(): res = db.get_docs_by_ids(collection_name=collection_name) assert [r["id"] for r in res] == ["2", "3"] - # test _chroma_results_to_query_results - data_dict = { - "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - "key2s": [["a", "b", "c"], ["c", "d", "e"], ["e", "f", "g"]], - "key3s": None, - "key4s": [["x", "y", "z"], ["1", "2", "3"], ["4", "5", "6"]], - "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], - } - results = [ - [ - ({"key1": 1, "key2": "a", "key4": "x"}, 0.1), - ({"key1": 2, "key2": "b", "key4": "y"}, 0.2), - ({"key1": 3, "key2": "c", "key4": "z"}, 0.3), - ], - [ - ({"key1": 4, "key2": "c", "key4": "1"}, 0.4), - ({"key1": 5, "key2": "d", "key4": "2"}, 0.5), - ({"key1": 6, "key2": "e", "key4": "3"}, 0.6), - ], - [ - ({"key1": 7, "key2": "e", "key4": "4"}, 0.7), - ({"key1": 8, "key2": "f", "key4": "5"}, 0.8), - ({"key1": 9, "key2": "g", "key4": "6"}, 0.9), - ], - ] - assert db._chroma_results_to_query_results(data_dict) == results - # test _chroma_get_results_to_list_documents data_dict = { "key1s": [1, 2, 3], diff --git a/test/agentchat/contrib/vectordb/test_vectordb_utils.py b/test/agentchat/contrib/vectordb/test_vectordb_utils.py index dde5e67af6f2..8c26ac9c3cdf 100644 --- a/test/agentchat/contrib/vectordb/test_vectordb_utils.py +++ b/test/agentchat/contrib/vectordb/test_vectordb_utils.py @@ -5,7 +5,7 @@ import pytest -from autogen.agentchat.contrib.vectordb.utils import filter_results_by_distance +from autogen.agentchat.contrib.vectordb.utils import chroma_results_to_query_results, filter_results_by_distance def test_retrieve_config(): @@ -19,3 +19,31 @@ def test_retrieve_config(): [("id3", 2)], ] assert filter_results == filter_results_by_distance(results, 2.1) + + +def test_chroma_results_to_query_results(): + data_dict = { + "key1s": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "key2s": [["a", "b", "c"], ["c", "d", "e"], ["e", "f", "g"]], + "key3s": None, + "key4s": [["x", "y", "z"], ["1", "2", "3"], ["4", "5", "6"]], + "distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + } + results = [ + [ + ({"key1": 1, "key2": "a", "key4": "x"}, 0.1), + ({"key1": 2, "key2": "b", "key4": "y"}, 0.2), + ({"key1": 3, "key2": "c", "key4": "z"}, 0.3), + ], + [ + ({"key1": 4, "key2": "c", "key4": "1"}, 0.4), + ({"key1": 5, "key2": "d", "key4": "2"}, 0.5), + ({"key1": 6, "key2": "e", "key4": "3"}, 0.6), + ], + [ + ({"key1": 7, "key2": "e", "key4": "4"}, 0.7), + ({"key1": 8, "key2": "f", "key4": "5"}, 0.8), + ({"key1": 9, "key2": "g", "key4": "6"}, 0.9), + ], + ] + assert chroma_results_to_query_results(data_dict) == results From aed255e57a9ff5a7d761e9d9e464897576f97667 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Sun, 7 Apr 2024 13:36:02 +0800 Subject: [PATCH 23/27] Improve type hints --- autogen/agentchat/contrib/vectordb/chromadb.py | 6 ++---- autogen/agentchat/contrib/vectordb/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index fc31ee1e07b7..8fd915c93ced 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -139,7 +139,7 @@ def delete_collection(self, collection_name: str) -> None: def _batch_insert( self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False - ): + ) -> None: batch_size = int(CHROMADB_MAX_BATCH_SIZE) for i in range(0, len(documents), min(batch_size, len(documents))): end_idx = i + min(batch_size, len(documents) - i) @@ -254,13 +254,11 @@ def retrieve_docs( return results @staticmethod - def _chroma_get_results_to_list_documents(data_dict, include=None): + def _chroma_get_results_to_list_documents(data_dict) -> List[Document]: """Converts a dictionary with list values to a list of Document. Args: data_dict: A dictionary where keys map to lists or None. - include: List[str] | The fields to include. Default is None. - If None, will include ["metadatas", "documents"], ids will always be included. Returns: List[Document] | The list of Document. diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index 16f318420181..5f5b638587bb 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List +from typing import Any, Dict, List from termcolor import colored @@ -55,7 +55,7 @@ def filter_results_by_distance(results: QueryResults, distance_threshold: float return results -def chroma_results_to_query_results(data_dict, special_key="distances"): +def chroma_results_to_query_results(data_dict: Dict[str, List[List[Any]]], special_key="distances") -> QueryResults: """Converts a dictionary with list-of-list values to a list of tuples. Args: From 9cdc97bf6b24cb43fdde293b1b47a8288621af97 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Sun, 7 Apr 2024 15:41:41 +0800 Subject: [PATCH 24/27] Update logger --- autogen/agentchat/contrib/vectordb/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/vectordb/utils.py b/autogen/agentchat/contrib/vectordb/utils.py index 5f5b638587bb..ae1ef1252519 100644 --- a/autogen/agentchat/contrib/vectordb/utils.py +++ b/autogen/agentchat/contrib/vectordb/utils.py @@ -30,7 +30,7 @@ def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger: logger = ColoredLogger(name, level) console_handler = logging.StreamHandler() logger.addHandler(console_handler) - formatter = logging.Formatter("%(asctime)s - %(filename)s:%(lineno)5d - %(levelname)s - %(message)s") + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger.handlers[0].setFormatter(formatter) return logger From 05d34a4127f3ce301338c0d9952b0ef2ee8fb30f Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Sun, 7 Apr 2024 15:43:44 +0800 Subject: [PATCH 25/27] Update init, add embedding func --- autogen/agentchat/contrib/vectordb/chromadb.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/contrib/vectordb/chromadb.py b/autogen/agentchat/contrib/vectordb/chromadb.py index 8fd915c93ced..6e571d58abc2 100644 --- a/autogen/agentchat/contrib/vectordb/chromadb.py +++ b/autogen/agentchat/contrib/vectordb/chromadb.py @@ -9,6 +9,7 @@ if chromadb.__version__ < "0.4.15": raise ImportError("Please upgrade chromadb to version 0.4.15 or later.") + import chromadb.utils.embedding_functions as ef from chromadb.api.models.Collection import Collection except ImportError: raise ImportError("Please install chromadb: `pip install chromadb`") @@ -33,7 +34,7 @@ def __init__( If provided, it will use the client object directly and ignore other arguments. path: str | The path to the vector database. Default is None. embedding_function: Callable | The embedding function used to generate the vector representation - of the documents. Default is None. + of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used. metadata: dict | The metadata of the vector database. Default is None. If None, it will use this setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances), @@ -45,15 +46,20 @@ def __init__( None """ self.client = client + self.path = path + self.embedding_function = ( + ef.SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") + if embedding_function is None + else embedding_function + ) + self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32} if not self.client: - self.path = path - self.embedding_function = embedding_function - self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32} if self.path is not None: self.client = chromadb.PersistentClient(path=self.path, **kwargs) else: self.client = chromadb.Client(**kwargs) self.active_collection = None + self.type = "chroma" def create_collection( self, collection_name: str, overwrite: bool = False, get_or_create: bool = True From ac5224dfc0428d8c8c523cf8f38600b24255cd6d Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 9 Apr 2024 10:04:01 +0800 Subject: [PATCH 26/27] Improve docstring of vectordb, add two attributes --- autogen/agentchat/contrib/vectordb/base.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/autogen/agentchat/contrib/vectordb/base.py b/autogen/agentchat/contrib/vectordb/base.py index d04f42b5a6bc..187d0d6acbbe 100644 --- a/autogen/agentchat/contrib/vectordb/base.py +++ b/autogen/agentchat/contrib/vectordb/base.py @@ -31,8 +31,25 @@ class Document(TypedDict): class VectorDB(Protocol): """ Abstract class for vector database. A vector database is responsible for storing and retrieving documents. + + Attributes: + active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None. + type: str | The type of the vector database, chroma, pgvector, etc. Default is "". + + Methods: + create_collection: Callable[[str, bool, bool], Any] | Create a collection in the vector database. + get_collection: Callable[[str], Any] | Get the collection from the vector database. + delete_collection: Callable[[str], Any] | Delete the collection from the vector database. + insert_docs: Callable[[List[Document], str, bool], None] | Insert documents into the collection of the vector database. + update_docs: Callable[[List[Document], str], None] | Update documents in the collection of the vector database. + delete_docs: Callable[[List[ItemID], str], None] | Delete documents from the collection of the vector database. + retrieve_docs: Callable[[List[str], str, int, float], QueryResults] | Retrieve documents from the collection of the vector database based on the queries. + get_docs_by_ids: Callable[[List[ItemID], str], List[Document]] | Retrieve documents from the collection of the vector database based on the ids. """ + active_collection: Any = None + type: str = "" + def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any: """ Create a collection in the vector database. From 6a126c80a87f1bec96659f3c8751404f81c24ef0 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Wed, 10 Apr 2024 13:15:04 +0800 Subject: [PATCH 27/27] Improve test workflow --- .github/workflows/contrib-openai.yml | 2 +- .github/workflows/contrib-tests.yml | 3 --- test/agentchat/contrib/vectordb/test_chromadb.py | 2 +- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/contrib-openai.yml b/.github/workflows/contrib-openai.yml index 4cd31dac9229..5e4ba170370a 100644 --- a/.github/workflows/contrib-openai.yml +++ b/.github/workflows/contrib-openai.yml @@ -53,7 +53,7 @@ jobs: AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }} OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }} run: | - coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/vectordb + coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py::test_retrievechat test/agentchat/contrib/test_qdrant_retrievechat.py::test_retrievechat coverage xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 9301900bd04e..719ff0861838 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -58,9 +58,6 @@ jobs: if [[ ${{ matrix.os }} != ubuntu-latest ]]; then echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV fi - - name: Test RetrieveChat - run: | - pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/vectordb --skip-openai - name: Coverage run: | pip install coverage>=5.3 diff --git a/test/agentchat/contrib/vectordb/test_chromadb.py b/test/agentchat/contrib/vectordb/test_chromadb.py index 9c36c121f4ea..ee4886f5154d 100644 --- a/test/agentchat/contrib/vectordb/test_chromadb.py +++ b/test/agentchat/contrib/vectordb/test_chromadb.py @@ -16,7 +16,7 @@ skip = False -@pytest.mark.skipif(skip, reason="dependency is not installed OR requested to skip") +@pytest.mark.skipif(skip, reason="dependency is not installed") def test_chromadb(): # test create collection db = ChromaVectorDB(path=".db")