Skip to content

Commit

Permalink
Add class conditioning (#140)
Browse files Browse the repository at this point in the history
* Add class conditioning ans tests

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Add misssing test (#140)

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
  • Loading branch information
Warvito authored Dec 12, 2022
1 parent 3f72d37 commit 6151176
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
27 changes: 22 additions & 5 deletions generative/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,8 @@ class DiffusionModelUNet(nn.Module):
with_conditioning: if True add spatial transformers to perform conditioning.
transformer_num_layers: number of layers of Transformer blocks to use.
cross_attention_dim: number of context dimensions to use.
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
classes.
"""

def __init__(
Expand All @@ -1455,6 +1457,7 @@ def __init__(
with_conditioning: bool = False,
transformer_num_layers: int = 1,
cross_attention_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
) -> None:
super().__init__()
if with_conditioning is True and cross_attention_dim is None:
Expand Down Expand Up @@ -1499,6 +1502,11 @@ def __init__(
nn.Linear(time_embed_dim, time_embed_dim),
)

# class embedding
self.num_class_embeds = num_class_embeds
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)

# down
self.down_blocks = nn.ModuleList([])
output_channel = num_channels[0]
Expand Down Expand Up @@ -1591,37 +1599,46 @@ def forward(
x: torch.Tensor,
timesteps: torch.Tensor,
context: Optional[torch.Tensor] = None,
class_labels: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x: input tensor (N, C, SpatialDims).
timesteps: timestep tensor (N,).
context: context tensor (N, 1, ContextDim).
class_labels: context tensor (N, ).
"""
# 1. time
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
emb = self.time_embed(t_emb)

# 2. initial convolution
# 2. class
if self.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels)
emb = emb + class_emb

# 3. initial convolution
h = self.conv_in(x)

# 3. down
# 4. down
down_block_res_samples: List[torch.Tensor] = [h]
for downsample_block in self.down_blocks:
h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
for residual in res_samples:
down_block_res_samples.append(residual)

# 4. mid
# 5. mid
h = self.middle_block(hidden_states=h, temb=emb, context=context)

# 5. up
# 6. up
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)

# 6. output block
# 7. output block
h = self.out(h)

return h
38 changes: 38 additions & 0 deletions tests/test_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,44 @@ def test_with_conditioning_cross_attention_dim_none(self):
norm_num_groups=8,
)

def test_shape_conditioned_models_class_conditioning(self):
net = DiffusionModelUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_res_blocks=1,
num_channels=(8, 8, 8),
attention_levels=(False, False, True),
norm_num_groups=8,
num_head_channels=8,
num_class_embeds=2,
)
with eval_mode(net):
result = net.forward(
x=torch.rand((1, 1, 16, 32)),
timesteps=torch.randint(0, 1000, (1,)).long(),
class_labels=torch.randint(0, 2, (1,)).long(),
)
self.assertEqual(result.shape, (1, 1, 16, 32))

def test_conditioned_models_no_class_labels(self):
with self.assertRaises(ValueError):
net = DiffusionModelUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_res_blocks=1,
num_channels=(8, 8, 8),
attention_levels=(False, False, True),
norm_num_groups=8,
num_head_channels=8,
num_class_embeds=2,
)
net.forward(
x=torch.rand((1, 1, 16, 32)),
timesteps=torch.randint(0, 1000, (1,)).long(),
)

def test_script_unconditioned_2d_models(self):
net = DiffusionModelUNet(
spatial_dims=2,
Expand Down

0 comments on commit 6151176

Please sign in to comment.