Skip to content

Commit

Permalink
fix schedulers for make_fx
Browse files Browse the repository at this point in the history
  • Loading branch information
PhaneeshB committed Feb 23, 2024
1 parent 76c645d commit 84e78e3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
31 changes: 19 additions & 12 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def scale_model_input(
if self.step_index is None:
self._init_step_index(timestep)

sigma = self.sigmas[self.step_index]
sigma = self.sigmas.index_select(0, self.step_index)
sample = sample / ((sigma**2 + 1) ** 0.5)

self.is_scale_input_called = True
Expand Down Expand Up @@ -344,18 +344,22 @@ def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)

index_candidates = (self.timesteps == timestep).nonzero()
index_candidates = torch.nonzero(self.timesteps == timestep)

# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]

self._step_index = step_index.item()
# if len(index_candidates) > 1:
# step_index = index_candidates[1]
# else:
# step_index = index_candidates[0]

pos = 1 if len(index_candidates) > 1 else 0
step_index = index_candidates[pos]

self._step_index = torch.scalar_tensor(step_index.item())

def step(
self,
Expand Down Expand Up @@ -419,9 +423,13 @@ def step(
if self.step_index is None:
self._init_step_index(timestep)

sigma = self.sigmas[self.step_index]
sigma = self.sigmas.index_select(0, self.step_index)

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
condition = s_tmin <= sigma
condition1 = sigma <= s_tmax
gamma = torch.where(condition & condition1,
torch.minimum(torch.tensor(s_churn / (len(self.sigmas) - 1)), torch.tensor(2**0.5 - 1)),
torch.tensor(0.0))

noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
Expand All @@ -430,8 +438,7 @@ def step(
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)

if gamma > 0:
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
sample = torch.where(gamma > 0, sample + eps * (sigma_hat**2 - sigma**2) ** 0.5, sample)

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
Expand All @@ -451,7 +458,7 @@ def step(
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat

dt = self.sigmas[self.step_index + 1] - sigma_hat
dt = self.sigmas.index_select(0, self.step_index + 1) - sigma_hat

prev_sample = sample + derivative * dt

Expand Down
13 changes: 11 additions & 2 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,17 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
if not isinstance(timestep, torch.Tensor):
timestep = torch.tensor(timestep)
if not isinstance(prev_timestep, torch.Tensor):
prev_timestep = torch.tensor(prev_timestep)
alpha_prod_t = self.alphas_cumprod.index_select(0, timestep)
updated_prev_timestep = torch.where(
prev_timestep >= 0, prev_timestep, self.alphas_cumprod.size(dim=0) + prev_timestep
)
alpha_prod_t_prev = torch.where(
prev_timestep >= 0, self.alphas_cumprod.index_select(0, updated_prev_timestep), self.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

Expand Down

0 comments on commit 84e78e3

Please sign in to comment.