forked from wengong-jin/hgraph2graph
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_generator.py
executable file
·113 lines (90 loc) · 4.25 KB
/
train_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
import rdkit
import math, random, sys
import numpy as np
import argparse
import os
from tqdm.auto import tqdm
from hgraph import *
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument('--train', required=True)
parser.add_argument('--vocab', required=True)
parser.add_argument('--atom_vocab', default=common_atom_vocab)
parser.add_argument('--save_dir', required=True)
parser.add_argument('--load_model', default=None)
parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--rnn_type', type=str, default='LSTM')
parser.add_argument('--hidden_size', type=int, default=250)
parser.add_argument('--embed_size', type=int, default=250)
parser.add_argument('--batch_size', type=int, default=50)
parser.add_argument('--latent_size', type=int, default=32)
parser.add_argument('--depthT', type=int, default=15)
parser.add_argument('--depthG', type=int, default=15)
parser.add_argument('--diterT', type=int, default=1)
parser.add_argument('--diterG', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--clip_norm', type=float, default=5.0)
parser.add_argument('--step_beta', type=float, default=0.001)
parser.add_argument('--max_beta', type=float, default=1.0)
parser.add_argument('--warmup', type=int, default=10000)
parser.add_argument('--kl_anneal_iter', type=int, default=2000)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--anneal_rate', type=float, default=0.9)
parser.add_argument('--anneal_iter', type=int, default=25000)
parser.add_argument('--print_iter', type=int, default=50)
parser.add_argument('--save_iter', type=int, default=5000)
args = parser.parse_args()
print(args)
torch.manual_seed(args.seed)
random.seed(args.seed)
vocab = [x.strip("\r\n ").split() for x in open(args.vocab)]
args.vocab = PairVocab(vocab)
model = HierVAE(args).cuda()
print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,))
for param in model.parameters():
if param.dim() == 1:
nn.init.constant_(param, 0)
else:
nn.init.xavier_normal_(param)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, args.anneal_rate)
if args.load_model:
print('continuing from checkpoint ' + args.load_model)
model_state, optimizer_state, total_step, beta = torch.load(args.load_model)
model.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
else:
total_step = beta = 0
param_norm = lambda m: math.sqrt(sum([p.norm().item() ** 2 for p in m.parameters()]))
grad_norm = lambda m: math.sqrt(sum([p.grad.norm().item() ** 2 for p in m.parameters() if p.grad is not None]))
meters = np.zeros(6)
for epoch in range(args.epoch):
dataset = DataFolder(args.train, args.batch_size)
for batch in tqdm(dataset):
total_step += 1
model.zero_grad()
loss, kl_div, wacc, iacc, tacc, sacc = model(*batch, beta=beta)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
optimizer.step()
meters = meters + np.array([kl_div, loss.item(), wacc * 100, iacc * 100, tacc * 100, sacc * 100])
if total_step % args.print_iter == 0:
meters /= args.print_iter
print("[%d] Beta: %.3f, KL: %.2f, loss: %.3f, Word: %.2f, %.2f, Topo: %.2f, Assm: %.2f, PNorm: %.2f, GNorm: %.2f" % (total_step, beta, meters[0], meters[1], meters[2], meters[3], meters[4], meters[5], param_norm(model), grad_norm(model)))
sys.stdout.flush()
meters *= 0
if total_step % args.save_iter == 0:
ckpt = (model.state_dict(), optimizer.state_dict(), total_step, beta)
torch.save(ckpt, os.path.join(args.save_dir, f"model.ckpt.{total_step}"))
if total_step % args.anneal_iter == 0:
scheduler.step()
print("learning rate: %.6f" % scheduler.get_lr()[0])
if total_step >= args.warmup and total_step % args.kl_anneal_iter == 0:
beta = min(args.max_beta, beta + args.step_beta)