diff --git a/README.md b/README.md index 7e685abf..58e5f4f6 100644 --- a/README.md +++ b/README.md @@ -178,42 +178,46 @@ _:1 -## Creating an Embedding Vector Database +## Search and Retrieval via Qdrant Vector Database +
To see a code snippet -##### Learning Embeddings ```bash # Train an embedding model -dicee --dataset_dir KGs/Countries-S1 --path_to_store_single_run CountryEmbeddings --model Keci --p 0 --q 1 --embedding_dim 32 --adaptive_swa +dicee --dataset_dir KGs/Countries-S1 --path_to_store_single_run CountryEmbeddings --model Keci --p 0 --q 1 --embedding_dim 256 --scoring_technique AllvsAll --num_epochs 300 --save_embeddings_as_csv ``` -#### Loading Embeddings into Qdrant Vector Database +Start qdrant instance. + ```bash -# Ensure that Qdrant available -# docker pull qdrant/qdrant && docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant -diceeindex --path_model "CountryEmbeddings" --collection_name "dummy" --location "localhost" +pip3 install fastapi uvicorn qdrant-client +docker pull qdrant/qdrant && docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant ``` -#### Launching Webservice +Upload Embeddings into vector database and start a webservice ```bash -diceeserve --path_model "CountryEmbeddings" --collection_name "dummy" --collection_location "localhost" -``` -##### Retrieve and Search - -Get embedding of germany +dicee_vector_db --index --serve --path CountryEmbeddings --collection "countries_vdb" +Creating a collection countries_vdb with distance metric:Cosine +Completed! +INFO: Started server process [28953] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` +Retrieve an embedding vector. ```bash curl -X 'GET' 'http://0.0.0.0:8000/api/get?q=germany' -H 'accept: application/json' +# {"result": [{"name": "europe","vector": [...]}]} ``` - -Get most similar things to europe +Retrieve embedding vectors. +```bash +curl -X 'POST' 'http://0.0.0.0:8000/api/search_batch' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"queries": ["brunei","guam"]}' +# {"results": [{ "name": "europe","vector": [...]},{ "name": "northern_europe","vector": [...]}]} +``` +Retrieve an average of embedding vectors. ```bash -curl -X 'GET' 'http://0.0.0.0:8000/api/search?q=europe' -H 'accept: application/json' -{"result":[{"hit":"europe","score":1.0}, -{"hit":"northern_europe","score":0.67126536}, -{"hit":"western_europe","score":0.6010134}, -{"hit":"puerto_rico","score":0.5051694}, -{"hit":"southern_europe","score":0.4829831}]} +curl -X 'POST' 'http://0.0.0.0:8000/api/search_batch' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"queries": ["europe","northern_europe"],"reducer": "mean"}' +# {"results":{"name": ["europe","northern_europe"],"vectors": [...]}} ```
diff --git a/dicee/scripts/index.py b/dicee/scripts/index.py deleted file mode 100644 index f588b72c..00000000 --- a/dicee/scripts/index.py +++ /dev/null @@ -1,31 +0,0 @@ -import argparse - - -def get_default_arguments(): - parser = argparse.ArgumentParser(add_help=False) - parser.add_argument("--path_model", type=str, required=True, - help="The path of a directory containing pre-trained model") - parser.add_argument("--collection_name", type=str, required=True, - help="Named of the vector database collection") - parser.add_argument("--location", type=str, required=True, - help="location") - return parser.parse_args() - - -def main(): - args = get_default_arguments() - # docker pull qdrant/qdrant - # docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant - # pip install qdrant-client - - from dicee.knowledge_graph_embeddings import KGE - - # Train a model on Countries dataset - KGE(path=args.path_model).create_vector_database(collection_name=args.collection_name, - location=args.location, - distance="cosine") - return "Completed!" - - -if __name__ == '__main__': - main() diff --git a/dicee/scripts/index_serve.py b/dicee/scripts/index_serve.py new file mode 100644 index 00000000..9f484ce9 --- /dev/null +++ b/dicee/scripts/index_serve.py @@ -0,0 +1,150 @@ +""" +$ docker pull qdrant/qdrant && docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant +$ dicee_vector_db --index --serve --path CountryEmbeddings --collection "countries_vdb" + +""" +import argparse +import os +import numpy as np +import pandas as pd +from qdrant_client import QdrantClient +from qdrant_client.http.models import Distance, VectorParams +from qdrant_client.http.models import PointStruct + +from fastapi import FastAPI +import uvicorn +from pydantic import BaseModel +from typing import List, Optional + +def get_default_arguments(): + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--path", type=str, required=True, + help="The path of a directory containing embedding csv file(s)") + parser.add_argument("--index", action="store_true", help="A flag for indexing") + parser.add_argument("--serve", action="store_true", help="A flag for serving") + + parser.add_argument("--collection", type=str, required=True,help="Named of the vector database collection") + + parser.add_argument("--vdb_host", type=str,default="localhost",help="Host of qdrant vector database") + parser.add_argument("--vdb_port", type=int,default=6333,help="port number") + parser.add_argument("--host",type=str, default="0.0.0.0",help="Host") + parser.add_argument("--port", type=int, default=8000,help="port number") + return parser.parse_args() + + +def index(args): + client = QdrantClient(host=args.vdb_host, port=args.vdb_port) + entity_to_idx = pd.read_csv(args.path + "/entity_to_idx.csv", index_col=0) + assert entity_to_idx.index.is_monotonic_increasing, "Entity Index must be monotonically increasing!{}" + entity_to_idx = {name: idx for idx, name in enumerate(entity_to_idx["entity"].tolist())} + + csv_files_holding_embeddings = [args.path + "/" + f for f in os.listdir(args.path) if "entity_embeddings.csv" in f] + assert len( + csv_files_holding_embeddings) == 1, f"There must be only single csv file containing entity_embeddings.csv prefix. Currently, :{len(csv_files_holding_embeddings)}" + path_entity_embeddings_csv = csv_files_holding_embeddings[0] + + points = [] + embedding_dim = None + for ith_row, (index_name, embedding_vector) in enumerate( + pd.read_csv(path_entity_embeddings_csv, index_col=0, header=0).iterrows()): + index_name: str + embedding_vector: np.ndarray + embedding_vector = embedding_vector.values + + points.append(PointStruct(id=entity_to_idx[index_name], + vector=embedding_vector, + payload={"name": index_name})) + + embedding_dim = len(embedding_vector) + + assert embedding_dim > 0 + # If the collection is not created, create it + if args.collection in [i.name for i in client.get_collections().collections]: + print("Deleting existing collection ", args.collection) + client.delete_collection(collection_name=args.collection) + + print(f"Creating a collection {args.collection} with distance metric:Cosine") + client.create_collection(collection_name=args.collection, + vectors_config=VectorParams(size=embedding_dim, + distance=Distance.COSINE)) + client.upsert(collection_name=args.collection, points=points) + print("Completed!") + + +app = FastAPI() +# Create a neural searcher instance +neural_searcher = None + +class NeuralSearcher: + def __init__(self, args): + self.collection_name = args.collection + assert os.path.exists(args.path + "/entity_to_idx.csv"), f"{args.path + '/entity_to_idx.csv'} does not exist!" + self.entity_to_idx = pd.read_csv(args.path + "/entity_to_idx.csv", index_col=0) + assert self.entity_to_idx.index.is_monotonic_increasing, "Entity Index must be monotonically increasing!{}" + self.entity_to_idx = {name: idx for idx, name in enumerate(self.entity_to_idx["entity"].tolist())} + # initialize Qdrant client + self.qdrant_client = QdrantClient(host=args.vdb_host,port=args.vdb_port) + # semantic search + self.topk=5 + + def retrieve_embedding(self,entity:str=None,entities:List[str]=None)->List: + ids=[] + inputs= [entity] + if entities is not None: + inputs.extend(entities) + for ent in inputs: + if idx := self.entity_to_idx.get(ent, None): + assert isinstance(idx, int) + ids.append(idx) + if len(ids)<1: + return {"error":f"IDs are not found for ({entity} or {entities})"} + else: + return [{"name": result.payload["name"], "vector": result.vector} for result in self.qdrant_client.retrieve(collection_name=self.collection_name,ids=ids, with_vectors=True)] + + def search(self, entity: str): + return self.qdrant_client.query_points(collection_name=self.collection_name, query=self.entity_to_idx[entity],limit=self.topk) + +@app.get("/") +async def root(): + return {"message": "Hello Dice Embedding User"} + +@app.get("/api/search") +async def search_embeddings(q: str): + return {"result": neural_searcher.search(entity=q)} + +@app.get("/api/get") +async def retrieve_embeddings(q: str): + return {"result": neural_searcher.retrieve_embedding(entity=q)} + +class StringListRequest(BaseModel): + queries: List[str] + reducer: Optional[str] = None # Add the reducer flag with default as None + + +@app.post("/api/search_batch") +async def search_embeddings_batch(request: StringListRequest): + if request.reducer == "mean": + names=[] + vectors=[] + for result in neural_searcher.retrieve_embedding(entities=request.queries): + names.append(result["name"]) + vectors.append(result["vector"]) + embeddings = np.mean(vectors, axis=0).tolist() # Reduce to mean + return {"results":{"name":names,"vectors":embeddings}} + else: + return {"results": neural_searcher.retrieve_embedding(entities=request.queries)} + +def serve(args): + global neural_searcher + neural_searcher = NeuralSearcher(args) + uvicorn.run(app, host=args.host, port=args.port) + +def main(): + args = get_default_arguments() + if args.index: + index(args) + if args.serve: + serve(args) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/dicee/scripts/serve.py b/dicee/scripts/serve.py deleted file mode 100644 index d4b41fdc..00000000 --- a/dicee/scripts/serve.py +++ /dev/null @@ -1,64 +0,0 @@ -import argparse -from ..knowledge_graph_embeddings import KGE -from fastapi import FastAPI -import uvicorn -from qdrant_client import QdrantClient - -app = FastAPI() -# Create a neural searcher instance -neural_searcher = None -def get_default_arguments(): - parser = argparse.ArgumentParser(add_help=False) - parser.add_argument("--path_model", type=str, required=True, - help="The path of a directory containing pre-trained model") - parser.add_argument("--collection_name", type=str, required=True, help="Named of the vector database collection") - parser.add_argument("--collection_location", type=str, required=True, help="location") - parser.add_argument("--host",type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=8000) - return parser.parse_args() - -@app.get("/") -async def root(): - return {"message": "Hello Dice Embedding User"} - -@app.get("/api/search") -async def search_embeddings(q: str): - return {"result": neural_searcher.search(entity=q)} - -@app.get("/api/get") -async def retrieve_embeddings(q: str): - return {"result": neural_searcher.get(entity=q)} - -class NeuralSearcher: - def __init__(self, args): - self.collection_name = args.collection_name - # Initialize encoder model - self.model = KGE(path=args.path_model) - # initialize Qdrant client - self.qdrant_client = QdrantClient(location=args.collection_location) - - def get(self,entity:str): - return self.model.get_transductive_entity_embeddings(indices=[entity], as_list=True)[0] - - def search(self, entity: str): - # Convert text query into vector - vector=self.get(entity) - - # Use `vector` for search for closest vectors in the collection - search_result = self.qdrant_client.search( - collection_name=self.collection_name, - query_vector=vector, - query_filter=None, # If you don't want any filters for now - limit=5, # 5 the most closest results is enough - ) - return [{"hit": i.payload["name"], "score": i.score} for i in search_result] - - -def main(): - args = get_default_arguments() - global neural_searcher - neural_searcher = NeuralSearcher(args) - uvicorn.run(app, host=args.host, port=args.port) - -if __name__ == '__main__': - main() diff --git a/setup.py b/setup.py index 5b9b97d3..89f39d48 100644 --- a/setup.py +++ b/setup.py @@ -68,8 +68,7 @@ def deps_list(*pkgs): python_requires='>=3.10', entry_points={"console_scripts": ["dicee=dicee.scripts.run:main", - "diceeindex=dicee.scripts.index:main", - "diceeserve=dicee.scripts.serve:main"]}, + "dicee_vector_db=dicee.scripts.index_serve:main"]}, long_description=long_description, long_description_content_type="text/markdown", )