-
Notifications
You must be signed in to change notification settings - Fork 0
/
params.py
79 lines (63 loc) · 1.9 KB
/
params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def get_hyperparameters(opt, dataset='sample', rep='musicnn'):
if opt.rep == "musicnn":
opt.input_dim = 50
opt.projection = 0
elif opt.rep == "jukebox":
opt.input_dim = 4800
opt.projection = 1
elif opt.rep == "maest":
opt.input_dim = 768
opt.projection = 1
if dataset == 'music4all-onion' or dataset == 'm4a':
if rep == "musicnn":
opt.num_node = 109267
opt.batch_size = 60000
opt.neighbors = 40
opt.folds = 10
opt.layers = 2
opt.de_1 = 0.05
opt.df_1 = 0.15
opt.de_2 = 0.1
opt.df_2 = 0.05
opt.gnn_dropout = 0.0
opt.dropout = 0.2
opt.alpha = 0.01
opt.beta = 0.07
opt.clusters = 10
opt.confidence_threshold = 0.5
opt.k = 5
elif rep == "jukebox":
opt.num_node = 109267
opt.batch_size = 40000
opt.neighbors = 20
opt.folds = 10
opt.layers = 2
opt.de_1 = 0.05
opt.df_1 = 0.15
opt.de_2 = 0.1
opt.df_2 = 0.05
opt.gnn_dropout = 0.0
opt.dropout = 0.2
opt.alpha = 0.01
opt.beta = 0.07
opt.clusters = 10
opt.confidence_threshold = 0.5
opt.k = 5
elif rep == "maest":
opt.num_node = 109267
opt.batch_size = 40000
opt.neighbors = 20
opt.folds = 10
opt.layers = 2
opt.de_1 = 0.05
opt.df_1 = 0.15
opt.de_2 = 0.1
opt.df_2 = 0.05
opt.gnn_dropout = 0.0
opt.dropout = 0.2
opt.alpha = 0.01
opt.beta = 0.07
opt.clusters = 10
opt.confidence_threshold = 0.5
opt.k = 5
return opt