From 6151176cdc99c51fe95b172af0ce541e361ebe22 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 12 Dec 2022 12:57:26 +0000 Subject: [PATCH] Add class conditioning (#140) * Add class conditioning ans tests Signed-off-by: Walter Hugo Lopez Pinaya * Add misssing test (#140) Signed-off-by: Walter Hugo Lopez Pinaya Signed-off-by: Walter Hugo Lopez Pinaya --- .../networks/nets/diffusion_model_unet.py | 27 ++++++++++--- tests/test_diffusion_model_unet.py | 38 +++++++++++++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 3a287774..7f652961 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1439,6 +1439,8 @@ class DiffusionModelUNet(nn.Module): with_conditioning: if True add spatial transformers to perform conditioning. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. """ def __init__( @@ -1455,6 +1457,7 @@ def __init__( with_conditioning: bool = False, transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1499,6 +1502,11 @@ def __init__( nn.Linear(time_embed_dim, time_embed_dim), ) + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + # down self.down_blocks = nn.ModuleList([]) output_channel = num_channels[0] @@ -1591,37 +1599,46 @@ def forward( x: torch.Tensor, timesteps: torch.Tensor, context: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: x: input tensor (N, C, SpatialDims). timesteps: timestep tensor (N,). context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). """ # 1. time t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) emb = self.time_embed(t_emb) - # 2. initial convolution + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + emb = emb + class_emb + + # 3. initial convolution h = self.conv_in(x) - # 3. down + # 4. down down_block_res_samples: List[torch.Tensor] = [h] for downsample_block in self.down_blocks: h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) for residual in res_samples: down_block_res_samples.append(residual) - # 4. mid + # 5. mid h = self.middle_block(hidden_states=h, temb=emb, context=context) - # 5. up + # 6. up for upsample_block in self.up_blocks: res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) - # 6. output block + # 7. output block h = self.out(h) return h diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index 08317175..f9925b96 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -168,6 +168,44 @@ def test_with_conditioning_cross_attention_dim_none(self): norm_num_groups=8, ) + def test_shape_conditioned_models_class_conditioning(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + with eval_mode(net): + result = net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + class_labels=torch.randint(0, 2, (1,)).long(), + ) + self.assertEqual(result.shape, (1, 1, 16, 32)) + + def test_conditioned_models_no_class_labels(self): + with self.assertRaises(ValueError): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + net.forward( + x=torch.rand((1, 1, 16, 32)), + timesteps=torch.randint(0, 1000, (1,)).long(), + ) + def test_script_unconditioned_2d_models(self): net = DiffusionModelUNet( spatial_dims=2,