Skip to content

Commit

Permalink
Update test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MC-E committed Oct 20, 2023
1 parent 9e79a5d commit c408b05
Showing 1 changed file with 34 additions and 102 deletions.
136 changes: 34 additions & 102 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,39 @@
from omegaconf import OmegaConf
from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL
from diffusers.utils import load_image, make_image_grid
from controlnet_aux.lineart import LineartDetector
import torch
import os
import cv2
import datetime
from huggingface_hub import hf_hub_url
import subprocess
import shlex
import copy
from basicsr.utils import tensor2img

from Adapter.Sampling import diffusion_inference
from configs.utils import instantiate_from_config
from Adapter.inference_base import get_base_argument_parser
from Adapter.extra_condition.api import get_cond_model, ExtraCondition
from Adapter.extra_condition import api

urls = {
'TencentARC/T2I-Adapter':[
'models_XL/adapter-xl-canny.pth', 'models_XL/adapter-xl-sketch.pth',
'models_XL/adapter-xl-openpose.pth', 'third-party-models/body_pose_model.pth',
'third-party-models/table5_pidinet.pth'
]
}

if os.path.exists('checkpoints') == False:
os.mkdir('checkpoints')
for repo in urls:
files = urls[repo]
for file in files:
url = hf_hub_url(repo, file)
name_ckp = url.split('/')[-1]
save_path = os.path.join('checkpoints',name_ckp)
if os.path.exists(save_path) == False:
subprocess.run(shlex.split(f'wget {url} -O {save_path}'))

# config
parser = get_base_argument_parser()
parser.add_argument(
'--model_id',
type=str,
default="stabilityai/stable-diffusion-xl-base-1.0",
help='huggingface url to stable diffusion model',
)
parser.add_argument(
'--config',
type=str,
default='configs/inference/Adapter-XL-sketch.yaml',
help='config path to T2I-Adapter',
)
parser.add_argument(
'--path_source',
type=str,
default='examples/dog.png',
help='config path to the source image',
)
parser.add_argument(
'--in_type',
type=str,
default='image',
help='config path to the source image',
)
global_opt = parser.parse_args()
global_opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

if __name__ == '__main__':
config = OmegaConf.load(global_opt.config)
# Adapter creation
cond_name = config.model.params.adapter_config.name
adapter_config = config.model.params.adapter_config
adapter = instantiate_from_config(adapter_config).cuda()
adapter.load_state_dict(torch.load(config.model.params.adapter_config.pretrained))
cond_model = get_cond_model(global_opt, getattr(ExtraCondition, cond_name))
process_cond_module = getattr(api, f'get_cond_{cond_name}')

# diffusion sampler creation
sampler = diffusion_inference(global_opt.model_id)
# load adapter
adapter = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16, varient="fp16"
).to("cuda")

# diffusion generation
cond = process_cond_module(
global_opt,
global_opt.path_source,
cond_inp_type = global_opt.in_type,
cond_model = cond_model
# load euler_a scheduler
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
euler_a = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
model_id, vae=vae, adapter=adapter, scheduler=euler_a, torch_dtype=torch.float16, variant="fp16",
).to("cuda")
pipe.enable_xformers_memory_efficient_attention()

line_detector = LineartDetector.from_pretrained("lllyasviel/Annotators").to("cuda")

url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/figs_SDXLV1.0/org_lin.jpg"
image = load_image(url)
image = line_detector(
image, detect_resolution=384, image_resolution=1024
)
with torch.no_grad():
adapter_features = adapter(cond)
result = sampler.inference(
prompt = global_opt.prompt,
prompt_n = global_opt.neg_prompt,
steps = global_opt.steps,
adapter_features = copy.deepcopy(adapter_features),
guidance_scale = global_opt.scale,
size = (cond.shape[-2], cond.shape[-1]),
seed= global_opt.seed,
)

# save results
root_results = os.path.join('results', cond_name)
if not os.path.exists(root_results):
os.makedirs(root_results)
now = datetime.datetime.now()
formatted_date = now.strftime("%Y-%m-%d")
formatted_time = now.strftime("%H:%M:%S")
im_cond = tensor2img(cond)
cv2.imwrite(os.path.join(root_results, formatted_date+'-'+formatted_time+'_image.png'), result)
cv2.imwrite(os.path.join(root_results, formatted_date+'-'+formatted_time+'_condition.png'), im_cond)

prompt = "Ice dragon roar, 4k photo"
negative_prompt = "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured"
gen_images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
num_inference_steps=30,
adapter_conditioning_scale=0.8,
guidance_scale=7.5,
).images[0]
gen_images.save('out_lin.png')

0 comments on commit c408b05

Please sign in to comment.