Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: init chroma client #384

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions backend/modules/query_controllers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import async_timeout
import requests
from fastapi import HTTPException
from langchain.prompts import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever, MultiQueryRetriever
from langchain.schema.vectorstore import VectorStoreRetriever
from langchain_core.language_models.chat_models import BaseChatModel
Expand Down Expand Up @@ -33,12 +32,6 @@ class BaseQueryController:
"relevance_score",
]

def _get_prompt_template(self, input_variables, template):
"""
Get the prompt template
"""
return PromptTemplate(input_variables=input_variables, template=template)

def _format_docs(self, docs):
return "\n\n".join([doc.page_content for doc in docs])

Expand Down
3 changes: 2 additions & 1 deletion backend/modules/query_controllers/example/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough

from backend.modules.query_controllers.base import BaseQueryController
Expand Down Expand Up @@ -35,7 +36,7 @@ async def answer(
vector_store = await self._get_vector_store(request.collection_name)

# Create the QA prompt templates
QA_PROMPT = self._get_prompt_template(
QA_PROMPT = PromptTemplate(
input_variables=["context", "question"],
template=request.prompt_template,
)
Expand Down
2 changes: 1 addition & 1 deletion backend/modules/vector_db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class BaseVectorDB(ABC):
@abstractmethod
def create_collection(self, collection_name: str, embeddings: Embeddings):
def create_collection(self, collection_name: str, embeddings: Embeddings = None):
"""
Create a collection in the vector database
"""
Expand Down
182 changes: 182 additions & 0 deletions backend/modules/vector_db/chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import List, Optional

from chromadb import HttpClient, PersistentClient
from chromadb.api import ClientAPI
from chromadb.api.models.Collection import Collection
from fastapi import HTTPException
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain_community.vectorstores import Chroma

from backend.constants import (
DATA_POINT_FQN_METADATA_KEY,
DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
)
from backend.logger import logger
from backend.modules.vector_db.base import BaseVectorDB
from backend.types import ChromaVectorDBConfig, DataPointVector


class ChromaVectorDB(BaseVectorDB):
def __init__(self, db_config: ChromaVectorDBConfig):
self.db_config = db_config
self.client = self._create_client()

## Initialization utility

def _create_client(self) -> ClientAPI:
# For local development, we use a persistent client that saves data to a temporary directory
if self.db_config.local:
return PersistentClient()

# For production, we use an http client that connects to a remote server
return HttpClient(
url=self.db_config.url,
api_key=self.db_config.api_key,
config=self.db_config.config,
)

## Client
def get_vector_client(self) -> ClientAPI:
return self.client

## Vector store

def get_vector_store(self, collection_name: str, **kwargs):
return Chroma(
client=self.client,
collection_name=collection_name,
**kwargs,
)

## Collections

def create_collection(self, collection_name: str, **kwargs) -> Collection:
try:
return self.client.create_collection(
name=collection_name,
**kwargs,
)
except ValueError as e:
# Error is raised by chroma client if
# 1. collection already exists
# 2. collection name is invalid
raise HTTPException(
status_code=400, detail=f"Unable to create collection: {e}"
)

def get_collection(self, collection_name: str) -> Collection:
return self.client.get_collection(name=collection_name)

def delete_collection(self, collection_name: str) -> None:
try:
return self.client.delete_collection(name=collection_name)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Unable to delete collection. {collection_name} does not exist",
)

def get_collections(
self, limit: Optional[int] = None, offset: Optional[int] = None
) -> List[str]:
return [
collection.name
for collection in self.client.list_collections(limit=limit, offset=offset)
]

## Documents

def upsert_documents(
self,
collection_name: str,
documents: List[Document],
embeddings: Embeddings,
incremental: bool = True,
):
if not documents:
logger.warning("No documents to index")
return
# get record IDs to be upserted
logger.debug(
f"[Qdrant] Adding {len(documents)} documents to collection {collection_name}"
)
# Collect the data point fqns from the documents
data_point_fqns = [
doc.metadata.get(DATA_POINT_FQN_METADATA_KEY)
for doc in documents
if doc.metadata.get(DATA_POINT_FQN_METADATA_KEY)
]

def delete_documents(self, collection_name: str, document_ids: List[str]):
# Fetch the collection
collection: Collection = self.client.get_collection(collection_name)
# Delete the documents in the collection by ids
collection.delete(ids=document_ids)

## Data point vectors

def list_data_point_vectors(
self,
collection_name: str,
data_source_fqn: str,
batch_size: int = DEFAULT_BATCH_SIZE_FOR_VECTOR_STORE,
) -> List[DataPointVector]:
# Fetch the collection by collection name
collection = self.get_collection(collection_name)
# Initialize the data point vectors list
data_point_vectors = []
# Initialize the offset
offset = 0

while True:
# Fetch the documents from the collection and limit the fetch to batch size
documents = collection.get(
where={
"data_source_fqn": data_source_fqn,
},
# Only fetch the metadata and ids of the documents since we need to return the data point vectors
include=["metadatas", "ids"],
# Limit the fetch to `batch_size`
limit=batch_size,
# Offset the fetch by `offset`
offset=offset,
)

# Break the loop if:
# 1. No documents are fetched -> we've reached the end of the collection
# 2. The number of documents fetched is less than the batch size -> we've reached the end of the collection
if not documents["ids"] or len(documents["ids"]) < batch_size:
break

# Iterate over the documents and append the data point vectors to the list if the metadata contains the data source fqn and data point hash
for doc_id, metadata in zip(documents["ids"], documents["metadatas"]):
# TODO: what if either of the metadata keys are missing?
if metadata.get("data_source_fqn") and metadata.get("data_point_hash"):
# Append the data point vector to the list
data_point_vectors.append(
DataPointVector(
data_point_vector_id=doc_id,
data_point_fqn=metadata.get("data_source_fqn"),
data_point_hash=metadata.get("data_point_hash"),
)
)

# Increment the offset by the number of documents fetched
offset += len(documents["ids"])

return data_point_vectors

def delete_data_point_vectors(
self,
collection_name: str,
data_point_vectors: List[DataPointVector],
**kwargs,
):
# Fetch the collection by collection name
collection = self.get_collection(collection_name)
# Delete the documents in the collection by ids
collection.delete(
ids=[vector.data_point_vector_id for vector in data_point_vectors]
)
logger.debug(f"[Chroma] Deleted {len(data_point_vectors)} data point vectors")
117 changes: 62 additions & 55 deletions backend/modules/vector_db/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from urllib.parse import urlparse

from langchain.embeddings.base import Embeddings
Expand All @@ -9,36 +9,42 @@
from backend.constants import DATA_POINT_FQN_METADATA_KEY, DATA_POINT_HASH_METADATA_KEY
from backend.logger import logger
from backend.modules.vector_db.base import BaseVectorDB
from backend.types import DataPointVector, QdrantClientConfig, VectorDBConfig
from backend.types import DataPointVector, QdrantVectorDBConfig

MAX_SCROLL_LIMIT = int(1e6)
BATCH_SIZE = 1000


class QdrantVectorDB(BaseVectorDB):
def __init__(self, config: VectorDBConfig):
logger.debug(f"Connecting to qdrant using config: {config.model_dump()}")
if config.local is True:
# TODO: make this path customizable
self.qdrant_client = QdrantClient(
path="./qdrant_db",
)
else:
url = config.url
api_key = config.api_key
if not api_key:
api_key = None
qdrant_kwargs = QdrantClientConfig.model_validate(config.config or {})
if url.startswith("http://") or url.startswith("https://"):
if qdrant_kwargs.port is None:
parsed_port = urlparse(url).port
if parsed_port:
qdrant_kwargs.port = parsed_port
else:
qdrant_kwargs.port = 443 if url.startswith("https://") else 6333
self.qdrant_client = QdrantClient(
url=url, api_key=api_key, **qdrant_kwargs.model_dump()
)
def __init__(self, db_config: QdrantVectorDBConfig):
logger.debug(f"Connecting to qdrant using config: {db_config.model_dump()}")
self.qdrant_client = self._create_client(db_config)

def _create_client(self, db_config: QdrantVectorDBConfig) -> QdrantClient:
# Local
if db_config.local:
return QdrantClient(path=db_config.path)

url = db_config.url

if url.startswith(("http://", "https://")):
db_config.config.port = self._get_port(url, db_config.config.port)

# If the Qdrant server is hosted on a remote server, create an http client
return QdrantClient(
url=url, api_key=db_config.api_key, **db_config.config.model_dump()
)

@staticmethod
def _get_port(url: str, existing_port: Optional[int]) -> int:
if existing_port:
return existing_port

parsed_port = urlparse(url).port
if parsed_port:
return parsed_port

return 443 if url.startswith("https://") else 6333

def create_collection(self, collection_name: str, embeddings: Embeddings):
logger.debug(f"[Qdrant] Creating new collection {collection_name}")
Expand Down Expand Up @@ -113,23 +119,24 @@ def _get_records_to_be_upserted(
def upsert_documents(
self,
collection_name: str,
documents,
documents: list,
embeddings: Embeddings,
incremental: bool = True,
):
if len(documents) == 0:
if not documents:
logger.warning("No documents to index")
return
# get record IDs to be upserted
logger.debug(
f"[Qdrant] Adding {len(documents)} documents to collection {collection_name}"
)
data_point_fqns = []
for document in documents:
if document.metadata.get(DATA_POINT_FQN_METADATA_KEY):
data_point_fqns.append(
document.metadata.get(DATA_POINT_FQN_METADATA_KEY)
)
# Collect the data point fqns from the documents
data_point_fqns = [
doc.metadata.get(DATA_POINT_FQN_METADATA_KEY)
for doc in documents
if doc.metadata.get(DATA_POINT_FQN_METADATA_KEY)
]
# Get the record ids to be upserted
record_ids_to_be_upserted: List[str] = self._get_records_to_be_upserted(
collection_name=collection_name,
data_point_fqns=data_point_fqns,
Expand All @@ -146,24 +153,26 @@ def upsert_documents(
f"[Qdrant] Added {len(documents)} documents to collection {collection_name}"
)

# Delete Documents
if len(record_ids_to_be_upserted):
logger.debug(
f"[Qdrant] Deleting {len(documents)} outdated documents from collection {collection_name}"
)
for i in range(0, len(record_ids_to_be_upserted), BATCH_SIZE):
record_ids_to_be_processed = record_ids_to_be_upserted[
i : i + BATCH_SIZE
]
self.qdrant_client.delete(
collection_name=collection_name,
points_selector=models.PointIdsList(
points=record_ids_to_be_processed,
),
)
logger.debug(
f"[Qdrant] Deleted {len(documents)} outdated documents from collection {collection_name}"
# Delete old documents

# If there are no record ids to be upserted, return
if not len(record_ids_to_be_upserted):
return

logger.debug(
f"[Qdrant] Deleting {len(documents)} outdated documents from collection {collection_name}"
)
for i in range(0, len(record_ids_to_be_upserted), BATCH_SIZE):
record_ids_to_be_processed = record_ids_to_be_upserted[i : i + BATCH_SIZE]
self.qdrant_client.delete(
collection_name=collection_name,
points_selector=models.PointIdsList(
points=record_ids_to_be_processed,
),
)
logger.debug(
f"[Qdrant] Deleted {len(documents)} outdated documents from collection {collection_name}"
)

def get_collections(self) -> List[str]:
logger.debug(f"[Qdrant] Fetching collections")
Expand Down Expand Up @@ -219,11 +228,9 @@ def list_data_point_vectors(
offset=offset,
)
for record in records:
metadata: dict = record.payload.get("metadata")
if (
metadata
and metadata.get(DATA_POINT_FQN_METADATA_KEY)
and metadata.get(DATA_POINT_HASH_METADATA_KEY)
metadata: dict = record.payload.get("metadata", {})
if metadata.get(DATA_POINT_FQN_METADATA_KEY) and metadata.get(
DATA_POINT_HASH_METADATA_KEY
):
data_point_vectors.append(
DataPointVector(
Expand Down
Loading
Loading