forked from RUCAIBox/MML
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_attention_fusion_train.py
45 lines (37 loc) · 1.58 KB
/
run_attention_fusion_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# @Time : 2020/7/20
# @Author : Shanlei Mu
# @Email : [email protected]
# UPDATE
# @Time : 2020/10/3, 2020/10/1
# @Author : Yupeng Hou, Zihan Lin
# @Email : [email protected], [email protected]
import argparse
from recbole.quick_start import run_attention_fusion_train
if __name__ == '__main__':
try:
parser = argparse.ArgumentParser()
parser.add_argument('--model_list', '-m', type=str, default='BPR', help='list of name of models')
parser.add_argument('--dataset', '-d', type=str, default='ml-100k', help='name of datasets')
parser.add_argument('--config_files', type=str, default=None, help='config files')
parser.add_argument('--saved', type=str, default='True', help='saved')
parser.add_argument('--hint', type=str, default='', help='hint for run_recbole')
args, _ = parser.parse_known_args()
model_list = eval(args.model_list) if args.model_list else None
config_file_lists = eval(args.config_files) if args.config_files else None
saved = (args.saved.lower() == 'true')
result = run_attention_fusion_train(
model_name_list=model_list, dataset=args.dataset, config_file_lists=config_file_lists, saved=saved
)
try:
from mtjupyter_utils import remind
message = f'best_valid_loss: {result}'
remind(message)
except Exception as e:
pass
except Exception as err:
try:
from mtjupyter_utils import remind
remind(str(err))
raise err
except Exception as e:
raise err