Skip to content

Commit

Permalink
Merge pull request #289 from dice-group/literal_example
Browse files Browse the repository at this point in the history
dicee_vector_db included
  • Loading branch information
Demirrr authored Dec 4, 2024
2 parents e4a0bd1 + 9cdd418 commit 00aa49d
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 119 deletions.
48 changes: 26 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,42 +178,46 @@ _:1 <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <http://www.w3.org/2002/07
dicee --continual_learning "KeciFamilyRun" --path_single_kg "KGs/Family/family-benchmark_rich_background.owl" --model Keci --backend rdflib --eval_model None
```


</details>

## Creating an Embedding Vector Database
## Search and Retrieval via Qdrant Vector Database

<details> <summary> To see a code snippet </summary>

##### Learning Embeddings
```bash
# Train an embedding model
dicee --dataset_dir KGs/Countries-S1 --path_to_store_single_run CountryEmbeddings --model Keci --p 0 --q 1 --embedding_dim 32 --adaptive_swa
dicee --dataset_dir KGs/Countries-S1 --path_to_store_single_run CountryEmbeddings --model Keci --p 0 --q 1 --embedding_dim 256 --scoring_technique AllvsAll --num_epochs 300 --save_embeddings_as_csv
```
#### Loading Embeddings into Qdrant Vector Database
Start qdrant instance.

```bash
# Ensure that Qdrant available
# docker pull qdrant/qdrant && docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant
diceeindex --path_model "CountryEmbeddings" --collection_name "dummy" --location "localhost"
pip3 install fastapi uvicorn qdrant-client
docker pull qdrant/qdrant && docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant
```
#### Launching Webservice
Upload Embeddings into vector database and start a webservice
```bash
diceeserve --path_model "CountryEmbeddings" --collection_name "dummy" --collection_location "localhost"
```
##### Retrieve and Search

Get embedding of germany
dicee_vector_db --index --serve --path CountryEmbeddings --collection "countries_vdb"
Creating a collection countries_vdb with distance metric:Cosine
Completed!
INFO: Started server process [28953]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
```
Retrieve an embedding vector.
```bash
curl -X 'GET' 'http://0.0.0.0:8000/api/get?q=germany' -H 'accept: application/json'
# {"result": [{"name": "europe","vector": [...]}]}
```

Get most similar things to europe
Retrieve embedding vectors.
```bash
curl -X 'POST' 'http://0.0.0.0:8000/api/search_batch' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"queries": ["brunei","guam"]}'
# {"results": [{ "name": "europe","vector": [...]},{ "name": "northern_europe","vector": [...]}]}
```
Retrieve an average of embedding vectors.
```bash
curl -X 'GET' 'http://0.0.0.0:8000/api/search?q=europe' -H 'accept: application/json'
{"result":[{"hit":"europe","score":1.0},
{"hit":"northern_europe","score":0.67126536},
{"hit":"western_europe","score":0.6010134},
{"hit":"puerto_rico","score":0.5051694},
{"hit":"southern_europe","score":0.4829831}]}
curl -X 'POST' 'http://0.0.0.0:8000/api/search_batch' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"queries": ["europe","northern_europe"],"reducer": "mean"}'
# {"results":{"name": ["europe","northern_europe"],"vectors": [...]}}
```

</details>
Expand Down
31 changes: 0 additions & 31 deletions dicee/scripts/index.py

This file was deleted.

