-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
74 lines (72 loc) · 2.57 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
from GSvqa import GSvqaV2
from dataset import *
import time
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='UnCoRd test.')
parser.add_argument('--image_dir', type=str)
parser.add_argument(
'--questions_file', type=str,
help='A json file with questions, images indices and answers'
)
parser.add_argument(
'--answer_vocab_file', type=str,
help='Vocabulary for VL-BERT estimator.'
)
parser.add_argument(
'--properties_file', type=str,
help='Property categories of your dataset.')
parser.add_argument(
'--num_questions', type=int, default=None,
help='Number of questions.')
parser.add_argument(
'--test_mode', type=bool, default=False,
help='Answers are not required in test mode.'
)
parser.add_argument(
'--shuffle', type=bool, default=False,
help='Shuffles questions if true.'
)
parser.add_argument(
'--device', type=str, default='cpu',
help='Device for VL-BERT estimator.'
)
parser.add_argument(
'--question_index', type=int, default=None,
help='If true, answers a single question with queried index.'
)
args = parser.parse_args()
if args.question_index:
questions = []
gT = []
ques, gt = get_question_by_idx(
args.questions_file, args.question_index, test=args.test_mode
)
questions.append(ques)
gT.append(gt)
else:
questions, gT = get_questions_and_answers(
args.questions_file, args.num_questions,
test=args.test_mode, shuffle=args.shuffle
)
model = GSvqaV2(args.device, args.answer_vocab_file, args.properties_file)
accuracy = 0
with open('test_answers.txt', 'w') as f:
for i, question in enumerate(questions):
start = time.time()
answer = model.get_answer(args.image_dir, question)
end = time.time()
print(f"Question: {question['question']}")
print(f"Model answer: {answer}")
if not args.test_mode:
print(f"Ground Truth: {gT[i]}")
print(f"Image id: {question['image_index']}")
print(f"Question id: {question['question_index']}")
print(f"VL-BERT calls: {model.num_vlbert_calls}")
print(f"Wall time: {(end - start):.4f}s")
if answer == gT[i]:
accuracy += 1
print('\n')
f.write(answer + '\n')
if not args.test_mode:
print(f"Accuracy: {accuracy / len(questions):.2f}")