-
Notifications
You must be signed in to change notification settings - Fork 22
/
run_finqa_cot_gpt3.py
179 lines (158 loc) · 7.67 KB
/
run_finqa_cot_gpt3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import json
from time import sleep
from tqdm import tqdm
import os
import openai
from datetime import datetime
from eval_tatqa.tatqa_utils import extract_one_num_from_str
from tool import *
from typing import Dict, Any
import argparse
from collections import Counter
parser = argparse.ArgumentParser()
parser.add_argument("--key", default='OPENAI_KEY', type=str)
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--greedy", default=False, action='store_true')
parser.add_argument("--dry_run", default=False, action='store_true')
parser.add_argument("--end", default=-1, type=int)
args = parser.parse_args()
def create_reader_request_processed(example: Dict[str, Any]):
prompt = 'Read the following text and table, and then answer a question:\n'
if example['text']:
prompt += example['text'] + '\n'
prompt += example['table'].strip() + '\n'
prompt += 'Question: {}\n'.format(example['question'])
return prompt
prompt_4shot = """Read the following text and table, and then answer a question:
$ in millions | year ended december 2014 | year ended december 2013 | year ended december 2012
fixed income currency and commodities client execution | $ 8461 | $ 8651 | $ 9914
equities client execution1 | 2079 | 2594 | 3171
commissions and fees | 3153 | 3103 | 3053
securities services | 1504 | 1373 | 1986
total equities | 6736 | 7070 | 8210
total net revenues | 15197 | 15721 | 18124
operating expenses | 10880 | 11792 | 12490
pre-tax earnings | $ 4317 | $ 3929 | $ 5634
Question: what was the percentage change in pre-tax earnings for the institutional client services segment between 2012 and 2013?
The pre-tax earnings for the institutional client services segment in 2012 was $ 5634 million , and in 2013 was $ 3929 million. The net change in pre-tax earnings was $ 1705 million, and the percentage change was 30.3%. So the answer is 30.3%.
Read the following text and table, and then answer a question:
during the year ended march 31 , 2012 , the company has recorded $ 3.3 million in stock-based compensation expense for equity awards in which the prescribed performance milestones have been achieved or are probable of being achieved .
- | number of shares ( in thousands ) | weighted average grant date fair value ( per share )
restricted stock and restricted stock units at beginning of year | 407 | $ 9.84
granted | 607 | 18.13
vested | -134 ( 134 ) | 10.88
forfeited | -9 ( 9 ) | 13.72
restricted stock and restricted stock units at end of year | 871 | $ 15.76
Question: during the 2012 year , did the equity awards in which the prescribed performance milestones were achieved exceed the equity award compensation expense for equity granted during the year?
The prescribed performance milestones is 3.3 million. The number of shares is 607 thousand. The fair value is 18.13. The compoensation expense is 607 thousand * 18.13 = 11 million. So the answer is no.
Read the following text and table, and then answer a question:
annual sales of printing papers and graphic arts supplies and equipment totaled $ 3.5 billion in 2012 compared with $ 4.0 billion in 2011 and $ 4.2 billion in 2010 , reflecting declining demand and the exiting of unprofitable businesses .
in millions | 2012 | 2011 | 2010
sales | $ 6040 | $ 6630 | $ 6735
operating profit | 22 | 34 | 78
Question: what percent of distribution sales where attributable to printing papers and graphic arts supplies and equipment in 2011?
The sales of print papers and graphic arts supplies and equipment in 2011 is 3.5 billion. The total sales in 2011 is 6.63 billion. The percentage is 52.8%. So the answer is 52.8%.
Read the following text and table, and then answer a question:
- | september 24 2005 | september 25 2004 | september 27 2003
beginning allowance balance | $ 47 | $ 49 | $ 51
charged to costs and expenses | 8 | 3 | 4
deductions ( a ) | -9 ( 9 ) | -5 ( 5 ) | -6 ( 6 )
ending allowance balance | $ 46 | $ 47 | $ 49
Question: what was the highest ending allowance balance , in millions?
The ending allowance balance in 2005 is 47. The ending allowance balance in 2004 is 49. The ending allowance balance in 2003 is 51. The highest ending allowance balance is 51. So the answer is 51.
"""
if __name__ == "__main__":
with open('data/finqa_dev.json') as f:
finqa_dev = json.load(f)
now = datetime.now()
dt_string = now.strftime("%m_%d_%H_%M")
correct, wrong = 0, 0
if args.greedy:
filename = f'outputs/finqa_cot_gpt3_s{args.start}_e{args.end}_{dt_string}.jsonl'
else:
filename = f'outputs/finqa_cot_gpt3_sc_s{args.start}_e{args.end}_{dt_string}.jsonl'
writer = open(filename, 'w')
for example in tqdm(finqa_dev):
full_prompt = prompt_4shot + "\n\n"
full_prompt += create_reader_request_processed(example)
if args.dry_run:
print(full_prompt)
print('=======================')
break
if args.greedy:
# greedy decoding
got_result = False
while not got_result:
try:
result = openai.Completion.create(
engine='text-davinci-002',
prompt=full_prompt,
api_key=os.getenv(args.key),
max_tokens=128,
temperature=0.0,
top_p=1,
n=1,
stop=['\n\n'],
logprobs=1
)
got_result = True
except Exception:
sleep(3)
else:
# self-consistency decoding
got_result = False
while not got_result:
try:
result = openai.Completion.create(
engine='text-davinci-002',
prompt=full_prompt,
api_key=os.getenv(args.key),
max_tokens=128,
temperature=0.5,
top_p=1,
n=30,
stop=['\n\n'],
logprobs=1
)
got_result = True
except Exception as e:
sleep(3)
# self-consistency decoding or greedy decoding.
result_counter = Counter()
codes = parse_api_result(result)
codes = [code.split('answer is')[-1].strip() for code in codes]
for r in codes:
print(r)
ans = extract_one_num_from_str(r)
if not ans:
if 'yes' in r.lower() or 'true' in r.lower():
ans = 'yes'
elif 'no' in r.lower() or 'false' in r.lower():
ans = 'no'
if ans is not None:
if type(ans) in [dict]:
result_counter.update(list(ans.values()))
elif type(ans) in [list, tuple]:
result_counter.update([float(ans[0])])
elif type(ans) in [str]:
result_counter.update([ans])
else:
try:
result_counter.update([float(ans)])
except Exception:
continue
if len(result_counter) > 0:
prediction = result_counter.most_common(1)[0][0]
else:
prediction = None
if prediction is None:
wrong += 1
elif finqa_equal(prediction, example['answer'], True, True):
correct += 1
else:
wrong += 1
example.update({'generated': codes, 'executed': prediction})
writer.write(json.dumps(example) + '\n')
print()
print('accuracy: ', correct / (correct + wrong))
writer.close()