Skip to content

Commit

Permalink
Christy demo (#1342)
Browse files Browse the repository at this point in the history
* Update notebooks to use Milvus Lite

Signed-off-by: Christy Bergman <[email protected]>

* Update notebooks to use Milvus Lite

Signed-off-by: Christy Bergman <[email protected]>

* Update notebooks to use Milvus Lite

Signed-off-by: Christy Bergman <[email protected]>

* Add multimodal demo

---------

Signed-off-by: Christy Bergman <[email protected]>
  • Loading branch information
christy authored May 29, 2024
1 parent d38de1d commit 8451217
Show file tree
Hide file tree
Showing 2 changed files with 1,203 additions and 0 deletions.
166 changes: 166 additions & 0 deletions bootcamp/workshops/multimodal/multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import uform
from uform import get_model, Modality
import requests
from io import BytesIO
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np

import pymilvus, time
from pymilvus import (
MilvusClient, utility, connections,
FieldSchema, CollectionSchema, DataType, IndexType,
Collection, AnnSearchRequest, RRFRanker, WeightedRanker
)
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Use the light-weight portable ONNX model.
# Available combinations: cpu & fp32, gpu & fp32, gpu & fp16.
# See Unum's Hugging Face space for more details:
# https://huggingface.co/unum-cloud

# Define a class to compute embeddings.
class ComputeEmbeddings:
def __init__(self, modelname):
# Load the pre-trained model.
self.model_name = modelname
self.modalities = [Modality.TEXT_ENCODER, Modality.IMAGE_ENCODER]

# Get the preprocessing function for the model.
self.processors, self.models = get_model(self.model_name, modalities=self.modalities)

# Get the text and image encoders.
self.model_image = self.models[Modality.IMAGE_ENCODER]
self.model_text = self.models[Modality.TEXT_ENCODER]
self.processor_image = self.processors[Modality.IMAGE_ENCODER]
self.processor_text = self.processors[Modality.TEXT_ENCODER]

def __call__(self, batch_images=[], batch_texts=[]):

img_converted_values = []
text_converted_values = []

# Encode a batch of images.
if len(batch_images) > 0:

# Process the images into embeddings.
image_data = self.processor_image(batch_images)
image_embeddings = self.model_image.encode(image_data, return_features=False)

# Milvus requires list of `np.ndarray` arrays of `np.float32` numbers.
img_converted_values = list(map(np.float32, image_embeddings))
assert isinstance(img_converted_values, list)
assert isinstance(img_converted_values[0], np.ndarray)
assert isinstance(img_converted_values[0][0], np.float32)

# Encode a batch of texts.
if len(batch_texts) > 0:

# Process the texts into embeddings.
text_data = self.processor_text(batch_texts)
text_embeddings = self.model_text.encode(text_data, return_features=False)

# Milvus requires list of `np.ndarray` arrays of `np.float32` numbers.
text_converted_values = list(map(np.float32, text_embeddings))
assert isinstance(text_converted_values, list)
assert isinstance(text_converted_values[0], np.ndarray)
assert isinstance(text_converted_values[0][0], np.float32)

return img_converted_values, text_converted_values


# Define a convenience search function.
def multi_modal_search(query_text, query_image,
embedding_model, col,
output_fields,
text_only=False,
image_only=False,
top_k=2):

# Embed the question using the same encoder.
query_img_embeddings, query_text_embeddings = \
embedding_model(
batch_images=[query_image],
batch_texts=[query_text])

# Prepare the search requests for both vector fields
image_search_params = {"metric_type": "COSINE"}
image_req = AnnSearchRequest(
query_img_embeddings,
"image_vector", image_search_params, limit=top_k)

text_search_params = {"metric_type": "COSINE"}
text_req = AnnSearchRequest(
query_text_embeddings,
"text_vector", text_search_params, limit=top_k)

# Run semantic vector search using Milvus.
start_time = time.time()

# User gave an image query only.
if image_only:
results = col.hybrid_search(
reqs=[image_req, text_req],
rerank=WeightedRanker(1.0, 0.0),
limit=top_k,
output_fields=output_fields
)

# User gave a text query only.
elif text_only:
results = col.hybrid_search(
reqs=[image_req, text_req],
rerank=WeightedRanker(0.0, 1.0),
limit=top_k,
output_fields=output_fields
)

# Use the both the text and image part of query.
else:
results = col.hybrid_search(
reqs=[image_req, text_req],
rerank=RRFRanker(),
limit=top_k,
output_fields=output_fields)

elapsed_time = time.time() - start_time
# print(f"Milvus Client search time for {len(dict_list)} vectors: {elapsed_time} seconds")
print(f"Milvus search time: {elapsed_time} seconds")

# Currently Milvus only support 1 query in the same hybrid search request, so
# we inspect res[0] directly. In future release Milvus will accept batch
# hybrid search queries in the same call.
results = results[0]

# Display the images 2x2.
if text_only:
plt.suptitle(f"Query: {query_text}")
elif image_only:
plt.suptitle(f"Query: using image on the left")
else:
plt.suptitle(f"Query: {query_text} AND image on the right")

# Display 2x2 grid of images.
num_rows = int(round(top_k/2,0))
if top_k == 2:
plt.figure(figsize=(10,5))
for i, result in enumerate(results):
with Image.open(f"./images/{result.entity.image_filepath}.jpg") as img:
plt.subplot(1, 2, i+1)
plt.imshow(img)
plt.title(f"COSINE distance: {round(result.distance,4)}")
plt.axis('off')
else:
plt.figure(figsize=(10,10))
for i, result in enumerate(results):
with Image.open(f"./images/{result.entity.image_filepath}.jpg") as img:
plt.subplot(num_rows, 2, i+1)
plt.imshow(img)
plt.title(f"COSINE distance: {round(result.distance,4)}")
plt.axis('off')
plt.tight_layout()
plt.show()

return results
Loading

0 comments on commit 8451217

Please sign in to comment.