Skip to content

Commit

Permalink
Add v_prediction and get_velocity method (#134)
Browse files Browse the repository at this point in the history
* Add v_prediction option and get_velocity method

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* [WIP] Add changes in the inferer to use v_prediction

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Add v_prediction tutorial (#134)

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Change inferer usage to be compatible with new version (#134)

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Update tutorials(#134)

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
  • Loading branch information
Warvito authored Dec 12, 2022
1 parent 6151176 commit 3dc9b86
Show file tree
Hide file tree
Showing 14 changed files with 1,464 additions and 28 deletions.
7 changes: 5 additions & 2 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __call__(
inputs: torch.Tensor,
diffusion_model: Callable[..., torch.Tensor],
noise: torch.Tensor,
timesteps: torch.Tensor,
condition: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Expand All @@ -48,10 +49,9 @@ def __call__(
inputs: Input image to which noise is added.
diffusion_model: diffusion model.
noise: random noise, of the same shape as the input.
timesteps: random timesteps.
condition: Conditioning for network input.
"""
num_timesteps = self.scheduler.num_train_timesteps
timesteps = torch.randint(0, num_timesteps, (inputs.shape[0],), device=inputs.device).long()
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition)

Expand Down Expand Up @@ -123,6 +123,7 @@ def __call__(
autoencoder_model: Callable[..., torch.Tensor],
diffusion_model: Callable[..., torch.Tensor],
noise: torch.Tensor,
timesteps: torch.Tensor,
condition: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Expand All @@ -133,6 +134,7 @@ def __call__(
autoencoder_model: first stage model.
diffusion_model: diffusion model.
noise: random noise, of the same shape as the latent representation.
timesteps: random timesteps.
condition: conditioning for network input.
"""
with torch.no_grad():
Expand All @@ -142,6 +144,7 @@ def __call__(
inputs=latent,
diffusion_model=diffusion_model,
noise=noise,
timesteps=timesteps,
condition=condition,
)

Expand Down
37 changes: 36 additions & 1 deletion generative/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class DDIMScheduler(nn.Module):
steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type: prediction type of the scheduler function, one of `epsilon` (predicting the noise of the
diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""

def __init__(
Expand All @@ -66,6 +69,7 @@ def __init__(
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
) -> None:
super().__init__()
self.beta_schedule = beta_schedule
Expand All @@ -79,6 +83,12 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]:
raise ValueError(
f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`"
)

self.prediction_type = prediction_type
self.num_train_timesteps = num_train_timesteps
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
Expand Down Expand Up @@ -171,7 +181,14 @@ def step(

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
if self.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.prediction_type == "sample":
pred_original_sample = model_output
elif self.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample

# 4. Clip "predicted x_0"
if self.clip_sample:
Expand Down Expand Up @@ -231,3 +248,21 @@ def add_noise(

noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
32 changes: 30 additions & 2 deletions generative/networks/schedulers/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
beta_schedule: str = "linear",
variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: str = "epsilon",
) -> None:
super().__init__()
self.beta_schedule = beta_schedule
Expand All @@ -74,6 +75,13 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]:
raise ValueError(
f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`"
)

self.prediction_type = prediction_type

self.num_train_timesteps = num_train_timesteps
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
Expand Down Expand Up @@ -170,10 +178,12 @@ def step(

# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
if self.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
elif self.prediction_type == "sample":
pred_original_sample = model_output
elif self.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output

# 3. Clip "predicted x_0"
if self.clip_sample:
Expand Down Expand Up @@ -233,3 +243,21 @@ def add_noise(

noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
3 changes: 2 additions & 1 deletion tests/test_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def test_call(self, model_params, input_shape):
)
inferer = DiffusionInferer(scheduler=scheduler)
scheduler.set_timesteps(num_inference_steps=10)
sample = inferer(inputs=input, noise=noise, diffusion_model=model)
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps)
self.assertEqual(sample.shape, input_shape)

@parameterized.expand(TEST_CASES)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def test_prediction_shape(self, model_type, autoencoder_params, stage_2_params,
)
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)
prediction = inferer(inputs=input, autoencoder_model=autoencoder_model, diffusion_model=stage_2, noise=noise)
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
prediction = inferer(
inputs=input, autoencoder_model=autoencoder_model, diffusion_model=stage_2, noise=noise, timesteps=timesteps
)
self.assertEqual(prediction.shape, latent_shape)

@parameterized.expand(TEST_CASES)
Expand Down
12 changes: 10 additions & 2 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,13 @@
" # Generate random noise\n",
" noise = torch.randn_like(images).to(device)\n",
"\n",
" # Create timesteps\n",
" timesteps = torch.randint(\n",
" 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device\n",
" ).long()\n",
"\n",
" # Get model prediction\n",
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)\n",
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n",
"\n",
" loss = F.mse_loss(noise_pred.float(), noise.float())\n",
"\n",
Expand All @@ -806,7 +811,10 @@
" with torch.no_grad():\n",
" with autocast(enabled=True):\n",
" noise = torch.randn_like(images).to(device)\n",
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)\n",
" timesteps = torch.randint(\n",
" 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device\n",
" ).long()\n",
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n",
" val_loss = F.mse_loss(noise_pred.float(), noise.float())\n",
"\n",
" val_epoch_loss += val_loss.item()\n",
Expand Down
12 changes: 10 additions & 2 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,13 @@
# Generate random noise
noise = torch.randn_like(images).to(device)

# Create timesteps
timesteps = torch.randint(
0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
).long()

# Get model prediction
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)

loss = F.mse_loss(noise_pred.float(), noise.float())

Expand All @@ -233,7 +238,10 @@
with torch.no_grad():
with autocast(enabled=True):
noise = torch.randn_like(images).to(device)
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)
timesteps = torch.randint(
0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
).long()
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)
val_loss = F.mse_loss(noise_pred.float(), noise.float())

val_epoch_loss += val_loss.item()
Expand Down
16 changes: 11 additions & 5 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,9 @@
")\n",
"model.to(device)\n",
"\n",
"num_train_timesteps = 1000\n",
"scheduler = DDPMScheduler(\n",
" num_train_timesteps=1000,\n",
" num_train_timesteps=num_train_timesteps,\n",
")\n",
"\n",
"optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)\n",
Expand Down Expand Up @@ -433,13 +434,17 @@
"\n",
" \"\"\"\n",
"\n",
" def __init__(self, condition_name: Optional[str] = None):\n",
" def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None):\n",
" self.condition_name = condition_name\n",
" self.num_train_timesteps = num_train_timesteps\n",
"\n",
" def get_noise(self, images):\n",
" \"\"\"Returns the noise tensor for input tensor `images`, override this for different noise distributions.\"\"\"\n",
" return torch.randn_like(images)\n",
"\n",
" def get_timesteps(self, images):\n",
" return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()\n",
"\n",
" def __call__(\n",
" self,\n",
" batchdata: Dict[str, torch.Tensor],\n",
Expand All @@ -449,8 +454,9 @@
" ):\n",
" images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)\n",
" noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)\n",
" timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)\n",
"\n",
" kwargs = {\"noise\": noise}\n",
" kwargs = {\"noise\": noise, \"timesteps\": timesteps}\n",
"\n",
" if self.condition_name is not None and isinstance(batchdata, Mapping):\n",
" kwargs[\"conditioning\"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)\n",
Expand Down Expand Up @@ -2159,7 +2165,7 @@
" val_data_loader=val_loader,\n",
" network=model,\n",
" inferer=inferer,\n",
" prepare_batch=DiffusionPrepareBatch(),\n",
" prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),\n",
" key_val_metric={\"val_mean_abs_error\": MeanAbsoluteError(output_transform=from_engine([\"pred\", \"label\"]))},\n",
" val_handlers=val_handlers,\n",
")\n",
Expand All @@ -2178,7 +2184,7 @@
" optimizer=optimizer,\n",
" loss_function=torch.nn.MSELoss(),\n",
" inferer=inferer,\n",
" prepare_batch=DiffusionPrepareBatch(),\n",
" prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),\n",
" key_train_metric={\"train_acc\": MeanSquaredError(output_transform=from_engine([\"pred\", \"label\"]))},\n",
" train_handlers=train_handlers,\n",
")\n",
Expand Down
16 changes: 11 additions & 5 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@
)
model.to(device)

