-
Notifications
You must be signed in to change notification settings - Fork 116
/
train.py
100 lines (91 loc) · 4.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from sqlnet.utils import *
from sqlnet.model.sqlnet import SQLNet
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--bs', type=int, default=16, help='Batch size')
parser.add_argument('--epoch', type=int, default=100, help='Epoch number')
parser.add_argument('--gpu', action='store_true', help='Whether use gpu to train')
parser.add_argument('--toy', action='store_true', help='If set, use small data for fast debugging')
parser.add_argument('--ca', action='store_true', help='Whether use column attention')
parser.add_argument('--train_emb', action='store_true', help='Train word embedding for SQLNet')
parser.add_argument('--restore', action='store_true', help='Whether restore trained model')
parser.add_argument('--logdir', type=str, default='', help='Path of save experiment log')
args = parser.parse_args()
n_word=300
if args.toy:
use_small=True
gpu=args.gpu
batch_size=16
else:
use_small=False
gpu=args.gpu
batch_size=args.bs
learning_rate = 1e-3
# load dataset
train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(use_small=use_small)
word_emb = load_word_emb('data/char_embedding')
model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
if args.restore:
model_path= 'saved_model/best_model'
print "Loading trained model from %s" % model_path
model.load_state_dict(torch.load(model_path))
# used to record best score of each sub-task
best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv, best_wr = 0, 0, 0, 0, 0, 0, 0, 0
best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx, best_wr_idx = 0, 0, 0, 0, 0, 0, 0, 0
best_lf, best_lf_idx = 0.0, 0
best_ex, best_ex_idx = 0.0, 0
print "#"*20+" Star to Train " + "#"*20
for i in range(args.epoch):
print 'Epoch %d'%(i+1)
# train on the train dataset
train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
# evaluate on the dev dataset
dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
# accuracy of each sub-task
print 'Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7])
# save the best model
if dev_acc[1] > best_lf:
best_lf = dev_acc[1]
best_lf_idx = i + 1
torch.save(model.state_dict(), 'saved_model/best_model')
if dev_acc[2] > best_ex:
best_ex = dev_acc[2]
best_ex_idx = i + 1
# record the best score of each sub-task
if True:
if dev_acc[0][0] > best_sn:
best_sn = dev_acc[0][0]
best_sn_idx = i+1
if dev_acc[0][1] > best_sc:
best_sc = dev_acc[0][1]
best_sc_idx = i+1
if dev_acc[0][2] > best_sa:
best_sa = dev_acc[0][2]
best_sa_idx = i+1
if dev_acc[0][3] > best_wn:
best_wn = dev_acc[0][3]
best_wn_idx = i+1
if dev_acc[0][4] > best_wc:
best_wc = dev_acc[0][4]
best_wc_idx = i+1
if dev_acc[0][5] > best_wo:
best_wo = dev_acc[0][5]
best_wo_idx = i+1
if dev_acc[0][6] > best_wv:
best_wv = dev_acc[0][6]
best_wv_idx = i+1
if dev_acc[0][7] > best_wr:
best_wr = dev_acc[0][7]
best_wr_idx = i+1
print 'Train loss = %.3f' % train_loss
print 'Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2])
print 'Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx)
print 'Best Execution: %.3f at epoch %d' % (best_ex, best_ex_idx)
if (i+1) % 10 == 0:
print 'Best val acc: %s\nOn epoch individually %s'%(
(best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),
(best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx))