From 490e20f28096dfa7d0a3befcfa928ed4c9c7e0bb Mon Sep 17 00:00:00 2001 From: Sai krishna Date: Sat, 19 Oct 2024 23:37:43 +0530 Subject: [PATCH 1/4] feat: init chroma client --- backend/modules/vector_db/base.py | 2 +- backend/modules/vector_db/chroma.py | 94 +++++++++++++++++++++++++++++ backend/requirements.txt | 1 + docker-compose.yaml | 21 +++++++ 4 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 backend/modules/vector_db/chroma.py diff --git a/backend/modules/vector_db/base.py b/backend/modules/vector_db/base.py index 3cca30ff..6df1a0c7 100644 --- a/backend/modules/vector_db/base.py +++ b/backend/modules/vector_db/base.py @@ -11,7 +11,7 @@ class BaseVectorDB(ABC): @abstractmethod - def create_collection(self, collection_name: str, embeddings: Embeddings): + def create_collection(self, collection_name: str, embeddings: Embeddings = None): """ Create a collection in the vector database """ diff --git a/backend/modules/vector_db/chroma.py b/backend/modules/vector_db/chroma.py new file mode 100644 index 00000000..d888a3d4 --- /dev/null +++ b/backend/modules/vector_db/chroma.py @@ -0,0 +1,94 @@ +from typing import List, Optional, Union + +from chromadb import HttpClient, PersistentClient +from chromadb.api import ClientAPI +from fastapi import HTTPException +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings + +from backend.constants import DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE +from backend.modules.vector_db.base import BaseVectorDB +from backend.types import VectorDBConfig + + +class ChromaVectorDB(BaseVectorDB): + def __init__(self, db_config: VectorDBConfig): + self.db_config = db_config + self.client = self.get_client() + + def get_client(self) -> ClientAPI: + # For local development, we use a persistent client that saves data to a temporary directory + if self.db_config.local: + return PersistentClient() + + # For production, we use an http client that connects to a remote server + return HttpClient( + url=self.db_config.url, + api_key=self.db_config.api_key, + config=self.db_config.config, + ) + + def get_vector_client(self) -> Union[PersistentClient, HttpClient]: + return self.client + + def get_vector_store(self, collection_name: str, embeddings: Embeddings): + pass + + def create_collection(self, collection_name: str, **kwargs): + try: + return self.client.create_collection( + name=collection_name, + **kwargs, + ) + except ValueError as e: + # Error is raised if + # 1. collection already exists + # 2. collection name is invalid + raise HTTPException( + status_code=400, detail=f"Unable to create collection: {e}" + ) + + def delete_collection(self, collection_name: str): + try: + return self.client.delete_collection(name=collection_name) + except ValueError as e: + raise HTTPException( + status_code=400, + detail=f"Unable to delete collection. {collection_name} does not exist", + ) + + def list_collections( + self, limit: Optional[int] = None, offset: Optional[int] = None + ): + return self.client.list_collections(limit=limit, offset=offset) + + def upsert_documents( + self, + collection_name: str, + documents: List[Document], + embeddings: Embeddings, + incremental: bool = True, + ): + return super().upsert_documents( + collection_name, documents, embeddings, incremental + ) + + def list_data_point_vectors( + self, + collection_name: str, + data_source_fqn: str, + batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, + ): + return super().list_data_point_vectors( + collection_name, data_source_fqn, batch_size + ) + + def delete_data_point_vectors( + self, + collection_name: str, + data_source_fqn: str, + batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, + ): + return super().delete_data_point_vectors( + collection_name, data_source_fqn, batch_size + ) diff --git a/backend/requirements.txt b/backend/requirements.txt index e376fd8f..12a6c37d 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -56,6 +56,7 @@ langchain-openai==0.2.5 ## vector db qdrant-client==1.9.0 +chromadb==0.5.15 ## dev autoflake==2.3.1 diff --git a/docker-compose.yaml b/docker-compose.yaml index 1cb3bdb1..88cf61ae 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -99,6 +99,27 @@ services: networks: - cognita-docker + + chromadb-server: + image: chromadb/chroma:0.5.13 + pull_policy: if_not_present + restart: unless-stopped + container_name: chroma + environment: + - IS_PERSISTENT=TRUE + ports: + - ${CHROMADB_PORT}:8000 + healthcheck: + test: ["CMD-SHELL", "/bin/bash -c ':> /dev/tcp/0.0.0.0/8000'"] + interval: 30s + timeout: 5s + retries: 3 + start_period: 10s + volumes: + - ./volumes/chromadb:/chroma/chroma + networks: + - cognita-docker + unstructured-io-parsers: # Docs: http://localhost:9500/general/docs image: downloads.unstructured.io/unstructured-io/unstructured-api:0.0.73 From 7cda8f75694fb14ab85d0a64f9c9b2f93ee1f67c Mon Sep 17 00:00:00 2001 From: Sai krishna Date: Sun, 20 Oct 2024 23:52:41 +0530 Subject: [PATCH 2/4] feat: add config pydantic classes for chroma and qdrant, refactor qdrant init method for readability --- backend/modules/query_controllers/base.py | 7 -- .../query_controllers/example/controller.py | 3 +- backend/modules/vector_db/chroma.py | 53 +++++++++++---- backend/modules/vector_db/qdrant.py | 68 ++++++++++--------- backend/types.py | 29 +++++++- 5 files changed, 107 insertions(+), 53 deletions(-) diff --git a/backend/modules/query_controllers/base.py b/backend/modules/query_controllers/base.py index 0a9bc580..506849ba 100644 --- a/backend/modules/query_controllers/base.py +++ b/backend/modules/query_controllers/base.py @@ -4,7 +4,6 @@ import async_timeout import requests from fastapi import HTTPException -from langchain.prompts import PromptTemplate from langchain.retrievers import ContextualCompressionRetriever, MultiQueryRetriever from langchain.schema.vectorstore import VectorStoreRetriever from langchain_core.language_models.chat_models import BaseChatModel @@ -33,12 +32,6 @@ class BaseQueryController: "relevance_score", ] - def _get_prompt_template(self, input_variables, template): - """ - Get the prompt template - """ - return PromptTemplate(input_variables=input_variables, template=template) - def _format_docs(self, docs): return "\n\n".join([doc.page_content for doc in docs]) diff --git a/backend/modules/query_controllers/example/controller.py b/backend/modules/query_controllers/example/controller.py index c9d9334f..65a26717 100644 --- a/backend/modules/query_controllers/example/controller.py +++ b/backend/modules/query_controllers/example/controller.py @@ -1,6 +1,7 @@ from fastapi import Body from fastapi.responses import StreamingResponse from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnableParallel, RunnablePassthrough from backend.modules.query_controllers.base import BaseQueryController @@ -35,7 +36,7 @@ async def answer( vector_store = await self._get_vector_store(request.collection_name) # Create the QA prompt templates - QA_PROMPT = self._get_prompt_template( + QA_PROMPT = PromptTemplate( input_variables=["context", "question"], template=request.prompt_template, ) diff --git a/backend/modules/vector_db/chroma.py b/backend/modules/vector_db/chroma.py index d888a3d4..898d18e6 100644 --- a/backend/modules/vector_db/chroma.py +++ b/backend/modules/vector_db/chroma.py @@ -1,22 +1,26 @@ -from typing import List, Optional, Union +from typing import List, Optional from chromadb import HttpClient, PersistentClient from chromadb.api import ClientAPI +from chromadb.api.models.Collection import Collection from fastapi import HTTPException from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings +from langchain_community.vectorstores import Chroma from backend.constants import DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE from backend.modules.vector_db.base import BaseVectorDB -from backend.types import VectorDBConfig +from backend.types import ChromaVectorDBConfig class ChromaVectorDB(BaseVectorDB): - def __init__(self, db_config: VectorDBConfig): + def __init__(self, db_config: ChromaVectorDBConfig): self.db_config = db_config - self.client = self.get_client() + self.client = self._create_client() - def get_client(self) -> ClientAPI: + ## Initialization utility + + def _create_client(self) -> ClientAPI: # For local development, we use a persistent client that saves data to a temporary directory if self.db_config.local: return PersistentClient() @@ -28,11 +32,20 @@ def get_client(self) -> ClientAPI: config=self.db_config.config, ) - def get_vector_client(self) -> Union[PersistentClient, HttpClient]: + ## Client + def get_vector_client(self) -> ClientAPI: return self.client - def get_vector_store(self, collection_name: str, embeddings: Embeddings): - pass + ## Vector store + + def get_vector_store(self, collection_name: str, **kwargs): + return Chroma( + client=self.client, + collection_name=collection_name, + **kwargs, + ) + + ## Collections def create_collection(self, collection_name: str, **kwargs): try: @@ -48,6 +61,9 @@ def create_collection(self, collection_name: str, **kwargs): status_code=400, detail=f"Unable to create collection: {e}" ) + def get_collection(self, collection_name: str): + return self.client.get_collection(name=collection_name) + def delete_collection(self, collection_name: str): try: return self.client.delete_collection(name=collection_name) @@ -59,9 +75,16 @@ def delete_collection(self, collection_name: str): def list_collections( self, limit: Optional[int] = None, offset: Optional[int] = None - ): + ) -> List[Collection]: + return self.client.list_collections(limit=limit, offset=offset) + + def get_collections( + self, limit: Optional[int] = None, offset: Optional[int] = None + ) -> List[Collection]: return self.client.list_collections(limit=limit, offset=offset) + ## Documents + def upsert_documents( self, collection_name: str, @@ -73,15 +96,21 @@ def upsert_documents( collection_name, documents, embeddings, incremental ) + def delete_documents(self, collection_name: str, document_ids: List[str]): + # Fetch the collection + collection: Collection = self.client.get_collection(collection_name) + # Delete the documents in the collection by ids + collection.delete(ids=document_ids) + + ## Data point vectors + def list_data_point_vectors( self, collection_name: str, data_source_fqn: str, batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, ): - return super().list_data_point_vectors( - collection_name, data_source_fqn, batch_size - ) + pass def delete_data_point_vectors( self, diff --git a/backend/modules/vector_db/qdrant.py b/backend/modules/vector_db/qdrant.py index 341493e3..c1e1576d 100644 --- a/backend/modules/vector_db/qdrant.py +++ b/backend/modules/vector_db/qdrant.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from urllib.parse import urlparse from langchain.embeddings.base import Embeddings @@ -9,36 +9,42 @@ from backend.constants import DATA_POINT_FQN_METADATA_KEY, DATA_POINT_HASH_METADATA_KEY from backend.logger import logger from backend.modules.vector_db.base import BaseVectorDB -from backend.types import DataPointVector, QdrantClientConfig, VectorDBConfig +from backend.types import DataPointVector, QdrantVectorDBConfig MAX_SCROLL_LIMIT = int(1e6) BATCH_SIZE = 1000 class QdrantVectorDB(BaseVectorDB): - def __init__(self, config: VectorDBConfig): - logger.debug(f"Connecting to qdrant using config: {config.model_dump()}") - if config.local is True: - # TODO: make this path customizable - self.qdrant_client = QdrantClient( - path="./qdrant_db", - ) - else: - url = config.url - api_key = config.api_key - if not api_key: - api_key = None - qdrant_kwargs = QdrantClientConfig.model_validate(config.config or {}) - if url.startswith("http://") or url.startswith("https://"): - if qdrant_kwargs.port is None: - parsed_port = urlparse(url).port - if parsed_port: - qdrant_kwargs.port = parsed_port - else: - qdrant_kwargs.port = 443 if url.startswith("https://") else 6333 - self.qdrant_client = QdrantClient( - url=url, api_key=api_key, **qdrant_kwargs.model_dump() - ) + def __init__(self, db_config: QdrantVectorDBConfig): + logger.debug(f"Connecting to qdrant using config: {db_config.model_dump()}") + self.qdrant_client = self._create_client(db_config) + + def _create_client(self, db_config: QdrantVectorDBConfig) -> QdrantClient: + # Local + if db_config.local: + return QdrantClient(path=db_config.path) + + url = db_config.url + + if url.startswith(("http://", "https://")): + db_config.config.port = self._get_port(url, db_config.config.port) + + # If the Qdrant server is hosted on a remote server, create an http client + return QdrantClient( + url=url, api_key=db_config.api_key, **db_config.config.model_dump() + ) + + @staticmethod + def _get_port(url: str, existing_port: Optional[int]) -> int: + if existing_port: + return existing_port + + parsed_port = urlparse(url).port + if parsed_port: + return parsed_port + + return 443 if url.startswith("https://") else 6333 def create_collection(self, collection_name: str, embeddings: Embeddings): logger.debug(f"[Qdrant] Creating new collection {collection_name}") @@ -113,11 +119,11 @@ def _get_records_to_be_upserted( def upsert_documents( self, collection_name: str, - documents, + documents: list, embeddings: Embeddings, incremental: bool = True, ): - if len(documents) == 0: + if not documents: logger.warning("No documents to index") return # get record IDs to be upserted @@ -219,11 +225,9 @@ def list_data_point_vectors( offset=offset, ) for record in records: - metadata: dict = record.payload.get("metadata") - if ( - metadata - and metadata.get(DATA_POINT_FQN_METADATA_KEY) - and metadata.get(DATA_POINT_HASH_METADATA_KEY) + metadata: dict = record.payload.get("metadata", {}) + if metadata.get(DATA_POINT_FQN_METADATA_KEY) and metadata.get( + DATA_POINT_HASH_METADATA_KEY ): data_point_vectors.append( DataPointVector( diff --git a/backend/types.py b/backend/types.py index d4827ba8..c943f1f3 100644 --- a/backend/types.py +++ b/backend/types.py @@ -176,10 +176,28 @@ class VectorDBConfig(ConfiguredBaseModel): """ provider: str + path: Optional[str] = None local: bool = False url: Optional[str] = None api_key: Optional[str] = None - config: Optional[Dict[str, Any]] = Field(default_factory=dict) + config: Dict[str, Any] = Field(default_factory=dict) + + +class ChromaClientConfig(ConfiguredBaseModel): + """ + Chroma client configuration + """ + + model_config = ConfigDict(extra="allow") + + +class ChromaVectorDBConfig(VectorDBConfig): + """ + Chroma-specific vector db configuration + """ + + path: Optional[str] = "./chroma_db" + config: ChromaClientConfig = Field(default_factory=ChromaClientConfig) class QdrantClientConfig(ConfiguredBaseModel): @@ -196,6 +214,15 @@ class QdrantClientConfig(ConfiguredBaseModel): timeout: int = 300 +class QdrantVectorDBConfig(VectorDBConfig): + """ + Qdrant-specific vector db configuration + """ + + path: Optional[str] = "./qdrant_db" + config: QdrantClientConfig = Field(default_factory=QdrantClientConfig) + + class MetadataStoreConfig(ConfiguredBaseModel): """ Metadata store configuration From 32bf746c0d7de6337e4e11114d313613790b95a0 Mon Sep 17 00:00:00 2001 From: Sai krishna Date: Mon, 21 Oct 2024 12:14:50 +0530 Subject: [PATCH 3/4] feat: list data point vectors, refactor qdrant --- backend/modules/vector_db/chroma.py | 99 +++++++++++++++++++++++------ backend/modules/vector_db/qdrant.py | 52 ++++++++------- 2 files changed, 107 insertions(+), 44 deletions(-) diff --git a/backend/modules/vector_db/chroma.py b/backend/modules/vector_db/chroma.py index 898d18e6..e1e6cf5a 100644 --- a/backend/modules/vector_db/chroma.py +++ b/backend/modules/vector_db/chroma.py @@ -8,9 +8,10 @@ from langchain.embeddings.base import Embeddings from langchain_community.vectorstores import Chroma -from backend.constants import DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE +from backend.constants import DATA_POINT_FQN_METADATA_KEY, DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE +from backend.logger import logger from backend.modules.vector_db.base import BaseVectorDB -from backend.types import ChromaVectorDBConfig +from backend.types import ChromaVectorDBConfig, DataPointVector class ChromaVectorDB(BaseVectorDB): @@ -47,24 +48,24 @@ def get_vector_store(self, collection_name: str, **kwargs): ## Collections - def create_collection(self, collection_name: str, **kwargs): + def create_collection(self, collection_name: str, **kwargs) -> Collection: try: return self.client.create_collection( name=collection_name, **kwargs, ) except ValueError as e: - # Error is raised if + # Error is raised by chroma client if # 1. collection already exists # 2. collection name is invalid raise HTTPException( status_code=400, detail=f"Unable to create collection: {e}" ) - def get_collection(self, collection_name: str): + def get_collection(self, collection_name: str) -> Collection: return self.client.get_collection(name=collection_name) - def delete_collection(self, collection_name: str): + def delete_collection(self, collection_name: str) -> None: try: return self.client.delete_collection(name=collection_name) except ValueError as e: @@ -73,15 +74,13 @@ def delete_collection(self, collection_name: str): detail=f"Unable to delete collection. {collection_name} does not exist", ) - def list_collections( - self, limit: Optional[int] = None, offset: Optional[int] = None - ) -> List[Collection]: - return self.client.list_collections(limit=limit, offset=offset) - def get_collections( self, limit: Optional[int] = None, offset: Optional[int] = None - ) -> List[Collection]: - return self.client.list_collections(limit=limit, offset=offset) + ) -> List[str]: + return [ + collection.name + for collection in self.client.list_collections(limit=limit, offset=offset) + ] ## Documents @@ -92,9 +91,20 @@ def upsert_documents( embeddings: Embeddings, incremental: bool = True, ): - return super().upsert_documents( - collection_name, documents, embeddings, incremental + if not documents: + logger.warning("No documents to index") + return + # get record IDs to be upserted + logger.debug( + f"[Qdrant] Adding {len(documents)} documents to collection {collection_name}" ) + # Collect the data point fqns from the documents + data_point_fqns = [ + doc.metadata.get(DATA_POINT_FQN_METADATA_KEY) + for doc in documents + if doc.metadata.get(DATA_POINT_FQN_METADATA_KEY) + ] + def delete_documents(self, collection_name: str, document_ids: List[str]): # Fetch the collection @@ -109,15 +119,62 @@ def list_data_point_vectors( collection_name: str, data_source_fqn: str, batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, - ): - pass + ) -> List[DataPointVector]: + # Fetch the collection by collection name + collection = self.get_collection(collection_name) + # Initialize the data point vectors list + data_point_vectors = [] + # Initialize the offset + offset = 0 + + while True: + # Fetch the documents from the collection and limit the fetch to batch size + documents = collection.get( + where={ + "data_source_fqn": data_source_fqn, + }, + # Only fetch the metadata and ids of the documents since we need to return the data point vectors + include=["metadatas", "ids"], + # Limit the fetch to `batch_size` + limit=batch_size, + # Offset the fetch by `offset` + offset=offset, + ) + + # Break the loop if: + # 1. No documents are fetched -> we've reached the end of the collection + # 2. The number of documents fetched is less than the batch size -> we've reached the end of the collection + if not documents["ids"] or len(documents["ids"]) < batch_size: + break + + # Iterate over the documents and append the data point vectors to the list if the metadata contains the data source fqn and data point hash + for doc_id, metadata in zip(documents["ids"], documents["metadatas"]): + # TODO: what if either of the metadata keys are missing? + if metadata.get("data_source_fqn") and metadata.get("data_point_hash"): + # Append the data point vector to the list + data_point_vectors.append( + DataPointVector( + data_point_vector_id=doc_id, + data_point_fqn=metadata.get("data_source_fqn"), + data_point_hash=metadata.get("data_point_hash"), + ) + ) + + # Increment the offset by the number of documents fetched + offset += len(documents["ids"]) + + return data_point_vectors def delete_data_point_vectors( self, collection_name: str, - data_source_fqn: str, - batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, + data_point_vectors: List[DataPointVector], + **kwargs, ): - return super().delete_data_point_vectors( - collection_name, data_source_fqn, batch_size + # Fetch the collection by collection name + collection = self.get_collection(collection_name) + # Delete the documents in the collection by ids + collection.delete( + ids=[vector.data_point_vector_id for vector in data_point_vectors] ) + logger.debug(f"[Chroma] Deleted {len(data_point_vectors)} data point vectors") diff --git a/backend/modules/vector_db/qdrant.py b/backend/modules/vector_db/qdrant.py index c1e1576d..13a93e02 100644 --- a/backend/modules/vector_db/qdrant.py +++ b/backend/modules/vector_db/qdrant.py @@ -130,12 +130,13 @@ def upsert_documents( logger.debug( f"[Qdrant] Adding {len(documents)} documents to collection {collection_name}" ) - data_point_fqns = [] - for document in documents: - if document.metadata.get(DATA_POINT_FQN_METADATA_KEY): - data_point_fqns.append( - document.metadata.get(DATA_POINT_FQN_METADATA_KEY) - ) + # Collect the data point fqns from the documents + data_point_fqns = [ + doc.metadata.get(DATA_POINT_FQN_METADATA_KEY) + for doc in documents + if doc.metadata.get(DATA_POINT_FQN_METADATA_KEY) + ] + # Get the record ids to be upserted record_ids_to_be_upserted: List[str] = self._get_records_to_be_upserted( collection_name=collection_name, data_point_fqns=data_point_fqns, @@ -152,24 +153,28 @@ def upsert_documents( f"[Qdrant] Added {len(documents)} documents to collection {collection_name}" ) - # Delete Documents - if len(record_ids_to_be_upserted): - logger.debug( - f"[Qdrant] Deleting {len(documents)} outdated documents from collection {collection_name}" - ) - for i in range(0, len(record_ids_to_be_upserted), BATCH_SIZE): - record_ids_to_be_processed = record_ids_to_be_upserted[ - i : i + BATCH_SIZE - ] - self.qdrant_client.delete( - collection_name=collection_name, - points_selector=models.PointIdsList( - points=record_ids_to_be_processed, - ), - ) - logger.debug( - f"[Qdrant] Deleted {len(documents)} outdated documents from collection {collection_name}" + # Delete old documents + + # If there are no record ids to be upserted, return + if not len(record_ids_to_be_upserted): + return + + logger.debug( + f"[Qdrant] Deleting {len(documents)} outdated documents from collection {collection_name}" + ) + for i in range(0, len(record_ids_to_be_upserted), BATCH_SIZE): + record_ids_to_be_processed = record_ids_to_be_upserted[ + i : i + BATCH_SIZE + ] + self.qdrant_client.delete( + collection_name=collection_name, + points_selector=models.PointIdsList( + points=record_ids_to_be_processed, + ), ) + logger.debug( + f"[Qdrant] Deleted {len(documents)} outdated documents from collection {collection_name}" + ) def get_collections(self) -> List[str]: logger.debug(f"[Qdrant] Fetching collections") @@ -415,3 +420,4 @@ def list_document_vector_points( f"[Qdrant] Listing {len(document_vector_points)} document vector points for collection {collection_name}" ) return document_vector_points + From 3e4ae8c3dd5d673434ed3dd4557c4482134d586c Mon Sep 17 00:00:00 2001 From: Sai krishna Date: Tue, 5 Nov 2024 12:15:55 +0530 Subject: [PATCH 4/4] lint --- backend/modules/vector_db/chroma.py | 6 ++++-- backend/modules/vector_db/qdrant.py | 9 +++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/backend/modules/vector_db/chroma.py b/backend/modules/vector_db/chroma.py index e1e6cf5a..4a918010 100644 --- a/backend/modules/vector_db/chroma.py +++ b/backend/modules/vector_db/chroma.py @@ -8,7 +8,10 @@ from langchain.embeddings.base import Embeddings from langchain_community.vectorstores import Chroma -from backend.constants import DATA_POINT_FQN_METADATA_KEY, DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE +from backend.constants import ( + DATA_POINT_FQN_METADATA_KEY, + DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE, +) from backend.logger import logger from backend.modules.vector_db.base import BaseVectorDB from backend.types import ChromaVectorDBConfig, DataPointVector @@ -104,7 +107,6 @@ def upsert_documents( for doc in documents if doc.metadata.get(DATA_POINT_FQN_METADATA_KEY) ] - def delete_documents(self, collection_name: str, document_ids: List[str]): # Fetch the collection diff --git a/backend/modules/vector_db/qdrant.py b/backend/modules/vector_db/qdrant.py index 13a93e02..e275ac39 100644 --- a/backend/modules/vector_db/qdrant.py +++ b/backend/modules/vector_db/qdrant.py @@ -154,18 +154,16 @@ def upsert_documents( ) # Delete old documents - + # If there are no record ids to be upserted, return if not len(record_ids_to_be_upserted): return - + logger.debug( f"[Qdrant] Deleting {len(documents)} outdated documents from collection {collection_name}" ) for i in range(0, len(record_ids_to_be_upserted), BATCH_SIZE): - record_ids_to_be_processed = record_ids_to_be_upserted[ - i : i + BATCH_SIZE - ] + record_ids_to_be_processed = record_ids_to_be_upserted[i : i + BATCH_SIZE] self.qdrant_client.delete( collection_name=collection_name, points_selector=models.PointIdsList( @@ -420,4 +418,3 @@ def list_document_vector_points( f"[Qdrant] Listing {len(document_vector_points)} document vector points for collection {collection_name}" ) return document_vector_points -