Skip to content

Commit

Permalink
Add stable diffusion integration (#1240)
Browse files Browse the repository at this point in the history
Reopen the #1111.

---------

Co-authored-by: sudoboi <[email protected]>
Co-authored-by: Abhijith S Raj <[email protected]>
  • Loading branch information
3 people authored Oct 15, 2023
1 parent 6a0cd76 commit bf02232
Show file tree
Hide file tree
Showing 13 changed files with 710 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ parts:
title: OpenAI
- file: source/reference/ai/yolo
title: YOLO
- file: source/reference/ai/stablediffusion
title: Stable Diffusion

- file: source/reference/ai/custom-ai-function
title: Bring Your Own AI Function
Expand Down
27 changes: 27 additions & 0 deletions docs/source/reference/ai/stablediffusion.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
.. _stablediffusion:

Stable Diffusion Models
======================================

This section provides an overview of how you can generate images from prompts in EvaDB using a Stable Diffusion model.


Introduction
------------

Stable Diffusion models leverage a controlled random walk process to generate intricate patterns and images from textual prompts,
bridging the gap between text and visual representation. EvaDB uses the stable diffusion implementation from `Replicate <https://replicate.com>`_.

Stable Diffusion UDF
--------------------

In order to create an image generation function in EvaDB, use the following SQL command:

.. code-block:: sql
CREATE FUNCTION IF NOT EXISTS StableDiffusion
IMPL 'evadb/functions/stable_diffusion.py';
EvaDB automatically uses the latest `stable diffusion release <https://replicate.com/stability-ai/stable-diffusion/versions>`_ available on Replicate.

To see a demo of how the function can be used, please check the `demo notebook <https://colab.research.google.com/github/georgia-tech-db/eva/blob/master/tutorials/18-stable-diffusion.ipynb>`_ on stable diffusion.
1 change: 1 addition & 0 deletions evadb/evadb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ third_party:
OPENAI_KEY: ""
PINECONE_API_KEY: ""
PINECONE_ENV: ""
REPLICATE_API_TOKEN: ""
88 changes: 88 additions & 0 deletions evadb/functions/dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from io import BytesIO

import numpy as np
import pandas as pd
import requests
from PIL import Image

from evadb.catalog.catalog_type import NdArrayType
from evadb.configuration.configuration_manager import ConfigurationManager
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.utils.generic_utils import try_to_import_openai


class DallEFunction(AbstractFunction):
@property
def name(self) -> str:
return "DallE"

def setup(self) -> None:
pass

@forward(
input_signatures=[
PandasDataframe(
columns=["prompt"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(None,)],
)
],
output_signatures=[
PandasDataframe(
columns=["response"],
column_types=[NdArrayType.FLOAT32],
column_shapes=[(None, None, 3)],
)
],
)
def forward(self, text_df):
try_to_import_openai()
import openai

# Register API key, try configuration manager first
openai.api_key = ConfigurationManager().get_value("third_party", "OPENAI_KEY")
# If not found, try OS Environment Variable
if len(openai.api_key) == 0:
openai.api_key = os.environ.get("OPENAI_KEY", "")
assert (
len(openai.api_key) != 0
), "Please set your OpenAI API key in evadb.yml file (third_party, open_api_key) or environment variable (OPENAI_KEY)"

def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
response = openai.Image.create(prompt=query, n=1, size="1024x1024")

# Download the image from the link
image_response = requests.get(response["data"][0]["url"])
image = Image.open(BytesIO(image_response.content))

# Convert the image to an array format suitable for the DataFrame
frame = np.array(image)
results.append(frame)

return results

df = pd.DataFrame({"response": generate_image(text_df=text_df)})
return df
14 changes: 14 additions & 0 deletions evadb/functions/function_bootstrap_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@
MODEL 'yolov8n.pt';
"""

stablediffusion_function_query = """CREATE FUNCTION IF NOT EXISTS StableDiffusion
IMPL '{}/functions/stable_diffusion.py';
""".format(
EvaDB_INSTALLATION_DIR
)

dalle_function_query = """CREATE FUNCTION IF NOT EXISTS DallE
IMPL '{}/functions/dalle.py';
""".format(
EvaDB_INSTALLATION_DIR
)


def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
"""Load the built-in functions into the system during system bootstrapping.
Expand Down Expand Up @@ -247,6 +259,8 @@ def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
# Mvit_function_query,
Sift_function_query,
Yolo_function_query,
stablediffusion_function_query,
dalle_function_query,
]

# if mode is 'debug', add debug functions
Expand Down
102 changes: 102 additions & 0 deletions evadb/functions/stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from io import BytesIO

import numpy as np
import pandas as pd
import requests
from PIL import Image

from evadb.catalog.catalog_type import NdArrayType
from evadb.configuration.configuration_manager import ConfigurationManager
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.utils.generic_utils import try_to_import_replicate


class StableDiffusion(AbstractFunction):
@property
def name(self) -> str:
return "StableDiffusion"