num_train_timesteps = 1000
scheduler = DDPMScheduler(
num_train_timesteps=1000,
num_train_timesteps=num_train_timesteps,
)

optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)
Expand All @@ -203,13 +204,17 @@ class DiffusionPrepareBatch(PrepareBatch):
"""

def __init__(self, condition_name: Optional[str] = None):
def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None):
self.condition_name = condition_name
self.num_train_timesteps = num_train_timesteps

def get_noise(self, images):
"""Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
return torch.randn_like(images)

def get_timesteps(self, images):
return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()

def __call__(
self,
batchdata: Dict[str, torch.Tensor],
Expand All @@ -219,8 +224,9 @@ def __call__(
):
images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)

kwargs = {"noise": noise}
kwargs = {"noise": noise, "timesteps": timesteps}

if self.condition_name is not None and isinstance(batchdata, Mapping):
kwargs["conditioning"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)
Expand All @@ -244,7 +250,7 @@ def __call__(
val_data_loader=val_loader,
network=model,
inferer=inferer,
prepare_batch=DiffusionPrepareBatch(),
prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),
key_val_metric={"val_mean_abs_error": MeanAbsoluteError(output_transform=from_engine(["pred", "label"]))},
val_handlers=val_handlers,
)
Expand All @@ -263,7 +269,7 @@ def __call__(
optimizer=optimizer,
loss_function=torch.nn.MSELoss(),
inferer=inferer,
prepare_batch=DiffusionPrepareBatch(),
prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),
key_train_metric={"train_acc": MeanSquaredError(output_transform=from_engine(["pred", "label"]))},
train_handlers=train_handlers,
)
Expand Down
Loading

0 comments on commit 3dc9b86

Please sign in to comment.