-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #289 from dice-group/literal_example
dicee_vector_db included
- Loading branch information
Showing
5 changed files
with
177 additions
and
119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
""" | ||
$ docker pull qdrant/qdrant && docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant | ||
$ dicee_vector_db --index --serve --path CountryEmbeddings --collection "countries_vdb" | ||
""" | ||
import argparse | ||
import os | ||
import numpy as np | ||
import pandas as pd | ||
from qdrant_client import QdrantClient | ||
from qdrant_client.http.models import Distance, VectorParams | ||
from qdrant_client.http.models import PointStruct | ||
|
||
from fastapi import FastAPI | ||
import uvicorn | ||
from pydantic import BaseModel | ||
from typing import List, Optional | ||
|
||
def get_default_arguments(): | ||
parser = argparse.ArgumentParser(add_help=False) | ||
parser.add_argument("--path", type=str, required=True, | ||
help="The path of a directory containing embedding csv file(s)") | ||
parser.add_argument("--index", action="store_true", help="A flag for indexing") | ||
parser.add_argument("--serve", action="store_true", help="A flag for serving") | ||
|
||
parser.add_argument("--collection", type=str, required=True,help="Named of the vector database collection") | ||
|
||
parser.add_argument("--vdb_host", type=str,default="localhost",help="Host of qdrant vector database") | ||
parser.add_argument("--vdb_port", type=int,default=6333,help="port number") | ||
parser.add_argument("--host",type=str, default="0.0.0.0",help="Host") | ||
parser.add_argument("--port", type=int, default=8000,help="port number") | ||
return parser.parse_args() | ||
|
||
|
||
def index(args): | ||
client = QdrantClient(host=args.vdb_host, port=args.vdb_port) | ||
entity_to_idx = pd.read_csv(args.path + "/entity_to_idx.csv", index_col=0) | ||
assert entity_to_idx.index.is_monotonic_increasing, "Entity Index must be monotonically increasing!{}" | ||
entity_to_idx = {name: idx for idx, name in enumerate(entity_to_idx["entity"].tolist())} | ||
|
||
csv_files_holding_embeddings = [args.path + "/" + f for f in os.listdir(args.path) if "entity_embeddings.csv" in f] | ||
assert len( | ||
csv_files_holding_embeddings) == 1, f"There must be only single csv file containing entity_embeddings.csv prefix. Currently, :{len(csv_files_holding_embeddings)}" | ||
path_entity_embeddings_csv = csv_files_holding_embeddings[0] | ||
|
||
points = [] | ||
embedding_dim = None | ||
for ith_row, (index_name, embedding_vector) in enumerate( | ||
pd.read_csv(path_entity_embeddings_csv, index_col=0, header=0).iterrows()): | ||
index_name: str | ||
embedding_vector: np.ndarray | ||
embedding_vector = embedding_vector.values | ||
|
||
points.append(PointStruct(id=entity_to_idx[index_name], | ||
vector=embedding_vector, | ||
payload={"name": index_name})) | ||
|
||
embedding_dim = len(embedding_vector) | ||
|
||
assert embedding_dim > 0 | ||
# If the collection is not created, create it | ||
if args.collection in [i.name for i in client.get_collections().collections]: | ||
print("Deleting existing collection ", args.collection) | ||
client.delete_collection(collection_name=args.collection) | ||
|
||
print(f"Creating a collection {args.collection} with distance metric:Cosine") | ||
client.create_collection(collection_name=args.collection, | ||
vectors_config=VectorParams(size=embedding_dim, | ||
distance=Distance.COSINE)) | ||
client.upsert(collection_name=args.collection, points=points) | ||
print("Completed!") | ||
|
||
|
||
app = FastAPI() | ||
# Create a neural searcher instance | ||
neural_searcher = None | ||
|
||
class NeuralSearcher: | ||
def __init__(self, args): | ||
self.collection_name = args.collection | ||
assert os.path.exists(args.path + "/entity_to_idx.csv"), f"{args.path + '/entity_to_idx.csv'} does not exist!" | ||
self.entity_to_idx = pd.read_csv(args.path + "/entity_to_idx.csv", index_col=0) | ||
assert self.entity_to_idx.index.is_monotonic_increasing, "Entity Index must be monotonically increasing!{}" | ||
self.entity_to_idx = {name: idx for idx, name in enumerate(self.entity_to_idx["entity"].tolist())} | ||
# initialize Qdrant client | ||
self.qdrant_client = QdrantClient(host=args.vdb_host,port=args.vdb_port) | ||
# semantic search | ||
self.topk=5 | ||
|
||
def retrieve_embedding(self,entity:str=None,entities:List[str]=None)->List: | ||
ids=[] | ||
inputs= [entity] | ||
if entities is not None: | ||
inputs.extend(entities) | ||
for ent in inputs: | ||
if idx := self.entity_to_idx.get(ent, None): | ||
assert isinstance(idx, int) | ||
ids.append(idx) | ||
if len(ids)<1: | ||
return {"error":f"IDs are not found for ({entity} or {entities})"} | ||
else: | ||
return [{"name": result.payload["name"], "vector": result.vector} for result in self.qdrant_client.retrieve(collection_name=self.collection_name,ids=ids, with_vectors=True)] | ||
|
||
def search(self, entity: str): | ||
return self.qdrant_client.query_points(collection_name=self.collection_name, query=self.entity_to_idx[entity],limit=self.topk) | ||
|
||
@app.get("/") | ||
async def root(): | ||
return {"message": "Hello Dice Embedding User"} | ||
|
||
@app.get("/api/search") | ||
async def search_embeddings(q: str): | ||
return {"result": neural_searcher.search(entity=q)} | ||
|
||
@app.get("/api/get") | ||
async def retrieve_embeddings(q: str): | ||
return {"result": neural_searcher.retrieve_embedding(entity=q)} | ||
|
||
class StringListRequest(BaseModel): | ||
queries: List[str] | ||
reducer: Optional[str] = None # Add the reducer flag with default as None | ||
|
||
|
||
@app.post("/api/search_batch") | ||
async def search_embeddings_batch(request: StringListRequest): | ||
if request.reducer == "mean": | ||
names=[] | ||
vectors=[] | ||
for result in neural_searcher.retrieve_embedding(entities=request.queries): | ||
names.append(result["name"]) | ||
vectors.append(result["vector"]) | ||
embeddings = np.mean(vectors, axis=0).tolist() # Reduce to mean | ||
return {"results":{"name":names,"vectors":embeddings}} | ||
else: | ||
return {"results": neural_searcher.retrieve_embedding(entities=request.queries)} | ||
|
||
def serve(args): | ||
global neural_searcher | ||
neural_searcher = NeuralSearcher(args) | ||
uvicorn.run(app, host=args.host, port=args.port) | ||
|
||
def main(): | ||
args = get_default_arguments() | ||
if args.index: | ||
index(args) | ||
if args.serve: | ||
serve(args) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters