-
Notifications
You must be signed in to change notification settings - Fork 1
/
Head.py
87 lines (77 loc) · 2.87 KB
/
Head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from torch import nn
import math
class CARHead(torch.nn.Module):
def __init__(self, in_channels):
"""
Arguments:
in_channels (int): number of channels of the input feature
"""
super(CARHead, self).__init__()
# TODO: Implement the sigmoid version first.
num_classes = 2 # We only predict background or foreground
Num_convs = 4 # Typically we use 4 conv layers in heads
num_vertex = 36 # Use 36 lines from a specific point, each line has angle of 10 degrees.
cls_tower = []
mask_tower = []
for i in range(Num_convs):
cls_tower.append(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
)
cls_tower.append(nn.GroupNorm(32, in_channels))
cls_tower.append(nn.ReLU())
mask_tower.append(
nn.Conv2d(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1
)
)
mask_tower.append(nn.GroupNorm(32, in_channels))
mask_tower.append(nn.ReLU())
self.add_module('cls_tower', nn.Sequential(*cls_tower))
self.add_module('mask_tower', nn.Sequential(*mask_tower))
self.cls_logits = nn.Conv2d(
in_channels, num_classes, kernel_size=3, stride=1,
padding=1
)
self.mask_pred = nn.Conv2d(
in_channels, num_vertex, kernel_size=3, stride=1,
padding=1
)
self.centerness = nn.Conv2d(
in_channels, 1, kernel_size=3, stride=1,
padding=1
)
# initialization
for modules in [self.cls_tower, self.mask_tower,
self.cls_logits, self.mask_pred,
self.centerness]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
# initialize the bias for focal loss
prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
torch.nn.init.constant_(self.cls_logits.bias, bias_value)
def forward(self, x):
cls_tower = self.cls_tower(x)
logits = self.cls_logits(cls_tower)
centerness = self.centerness(cls_tower)
mask_reg = torch.exp(self.mask_pred(self.mask_tower(x)))
return logits, mask_reg, centerness
class Scale(nn.Module):
def __init__(self, init_value=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.FloatTensor([init_value]))
def forward(self, input):
return input * self.scale