From 8f49ffc5455aa705ec5933ea4529787c499e97cf Mon Sep 17 00:00:00 2001 From: Vikram Voleti Date: Fri, 15 Mar 2024 08:18:57 +0000 Subject: [PATCH] Makes init changes for SV3D --- .../configs/sv3d_p_image_decoder.yaml | 132 +++++++++++++ .../configs/sv3d_u_image_decoder.yaml | 120 ++++++++++++ scripts/sampling/simple_video_sample.py | 179 ++++++++++++++---- sgm/modules/diffusionmodules/guiders.py | 34 +++- .../diffusionmodules/sigma_sampling.py | 5 + 5 files changed, 430 insertions(+), 40 deletions(-) create mode 100644 scripts/sampling/configs/sv3d_p_image_decoder.yaml create mode 100644 scripts/sampling/configs/sv3d_u_image_decoder.yaml diff --git a/scripts/sampling/configs/sv3d_p_image_decoder.yaml b/scripts/sampling/configs/sv3d_p_image_decoder.yaml new file mode 100644 index 00000000..ddde3878 --- /dev/null +++ b/scripts/sampling/configs/sv3d_p_image_decoder.yaml @@ -0,0 +1,132 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + disable_first_stage_autocast: True + ckpt_path: checkpoints/sv3d_p.safetensors + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + + network_config: + target: sgm.modules.diffusionmodules.video_model.VideoUNet + params: + adm_in_channels: 1280 + num_classes: sequential + use_checkpoint: True + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: softmax-xformers + extra_ff_mix_layer: True + use_spatial_context: True + merge_strategy: learned_with_images + video_kernel_size: [3, 1, 1] + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - input_key: cond_frames_without_noise + is_trainable: False + target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + params: + n_cond_frames: 1 + n_copies: 21 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: True + + - input_key: cond_frames_without_noise + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + disable_encoder_autocast: True + n_cond_frames: 1 + n_copies: 21 + is_ae: True + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler + sigma_cond_config: + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - input_key: polars_rad + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 512 + + - input_key: azimuths_rad + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 512 + + first_stage_config: + target: sgm.models.autoencoder.AutoencodingEngine + params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + decoder_config: + target: sgm.modules.diffusionmodules.model.Decoder + params: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 4, 4 ] + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + encoder_config: + target: torch.nn.Identity + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + params: + sigma_max: 700.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider + params: + max_scale: 2.5 diff --git a/scripts/sampling/configs/sv3d_u_image_decoder.yaml b/scripts/sampling/configs/sv3d_u_image_decoder.yaml new file mode 100644 index 00000000..a4d806a7 --- /dev/null +++ b/scripts/sampling/configs/sv3d_u_image_decoder.yaml @@ -0,0 +1,120 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + disable_first_stage_autocast: True + ckpt_path: checkpoints/sv3d_u.safetensors + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + + network_config: + target: sgm.modules.diffusionmodules.video_model.VideoUNet + params: + adm_in_channels: 256 + num_classes: sequential + use_checkpoint: True + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: softmax-xformers + extra_ff_mix_layer: True + use_spatial_context: True + merge_strategy: learned_with_images + video_kernel_size: [3, 1, 1] + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: False + input_key: cond_frames_without_noise + target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + params: + n_cond_frames: 1 + n_copies: 21 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: True + + - input_key: cond_frames_without_noise + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + disable_encoder_autocast: True + n_cond_frames: 1 + n_copies: 21 + is_ae: True + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler + sigma_cond_config: + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencodingEngine + params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + decoder_config: + target: sgm.modules.diffusionmodules.model.Decoder + params: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 4, 4 ] + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + encoder_config: + target: torch.nn.Identity + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + params: + sigma_max: 700.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider + params: + max_scale: 3.0 diff --git a/scripts/sampling/simple_video_sample.py b/scripts/sampling/simple_video_sample.py index c3f4ad2a..a5f1019f 100644 --- a/scripts/sampling/simple_video_sample.py +++ b/scripts/sampling/simple_video_sample.py @@ -1,11 +1,19 @@ import math import os +import sys from glob import glob from pathlib import Path -from typing import Optional +from typing import List, Optional + +import imageio + +sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) + +import io import cv2 import numpy as np +import requests import torch from einops import rearrange, repeat from fire import Fire @@ -13,12 +21,30 @@ from PIL import Image from torchvision.transforms import ToTensor -from scripts.util.detection.nsfw_and_watermark_dectection import \ - DeepFloydDataFiltering +from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.inference.helpers import embed_watermark from sgm.util import default, instantiate_from_config +def remove_bg_stable_PIL(PIL_img): + img_byte_arr = io.BytesIO() + PIL_img.save(img_byte_arr, format="PNG") + img_byte_arr.seek(0) + response = requests.post( + f"https://dev.apiv2.stability.ai/v2alpha/generation/stable-image/remove-background", + headers={ + "authorization": f"Bearer sk-yjTqff0EeUW4iSHMS7MEltSoQ19fQBsAsHeF77F4LQX3n2OG" + }, + files={"image": io.BufferedReader(img_byte_arr)}, + data={"output_format": "png"}, + ) + if response.status_code == 200: + return Image.open(io.BytesIO(response.content)) + else: + print("ERROR: Could not remove background!! " + str(response.json())) + return PIL_img + + def sample( input_path: str = "assets/test_image.png", # Can either be image file or folder with image files num_frames: Optional[int] = None, @@ -28,9 +54,12 @@ def sample( motion_bucket_id: int = 127, cond_aug: float = 0.02, seed: int = 23, - decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. + decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", output_folder: Optional[str] = None, + elevations_deg: Optional[float | List[float]] = 10.0, # For SV3D + azimuths_deg: Optional[float | List[float]] = None, # For SV3D + image_frame_ratio: Optional[float] = None, ): """ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each @@ -61,6 +90,28 @@ def sample( output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/" ) model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml" + elif version == "sv3d_u": + num_frames = 21 + num_steps = default(num_steps, 50) + output_folder = default( + output_folder, "outputs/simple_video_sample/sv3d_u_image_decoder/" + ) + model_config = "scripts/sampling/configs/sv3d_u_image_decoder.yaml" + cond_aug = 0.0 + elif version == "sv3d_p": + num_frames = 21 + num_steps = default(num_steps, 50) + output_folder = default( + output_folder, "outputs/simple_video_sample/sv3d_p_image_decoder/" + ) + model_config = "scripts/sampling/configs/sv3d_p_image_decoder.yaml" + cond_aug = 0.0 + if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): + elevations_deg = [elevations_deg] * num_frames + polars_rad = [np.deg2rad(90 - e) for e in elevations_deg] + if azimuths_deg is None: + azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360 + azimuths_rad = [np.deg2rad(a) for a in azimuths_deg] else: raise ValueError(f"Version {version} does not exist.") @@ -93,20 +144,57 @@ def sample( raise ValueError for input_img_path in all_img_paths: - with Image.open(input_img_path) as image: + if "sv3d" in version: + + image = Image.open(input_img_path) if image.mode == "RGBA": - image = image.convert("RGB") - w, h = image.size - - if h % 64 != 0 or w % 64 != 0: - width, height = map(lambda x: x - x % 64, (w, h)) - image = image.resize((width, height)) - print( - f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" - ) + pass + else: + # remove bg + image.thumbnail([768, 768], Image.Resampling.LANCZOS) + image = remove_bg_stable_PIL(image) + + # resize object in frame + image_arr = np.array(image) + in_w, in_h = image_arr.shape[:2] + ret, mask = cv2.threshold( + np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY + ) + x, y, w, h = cv2.boundingRect(mask) + max_size = max(w, h) + side_len = ( + int(max_size / image_frame_ratio) + if image_frame_ratio is not None + else in_w + ) + padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) + center = side_len // 2 + padded_image[ + center - h // 2 : center - h // 2 + h, + center - w // 2 : center - w // 2 + w, + ] = image_arr[y : y + h, x : x + w] + # resize frame to 576x576 + rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS) + # white bg + rgba_arr = np.array(rgba) / 255.0 + rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) + input_image = Image.fromarray((rgb * 255).astype(np.uint8)) + + else: + with Image.open(input_img_path) as image: + if image.mode == "RGBA": + input_image = image.convert("RGB") + w, h = image.size + + if h % 64 != 0 or w % 64 != 0: + width, height = map(lambda x: x - x % 64, (w, h)) + input_image = input_image.resize((width, height)) + print( + f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" + ) - image = ToTensor()(image) - image = image * 2.0 - 1.0 + image = ToTensor()(input_image) + image = image * 2.0 - 1.0 image = image.unsqueeze(0).to(device) H, W = image.shape[2:] @@ -114,10 +202,14 @@ def sample( F = 8 C = 4 shape = (num_frames, C, H // F, W // F) - if (H, W) != (576, 1024): + if (H, W) != (576, 1024) and "sv3d" not in version: print( "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`." ) + if (H, W) != (576, 576) and "sv3d" in version: + print( + "WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576." + ) if motion_bucket_id > 255: print( "WARNING: High motion bucket! This may lead to suboptimal performance." @@ -130,12 +222,15 @@ def sample( print("WARNING: Large fps value! This may lead to suboptimal performance.") value_dict = {} - value_dict["motion_bucket_id"] = motion_bucket_id - value_dict["fps_id"] = fps_id - value_dict["cond_aug"] = cond_aug value_dict["cond_frames_without_noise"] = image - value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) - value_dict["cond_aug"] = cond_aug + if version == "sv3d_p": + value_dict["polars_rad"] = polars_rad + value_dict["azimuths_rad"] = azimuths_rad + elif "sv3d" not in version: + value_dict["motion_bucket_id"] = motion_bucket_id + value_dict["fps_id"] = fps_id + value_dict["cond_aug"] = cond_aug + value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) with torch.no_grad(): with torch.autocast(device): @@ -155,7 +250,8 @@ def sample( ], ) - for k in ["crossattn", "concat"]: + repeat_conds = [] if "sv3d" in version else ["crossattn", "concat"] + for k in repeat_conds: uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) @@ -177,18 +273,24 @@ def denoiser(input, sigma, c): samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) model.en_and_decode_n_samples_a_time = decoding_t samples_x = model.decode_first_stage(samples_z) + if "sv3d" in version: + samples_x[-1:] = value_dict["cond_frames_without_noise"] samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) os.makedirs(output_folder, exist_ok=True) base_count = len(glob(os.path.join(output_folder, "*.mp4"))) - video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") - writer = cv2.VideoWriter( - video_path, - cv2.VideoWriter_fourcc(*"MP4V"), - fps_id + 1, - (samples.shape[-1], samples.shape[-2]), + + imageio.imwrite( + os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image ) + video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") + # writer = cv2.VideoWriter( + # video_path, + # cv2.VideoWriter_fourcc(*"MP4V"), + # fps_id + 1, + # (samples.shape[-1], samples.shape[-2]), + # ) samples = embed_watermark(samples) samples = filter(samples) vid = ( @@ -197,10 +299,11 @@ def denoiser(input, sigma, c): .numpy() .astype(np.uint8) ) - for frame in vid: - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - writer.write(frame) - writer.release() + # for frame in vid: + # frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + # writer.write(frame) + # writer.release() + imageio.mimwrite(video_path, vid) def get_unique_embedder_keys_from_conditioner(conditioner): @@ -230,12 +333,10 @@ def get_batch(keys, value_dict, N, T, device): "1 -> b", b=math.prod(N), ) - elif key == "cond_frames": - batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) - elif key == "cond_frames_without_noise": - batch[key] = repeat( - value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] - ) + elif key == "cond_frames" or key == "cond_frames_without_noise": + batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0]) + elif key == "polars_rad" or key == "azimuths_rad": + batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0]) else: batch[key] = value_dict[key] diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py index e8eca43e..bcaa01c5 100644 --- a/sgm/modules/diffusionmodules/guiders.py +++ b/sgm/modules/diffusionmodules/guiders.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple, Union import torch from einops import rearrange, repeat @@ -97,3 +97,35 @@ def prepare_inputs( assert c[k] == uc[k] c_out[k] = c[k] return torch.cat([x] * 2), torch.cat([s] * 2), c_out + + +class TrianglePredictionGuider(LinearPredictionGuider): + def __init__( + self, + max_scale: float, + num_frames: int, + min_scale: float = 1.0, + period: float | List[float] = 1.0, + period_fusing: Literal["mean", "multiply", "max"] = "max", + additional_cond_keys: Optional[Union[List[str], str]] = None, + ): + super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) + values = torch.linspace(0, 1, num_frames) + # Constructs a triangle wave + if isinstance(period, float): + period = [period] + + scales = [] + for p in period: + scales.append(self.triangle_wave(values, p)) + + if period_fusing == "mean": + scale = sum(scales) / len(period) + elif period_fusing == "multiply": + scale = torch.prod(torch.stack(scales), dim=0) + elif period_fusing == "max": + scale = torch.max(torch.stack(scales), dim=0).values + self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) + + def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor: + return 2 * (values / period - torch.floor(values / period + 0.5)).abs() diff --git a/sgm/modules/diffusionmodules/sigma_sampling.py b/sgm/modules/diffusionmodules/sigma_sampling.py index d54724c6..c2bac44b 100644 --- a/sgm/modules/diffusionmodules/sigma_sampling.py +++ b/sgm/modules/diffusionmodules/sigma_sampling.py @@ -29,3 +29,8 @@ def __call__(self, n_samples, rand=None): torch.randint(0, self.num_idx, (n_samples,)), ) return self.idx_to_sigma(idx) + + +class ZeroSampler: + def __call__(self, n_samples: int, rand=None) -> torch.Tensor: + return torch.zeros_like(default(rand, torch.randn((n_samples,)))) + 1.0e-5