Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Controlnet to 2.0 #2088

Open
wants to merge 1 commit into
base: sd-studio2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 132 additions & 6 deletions apps/shark_studio/api/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,29 @@
import os
import PIL
import numpy as np
from apps.shark_studio.modules.pipeline import SharkPipelineBase
from apps.shark_studio.web.utils.file_utils import (
get_generated_imgs_path,
)
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
clean_device_info,
get_iree_target_triple,
)
from apps.shark_studio.web.utils.file_utils import (
safe_name,
get_resource_path,
get_checkpoints_path,
)
import cv2
from datetime import datetime
from PIL import Image
from gradio.components.image_editor import (
EditorValue,
)

# from turbine_models.custom_models.sd_inference import export_controlnet_model, ControlNetModel
import gc

class control_adapter:
def __init__(
Expand All @@ -20,7 +34,14 @@ def __init__(
self.model = None

def export_control_adapter_model(model_keyword):
return None
if model_keyword == "canny":
return export_controlnet_model(
ControlNetModel("lllyasviel/control_v11p_sd15_canny"),
"lllyasviel/control_v11p_sd15_canny",
1,
512,
512,
)

def export_xl_control_adapter_model(model_keyword):
return None
Expand All @@ -36,9 +57,16 @@ def __init__(
def export_controlnet_model(model_keyword):
return None

ireec_flags = [
"--iree-flow-collapse-reduction-dims",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))",
"--iree-flow-inline-constants-max-byte-length=1" # Stopgap, take out when not needed
]

control_adapter_map = {
"sd15": {
"runwayml/stable-diffusion-v1-5": {
"canny": {"initializer": control_adapter.export_control_adapter_model},
"openpose": {"initializer": control_adapter.export_control_adapter_model},
"scribble": {"initializer": control_adapter.export_control_adapter_model},
Expand All @@ -64,14 +92,112 @@ def __init__(
):
self.model = hf_model_id
self.device = device
self.compiled_model = None

def compile(self):
if self.compiled_model is not None:
return
if "canny" in self.model:
return
if "openpose" in self.model:
pass
print("compile not implemented for preprocessor.")
return

def run(self, inputs):
print("run not implemented for preprocessor.")
return inputs
if self.compiled_model is None:
self.compile()
if "canny" in self.model:
out = cv2.Canny(*inputs)
return out
if "openpose" in self.model:
self.compiled_model(*inputs)

def __call__(self, *inputs):
return self.run(inputs)


class SharkControlnetPipeline(SharkPipelineBase):
def __init__(
self,
# model_map: dict,
# static_kwargs: dict,
device: str,
# import_mlir: bool = True,
):
self.model_map = control_adapter_map
self.pipe_map = {}
# self.static_kwargs = static_kwargs
self.static_kwargs = {}
self.triple = get_iree_target_triple(device)
self.device, self.device_id = clean_device_info(device)
self.import_mlir = False
self.iree_module_dict = {}
self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp"))
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
self.tempfiles = {}
self.pipe_vmfb_path = ""
self.ireec_flags = ireec_flags

def get_compiled_map(self, model, init_kwargs={}):
self.pipe_map[model] = {}
if model in self.iree_module_dict:
return
elif model not in self.tempfiles:
# if model in self.static_kwargs[model]:
# init_kwargs = self.static_kwargs[model]
init_kwargs = {}
# for key in self.static_kwargs["pipe"]:
# if key not in init_kwargs:
# init_kwargs[key] = self.static_kwargs["pipe"][key]
self.import_torch_ir(model, init_kwargs)
self.get_compiled_map(model)
else:
# weights_path = self.get_io_params(model)

self.iree_module_dict[model] = get_iree_compiled_module(
self.tempfiles[model],
device=self.device,
frontend="torch",
mmap=True,
# external_weight_file=weights_path,
external_weight_file=None,
extra_args=self.ireec_flags,
write_to=os.path.join(self.pipe_vmfb_path, model + ".vmfb")
)

def import_torch_ir(self, model, kwargs):
# torch_ir = self.model_map[model]["initializer"](
# **self.safe_dict(kwargs), compile_to="torch"
# )
tmp_kwargs = {
"model_keyword": "canny"
}
torch_ir = self.model_map["sd15"][model]["initializer"](
**self.safe_dict(tmp_kwargs) #, compile_to="torch"
)

self.tempfiles[model] = os.path.join(
self.tmp_dir, f"{model}.torch.tempfile"
)

with open(self.tempfiles[model], "w+") as f:
f.write(torch_ir)
del torch_ir
gc.collect()
return

def get_precompiled(self, model):
vmfbs = []
for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path):
vmfbs.extend(filenames)
break
for file in vmfbs:
if model in file:
self.pipe_map[model]["vmfb_path"] = os.path.join(
self.pipe_vmfb_path, file
)
return


def cnet_preview(model, input_image):
Expand Down
86 changes: 70 additions & 16 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path
from random import randint
from turbine_models.custom_models.sd_inference import clip, unet, vae
from apps.shark_studio.api.controlnet import control_adapter_map
from apps.shark_studio.api.controlnet import control_adapter_map, SharkControlnetPipeline
from apps.shark_studio.web.utils.state import status_label
from apps.shark_studio.web.utils.file_utils import (
safe_name,
Expand Down Expand Up @@ -112,10 +112,10 @@ def __init__(
"unet": {
"hf_model_name": base_model_id,
"unet_model": unet.UnetModel(
hf_model_name=base_model_id, hf_auth_token=None
hf_model_name=base_model_id, hf_auth_token=None, is_controlled=False,
),
"batch_size": batch_size,
# "is_controlled": is_controlled,
"is_controlled": is_controlled,
# "num_loras": num_loras,
"height": height,
"width": width,
Expand All @@ -126,7 +126,6 @@ def __init__(
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(
hf_model_name=base_model_id,
custom_vae=custom_vae,
),
"batch_size": batch_size,
"height": height,
Expand All @@ -137,7 +136,6 @@ def __init__(
"hf_model_name": base_model_id,
"vae_model": vae.VaeModel(
hf_model_name=base_model_id,
custom_vae=custom_vae,
),
"batch_size": batch_size,
"height": height,
Expand All @@ -163,6 +161,7 @@ def __init__(
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
del static_kwargs
gc.collect()
self.controlnet = SharkControlnetPipeline(device)

def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
print(f"\n[LOG] Preparing pipeline...")
Expand Down Expand Up @@ -291,6 +290,7 @@ def produce_img_latents(
mask=None,
masked_image_latents=None,
return_all_latents=False,
controlnet_latents=None
):
# self.status = SD_STATE_IDLE
step_time_sum = 0
Expand All @@ -299,6 +299,7 @@ def produce_img_latents(
text_embeddings_numpy = text_embeddings.detach().numpy()
guidance_scale = torch.Tensor([guidance_scale]).to(self.dtype)
self.load_submodels(["unet"])
control_scale = torch.tensor(1.0, dtype=self.dtype)
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(self.dtype).detach().numpy()
Expand All @@ -319,15 +320,52 @@ def produce_img_latents(

# Profiling Unet.
# profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.run(
"unet",
[
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
],
)
if controlnet_latents is None:
noise_pred = self.run(
"unet",
[
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
],
)
else:
noise_pred = self.run(
"unet",
[
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
controlnet_latents[0],
controlnet_latents[1],
controlnet_latents[2],
controlnet_latents[3],
controlnet_latents[4],
controlnet_latents[5],
controlnet_latents[6],
controlnet_latents[7],
controlnet_latents[8],
controlnet_latents[9],
controlnet_latents[10],
controlnet_latents[11],
controlnet_latents[12],
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
control_scale,
],
)
# end_profiling(profile_device)

if cpu_scheduling:
Expand Down Expand Up @@ -388,6 +426,7 @@ def generate_images(
repeatable_seeds,
resample_type,
control_mode,
controlnet_models,
hints,
):
# TODO: Batched args
Expand Down Expand Up @@ -432,12 +471,24 @@ def generate_images(
strength=strength,
)

hints = [Image.load_file(x) for x in hints]
controlnet_latents = None
for (model, hint) in zip(controlnet_models, hints):
# if model not in self.controlnets:
# continue
self.controlnet.get_compiled_map("canny")
latent = self.controlnets[model].run(hint)
if controlnet_latents is None:
controlnet_latents = latent
break

latents = self.produce_img_latents(
latents=init_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=final_timesteps,
cpu_scheduling=True, # until we have schedulers through Turbine
controlnet_latents=controlnet_latents,
)

# Img latents -> PIL images
Expand Down Expand Up @@ -511,6 +562,7 @@ def shark_sd_fn(
is_controlled = False
control_mode = None
hints = []
controlnet_models = []
num_loras = 0
for i in embeddings:
num_loras += 1 if embeddings[i] else 0
Expand All @@ -525,16 +577,17 @@ def shark_sd_fn(
}
else:
adapters[f"control_adapter_{model}"] = {
"hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][
"hf_id": +["stabilityai/stable-diffusion-xl-1.0"][
model
],
"strength": controlnets["strength"][i],
}
if model is not None:
is_controlled = True
controlnet_models.append(model)
control_mode = controlnets["control_mode"]
for i in controlnets["hint"]:
hints.append[i]
hints.append(i)

submit_pipe_kwargs = {
"base_model_id": base_model_id,
Expand Down Expand Up @@ -567,6 +620,7 @@ def shark_sd_fn(
"repeatable_seeds": repeatable_seeds,
"resample_type": resample_type,
"control_mode": control_mode,
"controlnet_models": controlnet_models,
"hints": hints,
}
if (
Expand Down
2 changes: 1 addition & 1 deletion apps/shark_studio/modules/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from msvcrt import kbhit
# from msvcrt import kbhit
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
Expand Down
Loading