forked from jcjohnson/densecap
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.lua
36 lines (31 loc) · 1.25 KB
/
models.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
local M = {}
function M.setup(opt)
local model
if opt.checkpoint_start_from == '' then
print('initializing a DenseCap model from scratch...')
model = DenseCapModel(opt)
else
print('initializing a DenseCap model from ' .. opt.checkpoint_start_from)
model = torch.load(opt.checkpoint_start_from).model
model.opt.end_objectness_weight = opt.end_objectness_weight
model.nets.localization_layer.opt.mid_objectness_weight = opt.mid_objectness_weight
model.nets.localization_layer.opt.mid_box_reg_weight = opt.mid_box_reg_weight
model.crits.box_reg_crit.w = opt.end_box_reg_weight
local rpn = model.nets.localization_layer.nets.rpn
rpn:findModules('nn.RegularizeLayer')[1].w = opt.box_reg_decay
model.opt.train_remove_outbounds_boxes = opt.train_remove_outbounds_boxes
model.opt.captioning_weight = opt.captioning_weight
if cudnn then
cudnn.convert(model.net, cudnn)
cudnn.convert(model.nets.localization_layer.nets.rpn, cudnn)
end
end
-- Find all Dropout layers and set their probabilities
local dropout_modules = model.nets.recog_base:findModules('nn.Dropout')
for i, dropout_module in ipairs(dropout_modules) do
dropout_module.p = opt.drop_prob
end
model:float()
return model
end
return M