150 changes: 150 additions & 0 deletions dicee/scripts/index_serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
$ docker pull qdrant/qdrant && docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant
$ dicee_vector_db --index --serve --path CountryEmbeddings --collection "countries_vdb"
"""
import argparse
import os
import numpy as np
import pandas as pd
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http.models import PointStruct

from fastapi import FastAPI
import uvicorn
from pydantic import BaseModel
from typing import List, Optional

def get_default_arguments():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--path", type=str, required=True,
help="The path of a directory containing embedding csv file(s)")
parser.add_argument("--index", action="store_true", help="A flag for indexing")
parser.add_argument("--serve", action="store_true", help="A flag for serving")

parser.add_argument("--collection", type=str, required=True,help="Named of the vector database collection")

parser.add_argument("--vdb_host", type=str,default="localhost",help="Host of qdrant vector database")
parser.add_argument("--vdb_port", type=int,default=6333,help="port number")
parser.add_argument("--host",type=str, default="0.0.0.0",help="Host")
parser.add_argument("--port", type=int, default=8000,help="port number")
return parser.parse_args()


def index(args):
client = QdrantClient(host=args.vdb_host, port=args.vdb_port)
entity_to_idx = pd.read_csv(args.path + "/entity_to_idx.csv", index_col=0)
assert entity_to_idx.index.is_monotonic_increasing, "Entity Index must be monotonically increasing!{}"
entity_to_idx = {name: idx for idx, name in enumerate(entity_to_idx["entity"].tolist())}

csv_files_holding_embeddings = [args.path + "/" + f for f in os.listdir(args.path) if "entity_embeddings.csv" in f]
assert len(
csv_files_holding_embeddings) == 1, f"There must be only single csv file containing entity_embeddings.csv prefix. Currently, :{len(csv_files_holding_embeddings)}"
path_entity_embeddings_csv = csv_files_holding_embeddings[0]

points = []
embedding_dim = None
for ith_row, (index_name, embedding_vector) in enumerate(
pd.read_csv(path_entity_embeddings_csv, index_col=0, header=0).iterrows()):
index_name: str
embedding_vector: np.ndarray
embedding_vector = embedding_vector.values

points.append(PointStruct(id=entity_to_idx[index_name],
vector=embedding_vector,
payload={"name": index_name}))

embedding_dim = len(embedding_vector)

assert embedding_dim > 0
# If the collection is not created, create it
if args.collection in [i.name for i in client.get_collections().collections]:
print("Deleting existing collection ", args.collection)
client.delete_collection(collection_name=args.collection)

print(f"Creating a collection {args.collection} with distance metric:Cosine")
client.create_collection(collection_name=args.collection,
vectors_config=VectorParams(size=embedding_dim,
distance=Distance.COSINE))
client.upsert(collection_name=args.collection, points=points)
print("Completed!")


app = FastAPI()
# Create a neural searcher instance
neural_searcher = None

class NeuralSearcher:
def __init__(self, args):
self.collection_name = args.collection
assert os.path.exists(args.path + "/entity_to_idx.csv"), f"{args.path + '/entity_to_idx.csv'} does not exist!"
self.entity_to_idx = pd.read_csv(args.path + "/entity_to_idx.csv", index_col=0)
assert self.entity_to_idx.index.is_monotonic_increasing, "Entity Index must be monotonically increasing!{}"
self.entity_to_idx = {name: idx for idx, name in enumerate(self.entity_to_idx["entity"].tolist())}
# initialize Qdrant client
self.qdrant_client = QdrantClient(host=args.vdb_host,port=args.vdb_port)
# semantic search
self.topk=5

def retrieve_embedding(self,entity:str=None,entities:List[str]=None)->List:
ids=[]
inputs= [entity]
if entities is not None:
inputs.extend(entities)
for ent in inputs:
if idx := self.entity_to_idx.get(ent, None):
assert isinstance(idx, int)
ids.append(idx)
if len(ids)<1:
return {"error":f"IDs are not found for ({entity} or {entities})"}
else:
return [{"name": result.payload["name"], "vector": result.vector} for result in self.qdrant_client.retrieve(collection_name=self.collection_name,ids=ids, with_vectors=True)]

def search(self, entity: str):
return self.qdrant_client.query_points(collection_name=self.collection_name, query=self.entity_to_idx[entity],limit=self.topk)

@app.get("/")
async def root():
return {"message": "Hello Dice Embedding User"}

@app.get("/api/search")
async def search_embeddings(q: str):
return {"result": neural_searcher.search(entity=q)}

@app.get("/api/get")
async def retrieve_embeddings(q: str):
return {"result": neural_searcher.retrieve_embedding(entity=q)}

class StringListRequest(BaseModel):
queries: List[str]
reducer: Optional[str] = None # Add the reducer flag with default as None


@app.post("/api/search_batch")
async def search_embeddings_batch(request: StringListRequest):
if request.reducer == "mean":
names=[]
vectors=[]
for result in neural_searcher.retrieve_embedding(entities=request.queries):
names.append(result["name"])
vectors.append(result["vector"])
embeddings = np.mean(vectors, axis=0).tolist() # Reduce to mean
return {"results":{"name":names,"vectors":embeddings}}
else:
return {"results": neural_searcher.retrieve_embedding(entities=request.queries)}

def serve(args):
global neural_searcher
neural_searcher = NeuralSearcher(args)
uvicorn.run(app, host=args.host, port=args.port)

def main():
args = get_default_arguments()
if args.index:
index(args)
if args.serve:
serve(args)

if __name__ == '__main__':
main()
64 changes: 0 additions & 64 deletions dicee/scripts/serve.py

This file was deleted.

3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def deps_list(*pkgs):
python_requires='>=3.10',
entry_points={"console_scripts":
["dicee=dicee.scripts.run:main",
"diceeindex=dicee.scripts.index:main",
"diceeserve=dicee.scripts.serve:main"]},
"dicee_vector_db=dicee.scripts.index_serve:main"]},
long_description=long_description,
long_description_content_type="text/markdown",
)

0 comments on commit 00aa49d

Please sign in to comment.