Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yeon-lab authored Jun 10, 2023
1 parent da469a6 commit 19e064f
Showing 1 changed file with 48 additions and 67 deletions.
115 changes: 48 additions & 67 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,46 @@
import collections
import numpy as np
import pandas as pd
import os.path
import os

import random
import torch
import pickle

import model.loss as module_loss
import model.models as model_arch
import model.metric as module_metric
from parse_config import ConfigParser
from utils.util import *
from utils.parse_config import ConfigParser
from trainer import Trainer
from model.sample_weighting import *
from model import *
from model.sampling_weight import *
from utils.load_data import init_data

import torch
import torch.nn as nn
def main(params, config, dataset, version):

# fix random seeds for reproducibility
SEED = 1111
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]

def weights_init_normal(m):
if type(m) == nn.Linear:
torch.nn.init.uniform_(m.weight.data, -0.1, 0.1)
torch.nn.init.constant_(m.bias.data, 0.0)
elif type(m) == nn.GRU:
for layer_p in m._all_weights:
for p in layer_p:
if 'weight' in p:
torch.nn.init.uniform_(m.__getattr__(p), -0.1, 0.1)
if 'bias' in p:
torch.nn.init.constant_(m.__getattr__(p), 0.0)

def main(params, config, dataset):
if config['arch']['type'] == 'Retain':
model = Retain(config=config)
elif config['arch']['type'] == 'Dipole':
model = Dipole(config=config)
elif config['arch']['type'] == 'LSTM':
model = LSTM(config=config)
elif config['arch']['type'] == 'GRU':
model = GRU(config=config)
elif config['arch']['type'] == 'Concare':
model = Concare(config=config)
elif config['arch']['type'] == 'Stagenet':
model = Stagenet(config=config)
model = getattr(model_arch, params.model)
model = model(config, criterion)
model.weights_init()

model.apply(weights_init_normal)
logger = config.get_logger('train')
logger.info('='*100)
logger.info(' {:25s}: {}'.format("Model", params.model))
logger.info("-"*100)
for key, value in config['hyper_params'].items():
logger.info(' {:25s}: {}'.format(str(key), value))
logger.info("-"*100)
logger.info(model)
logger.info("-"*100)


criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]

# build optimizer
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.init_obj('optimizer', torch.optim, model_parameters)

if params.version == 'weight':
dataset = reweight_data(dataset, config['hyper_params']["max_visit"], config['hyper_params']['n_feat'], params.steps, params.step_lr, params.kl_weight,
params.dist_weight, params.kl_dim)

if version == 'weight':
dataset = reweight_data(dataset, config['hyper_params']["max_visit"], config['hyper_params']['n_feat'], params.steps, params.step_lr, params.kl_weight, params.dist_weight, params.kl_dim)
trainer = Trainer(model,
optimizer,
criterion,
metrics,
config=config,
dataset=dataset
Expand All @@ -83,7 +57,9 @@ def main(params, config, dataset):
args.add_argument('-d', '--device', default="0", type=str,
help='indices of GPUs to enable (default: all)')
args.add_argument('-c', '--config', type=str)
args.add_argument('--version', type=str)
args.add_argument('--version', default='basic', type=str)
args.add_argument('--time', type=int)
args.add_argument('--target', type=str)
args.add_argument('--day_dim', type=int)
args.add_argument('--rnn_hidden', type=int)
args.add_argument('--model', type=str)
Expand All @@ -93,27 +69,32 @@ def main(params, config, dataset):
args.add_argument('--kl_weight', type=float)
args.add_argument('--dist_weight', type=float)
args.add_argument('--kl_dim', type=int)
args.add_argument('--np_data_dir', type=str,
help='Directory containing numpy files')
args.add_argument('--data_file',type=str)


CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
params = args.parse_args()
config = ConfigParser.from_args(args)
config['hyper_params']['day_dim']= params.day_dim
config['hyper_params']['rnn_hidden']= params.rnn_hidden

if 'mimic' in params.np_data_dir:
config['hyper_params']["max_visit"] = 29
config['hyper_params']["min_visit"] = 3
else:
config['hyper_params']["max_visit"] = 30
config['hyper_params']["min_visit"] = 10
data_file = 'pickles/input_{}_{}.pkl'.format(params.target, params.time)

version = params.version
if params.model == 'AdaDiag' or params.model == 'DG':
version = 'basic'

dataset, feat_dict, y_dict = init_data(data_file = args2.data_file, npy_path=args2.np_data_dir, config=config)
exper_name = '{}_{}/{}_{}/day_dim_{}_rnn_hidden_{}'.format(params.target, params.time, params.model, version, params.day_dim, params.rnn_hidden)
if version == 'weight':
exper_name = exper_name + '/steps_{}_step_lr_{}_kl_{}_dist_{}_kl_dim_{}'.format(params.steps, params.step_lr, params.kl_weight, params.dist_weight, params.kl_dim)
exper_name += '/iter_{}'.format(iter)
config = ConfigParser.from_args(args, exper_name)
config['hyper_params']['model'] = params.model
config['hyper_params']['day_dim']= params.day_dim
config['hyper_params']['rnn_hidden']= params.rnn_hidden
config['hyper_params']['version'] = version
config['optimizer']['args']['weight_decay']= params.weight_decay
config['hyper_params']["min_visit"], config['hyper_params']["max_visit"] = 10, 30

dataset, config, feat_dict, y_dict = init_data(data_file, config)
config['hyper_params']['n_feat'], config['hyper_params']['n_class'] = len(feat_dict.keys()), len(y_dict.keys())

log, log_per_month = main(params, config, dataset)
np.save('{}_{}_log_per_month'.format(params.version, config['arch']['type']), log_per_month)
np.save('{}_{}_log'.format(params.version, config['arch']['type']), log)
log = main(params, config, dataset, version)


0 comments on commit 19e064f

Please sign in to comment.