-
Notifications
You must be signed in to change notification settings - Fork 3
/
run_lr.py
executable file
·139 lines (111 loc) · 4.03 KB
/
run_lr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from bow_lr import *
import numpy as np
import argparse
from utils import load_train_dev_test_json, extract_key, extract_keys, remap_dict, load_jsonl
from pandas import DataFrame
import pickle
import os
def save_prediction(prediction, path):
with open(path, 'w') as fw:
for p in prediction:
fw.write(str(p) + '\n')
def eval_on_dataset(model, dict_name_map, path, args):
if "split_detail" in path:
return
print(path)
eval_data = load_jsonl(path)
if args.tittxt:
eval_data = [remap_dict(ex, dict_name_map) for ex in eval_data]
predicted = model.predict(DataFrame.from_dict(extract_keys(eval_data, ['title', 'text'])))
else:
predicted = model.predict(extract_key(eval_data, args.input_key))
acc = np.mean(predicted == extract_key(eval_data, args.target_key))
print(acc)
def run_lr(args):
skl_data = load_train_dev_test_json(args.data_dir)
eval_set = skl_data['dev'] if not args.use_test_set else skl_data['test']
if args.tittxt:
model = lr_plus_clf
input_keys = args.input_key.strip().split(',')
title_key, text_key = input_keys[0], input_keys[1]
dict_name_map = {title_key:'title', text_key:'text'}
mapped_train = [remap_dict(ex, dict_name_map) for ex in skl_data['train']]
eval_set = [remap_dict(ex, dict_name_map) for ex in eval_set]
model.fit(DataFrame.from_dict(extract_keys(mapped_train, ['title', 'text'])), extract_key(mapped_train, args.target_key))
else:
# model = NB_clf
dict_name_map = None
model = lr_clf
model.fit(extract_key(skl_data['train'], args.input_key), extract_key(skl_data['train'], args.target_key))
if args.tittxt:
predicted = model.predict(DataFrame.from_dict(extract_keys(eval_set, ['title', 'text'])))
else:
predicted = model.predict(extract_key(eval_set, args.input_key))
acc = np.mean(predicted == extract_key(eval_set, args.target_key))
print(acc)
if args.save_prediction is not None:
save_prediction(predicted, args.save_prediction)
if args.save_model is not None:
with open(args.save_model, 'wb') as f:
pickle.dump(model, f)
if args.test_dir is not None:
for subdir, dirs, files in os.walk(args.test_dir):
for filename in files:
filepath = subdir + os.sep + filename
if (filepath.endswith('.jsonl') or filepath.endswith('.json')) and not filename.startswith('.'):
eval_on_dataset(model, dict_name_map, filepath, args)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Containing train/dev/test.jsonl"
)
parser.add_argument(
"--input_key",
default=None,
type=str,
required=True,
help="The input key"
)
parser.add_argument(
'--target_key',
default=None,
type=str,
required=True,
help="The target key"
)
parser.add_argument(
'--use_test_set',
action="store_true",
help="if set use test set instead of dev set"
)
parser.add_argument(
'--save_prediction',
default=None,
type=str,
help="If not None, save the prediction to this location"
)
parser.add_argument(
'--save_model',
default=None,
type=str,
help="If not None, save the model to this location"
)
parser.add_argument(
'--tittxt',
action="store_true",
help="if set, use both title and text, and the input_key argument should in this format: [TITLE_KEY],[TEXT_KEY]"
)
parser.add_argument(
'--test_dir',
default=None,
type=str,
help = "If not None, test on all jsonl files in that dir"
)
args = parser.parse_args()
run_lr(args)