def setup(
self,
) -> None:
pass

@forward(
input_signatures=[
PandasDataframe(
columns=["prompt"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(None,)],
)
],
output_signatures=[
PandasDataframe(
columns=["response"],
column_types=[
# FileFormatType.IMAGE,
NdArrayType.FLOAT32
],
column_shapes=[(None, None, 3)],
)
],
)
def forward(self, text_df):
try_to_import_replicate()
import replicate

# Register API key, try configuration manager first
replicate_api_key = ConfigurationManager().get_value(
"third_party", "REPLICATE_API_TOKEN"
)
# If not found, try OS Environment Variable
if len(replicate_api_key) == 0:
replicate_api_key = os.environ.get("REPLICATE_API_TOKEN", "")
assert (
len(replicate_api_key) != 0
), "Please set your Replicate API key in evadb.yml file (third_party, replicate_api_token) or environment variable (REPLICATE_API_TOKEN)"
os.environ["REPLICATE_API_TOKEN"] = replicate_api_key

model_id = (
replicate.models.get("stability-ai/stable-diffusion").versions.list()[0].id
)

def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
output = replicate.run(
"stability-ai/stable-diffusion:" + model_id, input={"prompt": query}
)

# Download the image from the link
response = requests.get(output[0])
image = Image.open(BytesIO(response.content))

# Convert the image to an array format suitable for the DataFrame
frame = np.array(image)
results.append(frame)

return results

df = pd.DataFrame({"response": generate_image(text_df=text_df)})
return df
18 changes: 18 additions & 0 deletions evadb/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,3 +629,21 @@ def string_comparison_case_insensitive(string_1, string_2) -> bool:
return False

return string_1.lower() == string_2.lower()


def try_to_import_replicate():
try:
import replicate # noqa: F401
except ImportError:
raise ValueError(
"""Could not import replicate python package.
Please install it with `pip install replicate`."""
)


def is_replicate_available():
try:
try_to_import_replicate()
return True
except ValueError:
return False
2 changes: 1 addition & 1 deletion script/test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ long_integration_test() {
}

notebook_test() {
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb"
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb" --ignore="tutorials/18-stable-diffusion.ipynb"
code=$?
print_error_code $code "NOTEBOOK TEST"
}
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def read(path, encoding="utf-8"):
"neuralforecast" # MODEL TRAIN AND FINE TUNING
]

imagegen_libs = [
"replicate"
]

### NEEDED FOR DEVELOPER TESTING ONLY

dev_libs = [
Expand Down Expand Up @@ -167,7 +171,7 @@ def read(path, encoding="utf-8"):
"sklearn": sklearn_libs,
"forecasting": forecasting_libs,
# everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11.
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs,
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs,
}

setup(
Expand Down
68 changes: 68 additions & 0 deletions test/integration_tests/long/functions/test_stablediffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from test.markers import stable_diffusion_skip_marker
from test.util import get_evadb_for_testing

import numpy as np
import pytest

from evadb.server.command_handler import execute_query_fetch_all


class StableDiffusionTest(unittest.TestCase):
def setUp(self) -> None:
self.evadb = get_evadb_for_testing()
self.evadb.catalog().reset()
create_table_query = """CREATE TABLE IF NOT EXISTS ImageGen (
prompt TEXT);
"""
execute_query_fetch_all(self.evadb, create_table_query)

test_prompts = ["pink cat riding a rocket to the moon"]

for prompt in test_prompts:
insert_query = f"""INSERT INTO ImageGen (prompt) VALUES ('{prompt}')"""
execute_query_fetch_all(self.evadb, insert_query)

def tearDown(self) -> None:
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;")

@stable_diffusion_skip_marker
@pytest.mark.xfail(
reason="API call might be flaky due to rate limits or other issues."
)
def test_stable_diffusion_image_generation(self):
function_name = "StableDiffusion"

execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")

create_function_query = f"""CREATE FUNCTION IF NOT EXISTS {function_name}
IMPL 'evadb/functions/stable_diffusion.py';
"""
execute_query_fetch_all(self.evadb, create_function_query)

gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;"
output_batch = execute_query_fetch_all(self.evadb, gpt_query)

self.assertEqual(output_batch.columns, ["stablediffusion.response"])

# Check if the returned data is an np.array representing an image
img_data = output_batch.frames["stablediffusion.response"][0]
self.assertIsInstance(img_data, np.ndarray)
self.assertEqual(
img_data.shape[2], 3
) # Check if the image has 3 channels (RGB)
5 changes: 5 additions & 0 deletions test/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
is_ludwig_available,
is_pinecone_available,
is_qdrant_available,
is_replicate_available,
is_sklearn_available,
)

Expand Down Expand Up @@ -96,3 +97,7 @@
is_forecast_available() is False,
reason="Run only if forecasting packages available",
)

stable_diffusion_skip_marker = pytest.mark.skipif(
is_replicate_available() is False, reason="requires replicate"
)
Loading

0 comments on commit bf02232

Please sign in to comment.