-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_feature.py
176 lines (145 loc) · 5.86 KB
/
test_feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import datetime
import json
import yaml
import logging
import os
import sys
import time
from pprint import pprint
from utils.parse_args_util import get_parsed_params
from utils.pipeline_analysis_util import run_pipeline
from utils.setup_analysis_environment_util import setup_analysis_environment
# =============================================================================================== #
# MAIN UTILS #
# =============================================================================================== #
def read_neural_network_params(cmd_line_params):
if cmd_line_params.network_parameters is not None:
network_params_path = cmd_line_params.network_parameters
else:
raise Exception('[ERROR] Please define a valid parameters\' filename')
# Parameters read from file
network_params = get_neural_network_params_from_file(network_params_path)
# It it exists, weights of a pre-trained model are loaded
network_params['pretrained_model'] = cmd_line_params.pretrained_model
return network_params
def get_neural_network_params_from_file(network_params_path: str) -> dict:
result_dict: dict = None
with open(network_params_path, "r") as f:
if network_params_path.endswith('json'):
result_dict = json.load(f)
elif network_params_path.endswith('yaml'):
result_dict = yaml.load(f)
return result_dict
def run_test_decorator(a_test):
def wrapper_function(a_dict):
flag_test_passed: bool = True
try:
message: str = f" [*] Running TEST for function {a_test.__name__}"
print()
print(f"{message}", '-' * len(message), sep='\n')
result = a_test(a_dict)
except Exception as err:
print(f'ERROR: {str(err)}')
flag_test_passed = False
sys.exit(-1)
finally:
status_test: str = 'PASSED' if flag_test_passed is True else 'FAILED'
message: str = f" [*] TEST on function {a_test.__name__} ended with STATUS = {status_test}"
print()
print(f"{message}", '-' * len(message), sep='\n')
return result
return wrapper_function
# =============================================================================================== #
# TESTS SECTION #
# =============================================================================================== #
@run_test_decorator
def test_pipeline_util(test_info_dict: dict):
# Data to let function to be tested
conf_load_dict = test_info_dict['conf_load_dict']
conf_preprocess_dict = test_info_dict['conf_preprocess_dict']
cmd_line_params = test_info_dict['cmd_line_params']
network_params = test_info_dict['network_params']
meta_info_project_dict = test_info_dict['meta_info_project_dict']
main_logger = test_info_dict['main_logger']
# Load Data.
run_pipeline(
conf_load_dict=conf_load_dict,
conf_preprocess_dict=conf_preprocess_dict,
cmd_line_params=cmd_line_params,
network_params=network_params,
meta_info_project_dict=meta_info_project_dict,
main_logger=main_logger
)
pass
# =============================================================================================== #
# MAIN FUNCTION #
# =============================================================================================== #
def main(cmd_line_params: dict):
base_dir: str = 'bioinfo_project'
network_params = read_neural_network_params(cmd_line_params)
print(f"----> Set up analysis environment.")
logger, meta_info_project_dict = \
setup_analysis_environment(
logger_name=__name__,
base_dir=base_dir,
params=cmd_line_params,
flag_test=True)
logger.info("\n" + json.dumps(network_params, indent=4))
# ------------------------------------------------------ #
# Here - Test pipeline util
conf_load_dict: dict = {
'sequence_type': cmd_line_params.sequence_type,
'path': cmd_line_params.network_parameters,
'columns_names': [
'Sequences','Count','Unnamed: 0','Label','Translated_sequences','Protein_length'
],
'train_bins': [1,2,3],
'val_bins': [4],
'test_bins': [5],
}
conf_preprocess_dict: dict = {
'padding': 'post',
'maxlen': network_params['maxlen'],
'onehot_flag': False,
}
pipeline_info_dict : dict = {
'conf_load_dict': conf_load_dict,
'conf_preprocess_dict': conf_preprocess_dict,
'cmd_line_params': cmd_line_params,
'network_params': network_params,
'meta_info_project_dict': meta_info_project_dict,
'main_logger': logger,
}
test_pipeline_util(pipeline_info_dict)
pass
if __name__ == "__main__":
# Useless rigth now. Just ignore
dict_images: dict = {
'loss': {
'title': 'Training With Validation Loss',
'fig_name': 'train_val_loss',
'fig_format': 'png',
'savefig_flag': True
},
'acc': {
'title': 'Training With Validation Accuracy',
'fig_name': 'train_val_acc',
'fig_format': 'png',
'savefig_flag': True
},
'roc_curve': {
'title': 'Roc Curve',
'fig_name': 'roc_curve',
'fig_format': 'png',
'savefig_flag': True
},
'confusion_matrix': {
'title': 'Confusion Matrix',
'fig_name': 'confusion_matrix',
'fig_format': 'png',
'savefig_flag': True
}
}
cmd_line_params, _ = get_parsed_params()
main(cmd_line_params)
pass