-
Notifications
You must be signed in to change notification settings - Fork 11
/
unet3d.py
135 lines (110 loc) · 5.55 KB
/
unet3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation
Paper URL: https://arxiv.org/abs/1606.06650
Author: Amir Aghdam
"""
from torch import nn
from torchsummary import summary
import torch
import time
class Conv3DBlock(nn.Module):
"""
The basic block for double 3x3x3 convolutions in the analysis path
-- __init__()
:param in_channels -> number of input channels
:param out_channels -> desired number of output channels
:param bottleneck -> specifies the bottlneck block
-- forward()
:param input -> input Tensor to be convolved
:return -> Tensor
"""
def __init__(self, in_channels, out_channels, bottleneck = False) -> None:
super(Conv3DBlock, self).__init__()
self.conv1 = nn.Conv3d(in_channels= in_channels, out_channels=out_channels//2, kernel_size=(3,3,3), padding=1)
self.bn1 = nn.BatchNorm3d(num_features=out_channels//2)
self.conv2 = nn.Conv3d(in_channels= out_channels//2, out_channels=out_channels, kernel_size=(3,3,3), padding=1)
self.bn2 = nn.BatchNorm3d(num_features=out_channels)
self.relu = nn.ReLU()
self.bottleneck = bottleneck
if not bottleneck:
self.pooling = nn.MaxPool3d(kernel_size=(2,2,2), stride=2)
def forward(self, input):
res = self.relu(self.bn1(self.conv1(input)))
res = self.relu(self.bn2(self.conv2(res)))
out = None
if not self.bottleneck:
out = self.pooling(res)
else:
out = res
return out, res
class UpConv3DBlock(nn.Module):
"""
The basic block for upsampling followed by double 3x3x3 convolutions in the synthesis path
-- __init__()
:param in_channels -> number of input channels
:param out_channels -> number of residual connections' channels to be concatenated
:param last_layer -> specifies the last output layer
:param num_classes -> specifies the number of output channels for dispirate classes
-- forward()
:param input -> input Tensor
:param residual -> residual connection to be concatenated with input
:return -> Tensor
"""
def __init__(self, in_channels, res_channels=0, last_layer=False, num_classes=None) -> None:
super(UpConv3DBlock, self).__init__()
assert (last_layer==False and num_classes==None) or (last_layer==True and num_classes!=None), 'Invalid arguments'
self.upconv1 = nn.ConvTranspose3d(in_channels=in_channels, out_channels=in_channels, kernel_size=(2, 2, 2), stride=2)
self.relu = nn.ReLU()
self.bn = nn.BatchNorm3d(num_features=in_channels//2)
self.conv1 = nn.Conv3d(in_channels=in_channels+res_channels, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
self.conv2 = nn.Conv3d(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=(3,3,3), padding=(1,1,1))
self.last_layer = last_layer
if last_layer:
self.conv3 = nn.Conv3d(in_channels=in_channels//2, out_channels=num_classes, kernel_size=(1,1,1))
def forward(self, input, residual=None):
out = self.upconv1(input)
if residual!=None: out = torch.cat((out, residual), 1)
out = self.relu(self.bn(self.conv1(out)))
out = self.relu(self.bn(self.conv2(out)))
if self.last_layer: out = self.conv3(out)
return out
class UNet3D(nn.Module):
"""
The 3D UNet model
-- __init__()
:param in_channels -> number of input channels
:param num_classes -> specifies the number of output channels or masks for different classes
:param level_channels -> the number of channels at each level (count top-down)
:param bottleneck_channel -> the number of bottleneck channels
:param device -> the device on which to run the model
-- forward()
:param input -> input Tensor
:return -> Tensor
"""
def __init__(self, in_channels, num_classes, level_channels=[64, 128, 256], bottleneck_channel=512) -> None:
super(UNet3D, self).__init__()
level_1_chnls, level_2_chnls, level_3_chnls = level_channels[0], level_channels[1], level_channels[2]
self.a_block1 = Conv3DBlock(in_channels=in_channels, out_channels=level_1_chnls)
self.a_block2 = Conv3DBlock(in_channels=level_1_chnls, out_channels=level_2_chnls)
self.a_block3 = Conv3DBlock(in_channels=level_2_chnls, out_channels=level_3_chnls)
self.bottleNeck = Conv3DBlock(in_channels=level_3_chnls, out_channels=bottleneck_channel, bottleneck= True)
self.s_block3 = UpConv3DBlock(in_channels=bottleneck_channel, res_channels=level_3_chnls)
self.s_block2 = UpConv3DBlock(in_channels=level_3_chnls, res_channels=level_2_chnls)
self.s_block1 = UpConv3DBlock(in_channels=level_2_chnls, res_channels=level_1_chnls, num_classes=num_classes, last_layer=True)
def forward(self, input):
#Analysis path forward feed
out, residual_level1 = self.a_block1(input)
out, residual_level2 = self.a_block2(out)
out, residual_level3 = self.a_block3(out)
out, _ = self.bottleNeck(out)
#Synthesis path forward feed
out = self.s_block3(out, residual_level3)
out = self.s_block2(out, residual_level2)
out = self.s_block1(out, residual_level1)
return out
if __name__ == '__main__':
#Configurations according to the Xenopus kidney dataset
model = UNet3D(in_channels=3, num_classes=1)
start_time = time.time()
summary(model=model, input_size=(3, 16, 128, 128), batch_size=-1, device="cpu")
print("--- %s seconds ---" % (time.time() - start_time))