Skip to content

Commit

Permalink
Remove asserts (#383)
Browse files Browse the repository at this point in the history
* Remove asserts

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Update generative/networks/nets/diffusion_model_unet.py

Co-authored-by: Eric Kerfoot <[email protected]>
Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Edit docstring

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

* Fix test

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>

---------

Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
Warvito and ericspod authored May 3, 2023
1 parent 639b6eb commit 60a3851
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
18 changes: 13 additions & 5 deletions generative/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,8 @@ def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_peri
embedding_dim: the dimension of the output.
max_period: controls the minimum frequency of the embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
if timesteps.ndim != 1:
raise ValueError("Timesteps should be a 1d-array")

half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
Expand All @@ -491,7 +492,8 @@ class Downsample(nn.Module):
Args:
spatial_dims: number of spatial dimensions.
num_channels: number of input channels.
use_conv: if True uses Convolution instead of Pool average to perform downsampling.
use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is
False, the number of output channels must be the same as the number of input channels.
out_channels: number of output channels.
padding: controls the amount of implicit zero-paddings on both sides for padding number of points
for each dimension.
Expand All @@ -515,12 +517,17 @@ def __init__(
conv_only=True,
)
else:
assert self.num_channels == self.out_channels
if self.num_channels != self.out_channels:
raise ValueError("num_channels and out_channels must be equal when use_conv=False")
self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2)

def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
del emb
assert x.shape[1] == self.num_channels
if x.shape[1] != self.num_channels:
raise ValueError(
f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels "
f"({self.num_channels})"
)
return self.op(x)


Expand Down Expand Up @@ -559,7 +566,8 @@ def __init__(

def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
del emb
assert x.shape[1] == self.num_channels
if x.shape[1] != self.num_channels:
raise ValueError("Input channels should be equal to num_channels")

# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# https://github.com/pytorch/pytorch/issues/86679
Expand Down
14 changes: 14 additions & 0 deletions tests/test_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,20 @@ def test_shape_unconditioned_models(self, input_param):
result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long())
self.assertEqual(result.shape, (1, 1, 16, 16))

def test_timestep_with_wrong_shape(self):
net = DiffusionModelUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_res_blocks=1,
num_channels=(8, 8, 8),
attention_levels=(False, False, False),
norm_num_groups=8,
)
with self.assertRaises(ValueError):
with eval_mode(net):
net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long())

def test_shape_with_different_in_channel_out_channel(self):
in_channels = 6
out_channels = 3
Expand Down

0 comments on commit 60a3851

Please sign in to comment.