-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
116 lines (103 loc) · 5.3 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import numpy as np
import time
import os
from six.moves import cPickle
import utils.opts as opts
import models
from utils.dataloader import *
import utils.eval_utils as eval_utils
import argparse
import utils.utils as utils
import torch
# Input arguments and options
parser = argparse.ArgumentParser()
# Input paths
parser.add_argument('--ckpt_path', type=str, default='log_svbase',
help='path to model dir to evaluate')
# Basic options
parser.add_argument('--batch_size', type=int, default=10,
help='if > 0 then overrule, otherwise load from checkpoint.')
parser.add_argument('--num_images', type=int, default=5000,
help='how many images to use when periodically evaluating the loss? (-1 = all)')
parser.add_argument('--language_eval', type=int, default=1,
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
parser.add_argument('--dump_images', type=int, default=0,
help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
parser.add_argument('--dump_json', type=int, default=0,
help='Dump json with predictions into vis folder? (1=yes,0=no)')
parser.add_argument('--dump_path', type=int, default=0,
help='Write image paths along with predictions into vis json? (1=yes,0=no)')
# Sampling options
parser.add_argument('--sample_max', type=int, default=1,
help='1 = sample argmax words. 0 = sample from distributions.')
parser.add_argument('--beam_size', type=int, default=5,
help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
parser.add_argument('--temperature', type=float, default=1.0,
help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.')
# For evaluation on a folder of images:
parser.add_argument('--image_folder', type=str, default='',
help='If this is nonempty then will predict on the images in this folder path')
parser.add_argument('--image_root', type=str, default='',
help='In case the image paths have to be preprended with a root path to an image folder')
# For evaluation on MSCOCO images from some split:
parser.add_argument('--image_feat_dir', type=str, default='datasets/mscoco/features/frcn-r101',
help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_label_h5', type=str, default='datasets/mscoco/annotations/cocotalk_label.h5',
help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_json', type=str, default='datasets/mscoco/annotations/cocotalk.json',
help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
parser.add_argument('--embed_weight_file', default='datasets/mscoco/annotations/glove_embeding.npy', type=str, help='file path of embeding weight file')
parser.add_argument('--split', type=str, default='test',
help='if running on MSCOCO images, which split to use: val|test|train')
# misc
parser.add_argument('--id', type=str, default='',
help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
parser.add_argument('--gpu_id', type=str, default='0',
help='gpu id')
opt = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id
# Load infos
with open(os.path.join(opt.ckpt_path, 'infos-best.pkl'), 'rb') as f:
infos = cPickle.load(f, encoding='latin-1')
# override and collect parameters
if len(opt.image_feat_dir) == 0:
opt.image_feat_dir = infos['opt'].image_feat_dir
opt.input_label_h5 = infos['opt'].input_label_h5
if len(opt.input_json) == 0:
opt.input_json = infos['opt'].input_json
if opt.batch_size == 0:
opt.batch_size = infos['opt'].batch_size
if len(opt.id) == 0:
opt.id = infos['opt'].id
ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval", "image_feat_dir", "gpu_id", "input_json", "input_label_h5", "embed_weight_file"]
for k in vars(infos['opt']).keys():
if k not in ignore:
if k in vars(opt):
assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
else:
vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model
vocab = infos['vocab'] # ix -> word mapping
# Setup the model
model = models.setup(opt)
model.load_state_dict(torch.load(os.path.join(opt.ckpt_path, "model-best.pth")))
model.cuda()
model.eval()
crit = utils.LanguageModelCriterion()
# Create the Data Loader instance
loader = DataLoader(opt)
# When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
# So make sure to use the vocab in infos file.
loader.ix_to_word = infos['vocab']
# Set sample options
loss, split_predictions, lang_stats = eval_utils.eval_split(model, loader,
vars(opt))
print('loss: ', loss)
if lang_stats:
print(lang_stats)
if opt.dump_json == 1:
# dump the json
json.dump(split_predictions, open('vis/vis.json', 'w'))