Skip to content

Commit

Permalink
Refactoring schedulers (#285)
Browse files Browse the repository at this point in the history
* Adding components and refactoring of schedulers (DDPM only)

Signed-off-by: Eric Kerfoot <[email protected]>

* Fix

Signed-off-by: Eric Kerfoot <[email protected]>

* Adding tests I forgot to add

Signed-off-by: Eric Kerfoot <[email protected]>

* Updates from comments

* Updates to other schedulers

Signed-off-by: Eric Kerfoot <[email protected]>

* Update

Signed-off-by: Eric Kerfoot <[email protected]>

* Tutorials updates

Signed-off-by: Eric Kerfoot <[email protected]>

* Update

Signed-off-by: Eric Kerfoot <[email protected]>

* Update

Signed-off-by: Eric Kerfoot <[email protected]>

* Fixes

Signed-off-by: Eric Kerfoot <[email protected]>

* Updates from comments

Signed-off-by: Eric Kerfoot <[email protected]>

* Update generative/networks/schedulers/ddpm.py

Co-authored-by: Mark Graham <[email protected]>
Signed-off-by: Eric Kerfoot <[email protected]>

* Autofixin'

Signed-off-by: Eric Kerfoot <[email protected]>

* Fixes

Signed-off-by: Eric Kerfoot <[email protected]>

---------

Signed-off-by: Eric Kerfoot <[email protected]>
Signed-off-by: Eric Kerfoot <[email protected]>
Co-authored-by: Mark Graham <[email protected]>
  • Loading branch information
ericspod and marksgraham authored May 16, 2023
1 parent 798a2ef commit 5a8f75f
Show file tree
Hide file tree
Showing 26 changed files with 649 additions and 304 deletions.
1 change: 1 addition & 0 deletions generative/networks/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .ddim import DDIMScheduler
from .ddpm import DDPMScheduler
from .pndm import PNDMScheduler
from .scheduler import NoiseSchedules, Scheduler
135 changes: 42 additions & 93 deletions generative/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,66 +33,62 @@

import numpy as np
import torch
import torch.nn as nn
from monai.utils import StrEnum

from .scheduler import Scheduler

class DDIMScheduler(nn.Module):

class DDIMPredictionType(StrEnum):
"""
Set of valid prediction type names for the DDIM scheduler's `prediction_type` argument.
epsilon: predicting the noise of the diffusion process
sample: directly predicting the noisy sample
v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
"""

EPSILON = "epsilon"
SAMPLE = "sample"
V_PREDICTION = "v_prediction"


class DDIMScheduler(Scheduler):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion
Implicit Models" https://arxiv.org/abs/2010.02502
Args:
num_train_timesteps: number of diffusion steps used to train the model.
beta_start: the starting `beta` value of inference.
beta_end: the final `beta` value.
beta_schedule: {``"linear"``, ``"scaled_linear"``}
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
schedule: member of NoiseSchedules, name of noise schedule function in component store
clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one.
For the final step there is no previous alpha. When this option is `True` the previous alpha product is
fixed to `1`, otherwise it uses the value of alpha at step 0.
steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type: {``"epsilon"``, ``"sample"``, ``"v_prediction"``}
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
prediction_type: member of DDPMPredictionType
schedule_args: arguments to pass to the schedule function
"""

def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 1e-4,
beta_end: float = 2e-2,
beta_schedule: str = "linear",
schedule: str = "linear_beta",
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
prediction_type: str = DDIMPredictionType.EPSILON,
**schedule_args,
) -> None:
super().__init__()
self.beta_schedule = beta_schedule
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
super().__init__(num_train_timesteps, schedule, **schedule_args)

if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]:
raise ValueError(
f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`"
)
if prediction_type not in DDIMPredictionType.__members__.values():
raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType")

self.prediction_type = prediction_type
self.num_train_timesteps = num_train_timesteps
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
Expand All @@ -103,13 +99,13 @@ def __init__(
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0

self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64))
self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64))

self.clip_sample = clip_sample
self.steps_offset = steps_offset

# default the number of inference timesteps to the number of train steps
self.set_timesteps(num_train_timesteps)
self.set_timesteps(self.num_train_timesteps)

def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
"""
Expand Down Expand Up @@ -190,13 +186,13 @@ def step(

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
if self.prediction_type == DDIMPredictionType.EPSILON:
pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5)
pred_epsilon = model_output
elif self.prediction_type == "sample":
elif self.prediction_type == DDIMPredictionType.SAMPLE:
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.prediction_type == "v_prediction":
pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5)
elif self.prediction_type == DDIMPredictionType.V_PREDICTION:
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample

Expand All @@ -207,19 +203,19 @@ def step(
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
std_dev_t = eta * variance**0.5

# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon

# 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction

if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise

pred_prev_sample = pred_prev_sample + variance

Expand Down Expand Up @@ -263,13 +259,13 @@ def reversed_step(
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf

if self.prediction_type == "epsilon":
if self.prediction_type == DDIMPredictionType.EPSILON:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.prediction_type == "sample":
elif self.prediction_type == DDIMPredictionType.SAMPLE:
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.prediction_type == "v_prediction":
elif self.prediction_type == DDIMPredictionType.V_PREDICTION:
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample

Expand All @@ -284,50 +280,3 @@ def reversed_step(
pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

return pred_post_sample, pred_original_sample

def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
"""
Add noise to the original samples.
Args:
original_samples: original samples
noise: noise to add to samples
timesteps: timesteps tensor indicating the timestep to be computed for each sample.
Returns:
noisy_samples: sample with added noise
"""
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)

sqrt_alpha_cumprod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_cumprod = sqrt_alpha_cumprod.flatten()
while len(sqrt_alpha_cumprod.shape) < len(original_samples.shape):
sqrt_alpha_cumprod = sqrt_alpha_cumprod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
Loading

0 comments on commit 5a8f75f

Please sign in to comment.