-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
52 lines (45 loc) · 2.3 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import json
import os
import random
from experiments.exp_basic import *
from datetime import datetime
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Autoformer & Transformer family for Time Series Forecasting')
parser.add_argument('--cfg_file', type=str, required=True, default='NA', help='path to the config file')
args = parser.parse_args()
# if os.paths.exists()
# with open('cfgs/exp/MTGNN/MTGNN_yellow_taxi_2022-01.json','r') as f:
# with open('cfgs/exp/MTGNN/MTGNN_Electricity.json','r') as f:
# with open('cfgs/exp/MTGNN/MTGNN_Traffic.json','r') as f:
# with open('cfgs/exp/MTGNN/MTGNN_solar_AL.json','r') as f:
# with open('cfgs/exp/MTGNN/MTGNN_wiki_rolling_nips.json','r') as f:
# with open('cfgs/exp/SCINet/SCINet_wind.json','r') as f:
# with open('cfgs/exp/MTGNN/MTGNN_ETTh1_example.json','r') as f:
# with open('cfgs/exp/MTGNN/MTGNN_PEMS03.json','r') as f:
# with open('cfgs/exp/MTGNN/MTGNN_exchange_rate.json','r') as f:
# with open('cfgs/exp/MTGNN/MTGNN_metr_la.json','r') as f:
with open(args.cfg_file,'r') as f:
cfg =json.load(f)
# else
model_save_dir = 'cache/{}/{}/{}'.format(cfg['model']['model_name'], cfg['data']['dataset_name'], cfg['data']['horizon'])
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
model_name = cfg['model']['model_name']
dataset_name = cfg['data']['dataset_name']
exp = Exp_Basic(cfg, model_save_dir)
if cfg['exp']['train']['training'] or not os.path.exists(model_save_dir):
print("Start training for model:", model_name, " dataset:", dataset_name)
before_train = datetime.now().timestamp()
print("===================Train-Start=========================")
exp.train()
after_train = datetime.now().timestamp()
print(f'Training took {(after_train - before_train) / 60} minutes')
print("===================Train-End=========================")
else:
exp.load_model()
print("Start testing for model:", model_name, " dataset:", dataset_name)
before_evaluation = datetime.now().timestamp()
exp.test()
after_evaluation = datetime.now().timestamp()
print('Test/evaluation took: {} minutes'.format((after_evaluation - before_evaluation) / 60))