diff --git a/apps/shark_studio/api/controlnet.py b/apps/shark_studio/api/controlnet.py index 2c8a8b566b..58ecafc1b9 100644 --- a/apps/shark_studio/api/controlnet.py +++ b/apps/shark_studio/api/controlnet.py @@ -2,15 +2,29 @@ import os import PIL import numpy as np +from apps.shark_studio.modules.pipeline import SharkPipelineBase from apps.shark_studio.web.utils.file_utils import ( get_generated_imgs_path, ) +from shark.iree_utils.compile_utils import ( + get_iree_compiled_module, + load_vmfb_using_mmap, + clean_device_info, + get_iree_target_triple, +) +from apps.shark_studio.web.utils.file_utils import ( + safe_name, + get_resource_path, + get_checkpoints_path, +) +import cv2 from datetime import datetime from PIL import Image from gradio.components.image_editor import ( EditorValue, ) - +# from turbine_models.custom_models.sd_inference import export_controlnet_model, ControlNetModel +import gc class control_adapter: def __init__( @@ -20,7 +34,14 @@ def __init__( self.model = None def export_control_adapter_model(model_keyword): - return None + if model_keyword == "canny": + return export_controlnet_model( + ControlNetModel("lllyasviel/control_v11p_sd15_canny"), + "lllyasviel/control_v11p_sd15_canny", + 1, + 512, + 512, + ) def export_xl_control_adapter_model(model_keyword): return None @@ -36,9 +57,16 @@ def __init__( def export_controlnet_model(model_keyword): return None +ireec_flags = [ + "--iree-flow-collapse-reduction-dims", + "--iree-opt-const-expr-hoisting=False", + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", + "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))", + "--iree-flow-inline-constants-max-byte-length=1" # Stopgap, take out when not needed +] control_adapter_map = { - "sd15": { + "runwayml/stable-diffusion-v1-5": { "canny": {"initializer": control_adapter.export_control_adapter_model}, "openpose": {"initializer": control_adapter.export_control_adapter_model}, "scribble": {"initializer": control_adapter.export_control_adapter_model}, @@ -64,14 +92,112 @@ def __init__( ): self.model = hf_model_id self.device = device + self.compiled_model = None def compile(self): + if self.compiled_model is not None: + return + if "canny" in self.model: + return + if "openpose" in self.model: + pass print("compile not implemented for preprocessor.") - return def run(self, inputs): - print("run not implemented for preprocessor.") - return inputs + if self.compiled_model is None: + self.compile() + if "canny" in self.model: + out = cv2.Canny(*inputs) + return out + if "openpose" in self.model: + self.compiled_model(*inputs) + + def __call__(self, *inputs): + return self.run(inputs) + + +class SharkControlnetPipeline(SharkPipelineBase): + def __init__( + self, + # model_map: dict, + # static_kwargs: dict, + device: str, + # import_mlir: bool = True, + ): + self.model_map = control_adapter_map + self.pipe_map = {} + # self.static_kwargs = static_kwargs + self.static_kwargs = {} + self.triple = get_iree_target_triple(device) + self.device, self.device_id = clean_device_info(device) + self.import_mlir = False + self.iree_module_dict = {} + self.tmp_dir = get_resource_path(os.path.join("..", "shark_tmp")) + if not os.path.exists(self.tmp_dir): + os.mkdir(self.tmp_dir) + self.tempfiles = {} + self.pipe_vmfb_path = "" + self.ireec_flags = ireec_flags + + def get_compiled_map(self, model, init_kwargs={}): + self.pipe_map[model] = {} + if model in self.iree_module_dict: + return + elif model not in self.tempfiles: + # if model in self.static_kwargs[model]: + # init_kwargs = self.static_kwargs[model] + init_kwargs = {} + # for key in self.static_kwargs["pipe"]: + # if key not in init_kwargs: + # init_kwargs[key] = self.static_kwargs["pipe"][key] + self.import_torch_ir(model, init_kwargs) + self.get_compiled_map(model) + else: + # weights_path = self.get_io_params(model) + + self.iree_module_dict[model] = get_iree_compiled_module( + self.tempfiles[model], + device=self.device, + frontend="torch", + mmap=True, + # external_weight_file=weights_path, + external_weight_file=None, + extra_args=self.ireec_flags, + write_to=os.path.join(self.pipe_vmfb_path, model + ".vmfb") + ) + + def import_torch_ir(self, model, kwargs): + # torch_ir = self.model_map[model]["initializer"]( + # **self.safe_dict(kwargs), compile_to="torch" + # ) + tmp_kwargs = { + "model_keyword": "canny" + } + torch_ir = self.model_map["sd15"][model]["initializer"]( + **self.safe_dict(tmp_kwargs) #, compile_to="torch" + ) + + self.tempfiles[model] = os.path.join( + self.tmp_dir, f"{model}.torch.tempfile" + ) + + with open(self.tempfiles[model], "w+") as f: + f.write(torch_ir) + del torch_ir + gc.collect() + return + + def get_precompiled(self, model): + vmfbs = [] + for dirpath, dirnames, filenames in os.walk(self.pipe_vmfb_path): + vmfbs.extend(filenames) + break + for file in vmfbs: + if model in file: + self.pipe_map[model]["vmfb_path"] = os.path.join( + self.pipe_vmfb_path, file + ) + return def cnet_preview(model, input_image): diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index c26c25bf00..093f09a11a 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -9,7 +9,7 @@ from pathlib import Path from random import randint from turbine_models.custom_models.sd_inference import clip, unet, vae -from apps.shark_studio.api.controlnet import control_adapter_map +from apps.shark_studio.api.controlnet import control_adapter_map, SharkControlnetPipeline from apps.shark_studio.web.utils.state import status_label from apps.shark_studio.web.utils.file_utils import ( safe_name, @@ -112,10 +112,10 @@ def __init__( "unet": { "hf_model_name": base_model_id, "unet_model": unet.UnetModel( - hf_model_name=base_model_id, hf_auth_token=None + hf_model_name=base_model_id, hf_auth_token=None, is_controlled=False, ), "batch_size": batch_size, - # "is_controlled": is_controlled, + "is_controlled": is_controlled, # "num_loras": num_loras, "height": height, "width": width, @@ -126,7 +126,6 @@ def __init__( "hf_model_name": base_model_id, "vae_model": vae.VaeModel( hf_model_name=base_model_id, - custom_vae=custom_vae, ), "batch_size": batch_size, "height": height, @@ -137,7 +136,6 @@ def __init__( "hf_model_name": base_model_id, "vae_model": vae.VaeModel( hf_model_name=base_model_id, - custom_vae=custom_vae, ), "batch_size": batch_size, "height": height, @@ -163,6 +161,7 @@ def __init__( print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.") del static_kwargs gc.collect() + self.controlnet = SharkControlnetPipeline(device) def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img): print(f"\n[LOG] Preparing pipeline...") @@ -291,6 +290,7 @@ def produce_img_latents( mask=None, masked_image_latents=None, return_all_latents=False, + controlnet_latents=None ): # self.status = SD_STATE_IDLE step_time_sum = 0 @@ -299,6 +299,7 @@ def produce_img_latents( text_embeddings_numpy = text_embeddings.detach().numpy() guidance_scale = torch.Tensor([guidance_scale]).to(self.dtype) self.load_submodels(["unet"]) + control_scale = torch.tensor(1.0, dtype=self.dtype) for i, t in tqdm(enumerate(total_timesteps)): step_start_time = time.time() timestep = torch.tensor([t]).to(self.dtype).detach().numpy() @@ -319,15 +320,52 @@ def produce_img_latents( # Profiling Unet. # profile_device = start_profiling(file_path="unet.rdc") - noise_pred = self.run( - "unet", - [ - latent_model_input, - timestep, - text_embeddings_numpy, - guidance_scale, - ], - ) + if controlnet_latents is None: + noise_pred = self.run( + "unet", + [ + latent_model_input, + timestep, + text_embeddings_numpy, + guidance_scale, + ], + ) + else: + noise_pred = self.run( + "unet", + [ + latent_model_input, + timestep, + text_embeddings_numpy, + guidance_scale, + controlnet_latents[0], + controlnet_latents[1], + controlnet_latents[2], + controlnet_latents[3], + controlnet_latents[4], + controlnet_latents[5], + controlnet_latents[6], + controlnet_latents[7], + controlnet_latents[8], + controlnet_latents[9], + controlnet_latents[10], + controlnet_latents[11], + controlnet_latents[12], + control_scale, + control_scale, + control_scale, + control_scale, + control_scale, + control_scale, + control_scale, + control_scale, + control_scale, + control_scale, + control_scale, + control_scale, + control_scale, + ], + ) # end_profiling(profile_device) if cpu_scheduling: @@ -388,6 +426,7 @@ def generate_images( repeatable_seeds, resample_type, control_mode, + controlnet_models, hints, ): # TODO: Batched args @@ -432,12 +471,24 @@ def generate_images( strength=strength, ) + hints = [Image.load_file(x) for x in hints] + controlnet_latents = None + for (model, hint) in zip(controlnet_models, hints): + # if model not in self.controlnets: + # continue + self.controlnet.get_compiled_map("canny") + latent = self.controlnets[model].run(hint) + if controlnet_latents is None: + controlnet_latents = latent + break + latents = self.produce_img_latents( latents=init_latents, text_embeddings=text_embeddings, guidance_scale=guidance_scale, total_timesteps=final_timesteps, cpu_scheduling=True, # until we have schedulers through Turbine + controlnet_latents=controlnet_latents, ) # Img latents -> PIL images @@ -511,6 +562,7 @@ def shark_sd_fn( is_controlled = False control_mode = None hints = [] + controlnet_models = [] num_loras = 0 for i in embeddings: num_loras += 1 if embeddings[i] else 0 @@ -525,16 +577,17 @@ def shark_sd_fn( } else: adapters[f"control_adapter_{model}"] = { - "hf_id": control_adapter_map["stabilityai/stable-diffusion-xl-1.0"][ + "hf_id": +["stabilityai/stable-diffusion-xl-1.0"][ model ], "strength": controlnets["strength"][i], } if model is not None: is_controlled = True + controlnet_models.append(model) control_mode = controlnets["control_mode"] for i in controlnets["hint"]: - hints.append[i] + hints.append(i) submit_pipe_kwargs = { "base_model_id": base_model_id, @@ -567,6 +620,7 @@ def shark_sd_fn( "repeatable_seeds": repeatable_seeds, "resample_type": resample_type, "control_mode": control_mode, + "controlnet_models": controlnet_models, "hints": hints, } if ( diff --git a/apps/shark_studio/modules/pipeline.py b/apps/shark_studio/modules/pipeline.py index 5dee266b13..8c49bf81ee 100644 --- a/apps/shark_studio/modules/pipeline.py +++ b/apps/shark_studio/modules/pipeline.py @@ -1,4 +1,4 @@ -from msvcrt import kbhit +# from msvcrt import kbhit from shark.iree_utils.compile_utils import ( get_iree_compiled_module, load_vmfb_using_mmap, diff --git a/apps/shark_studio/web/ui/sd.py b/apps/shark_studio/web/ui/sd.py index cbb17457ed..cabd47dace 100644 --- a/apps/shark_studio/web/ui/sd.py +++ b/apps/shark_studio/web/ui/sd.py @@ -1,3 +1,4 @@ +import base64 import os import json import gradio as gr @@ -40,6 +41,7 @@ from apps.shark_studio.web.ui.common_events import lora_changed from apps.shark_studio.modules import logger import apps.shark_studio.web.utils.globals as global_obj +import random sd_default_models = [ "CompVis/stable-diffusion-v1-4", @@ -64,25 +66,34 @@ def submit_to_cnet_config( cnet_strength: int, control_mode: str, curr_config: dict, + curr_main_config: dict, ): - if any(i in [None, ""] for i in [stencil, preprocessed_hint]): + if ((stencil is None or stencil == "") or + (None in preprocessed_hint or "" in preprocessed_hint)): return gr.update() + + filename = stencil + "_" + str(random.getrandbits(32)) + ".png" + hint_img = Image.fromarray(preprocessed_hint) + hint_img.save(filename) + if curr_config is not None: if "controlnets" in curr_config: curr_config["controlnets"]["control_mode"] = control_mode curr_config["controlnets"]["model"].append(stencil) - curr_config["controlnets"]["hint"].append(preprocessed_hint) + curr_config["controlnets"]["hint"].append(filename) curr_config["controlnets"]["strength"].append(cnet_strength) - return curr_config + curr_main_config["controlnets"] = curr_config["controlnets"] + return (curr_config, curr_main_config) cnet_map = {} cnet_map["controlnets"] = { "control_mode": control_mode, "model": [stencil], - "hint": [preprocessed_hint], + "hint": [filename], "strength": [cnet_strength], } - return cnet_map + curr_main_config["controlnets"] = cnet_map["controlnets"] + return (cnet_map, curr_main_config) def update_embeddings_json(embedding): @@ -458,8 +469,8 @@ def base_model_changed(base_model_id): ) with gr.Accordion( label="Controlnet Options", - open=False, - visible=False, + open=True, + visible=True, ): preprocessed_hints = gr.State([]) with gr.Column(): @@ -572,20 +583,6 @@ def base_model_changed(base_model_id): preprocessed_hints, ], ) - use_result.click( - fn=submit_to_cnet_config, - inputs=[ - cnet_model, - cnet_output, - cnet_strength, - control_mode, - cnet_config, - ], - outputs=[ - cnet_config, - ], - queue=False, - ) with gr.Column(scale=3, min_width=600): with gr.Group(): sd_gallery = gr.Gallery( @@ -681,6 +678,23 @@ def base_model_changed(base_model_id): outputs=[sd_config_name], ) + use_result.click( + fn=submit_to_cnet_config, + inputs=[ + cnet_model, + cnet_output, + cnet_strength, + control_mode, + cnet_config, + sd_json, + ], + outputs=[ + cnet_config, + sd_json, + ], + queue=False, + ) + pull_kwargs = dict( fn=pull_sd_configs, inputs=[ @@ -731,7 +745,7 @@ def base_model_changed(base_model_id): neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(**pull_kwargs) generate_click = ( stable_diffusion.click(**status_kwargs) - .then(**pull_kwargs) + # .then(**pull_kwargs) .then(**gen_kwargs) ) stop_batch.click( diff --git a/requirements-importer.txt b/requirements-importer.txt index 3fe3a64659..59b53c7b05 100644 --- a/requirements-importer.txt +++ b/requirements-importer.txt @@ -26,7 +26,7 @@ sacremoses sentencepiece # web dependecies. -gradio==3.44.3 +gradio altair scipy diff --git a/requirements.txt b/requirements.txt index 3f7e719e67..bba6d382c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -51,4 +51,5 @@ pyinstaller # For quantized GPTQ models optimum +peft==0.5.0 auto_gptq