-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_script.py
104 lines (74 loc) · 3.24 KB
/
inference_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Script to generate GTV segmentation outputs given a trained model and a validation set.
- The outputs could be either GTV foreground probability maps or binary labelmaps, depending on user's shoice
- Optionally, avg Dice metric can be computed
- Optionally, the model outputs and the metrics can be saved to disk
"""
import argparse
import logging
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from datautils.preprocessing import Preprocessor
from datasets.hecktor_unimodal_dataset import HECKTORUnimodalDataset
from datasets.hecktor_petct_dataset import HECKTORPETCTDataset
from datautils.patch_sampling import PatchSampler3D, PatchQueue
from datautils.patch_aggregation import PatchAggregator3D
import nnmodules
from evalutils.inferer import Inferer
import config_utils
# Constants
DEFAULT_DATA_CONFIG_FILE = "./config_files/data-crFHN_rs113-petct_default.yaml"
DEFAULT_NN_CONFIG_FILE = "./config_files/nn-msam3d_default.yaml"
DEFAULT_INFERENCE_CONFIG_FILE = "./config_files/infer-default.yaml"
def get_cli_args():
parser = argparse.ArgumentParser()
# Config filepaths
parser.add_argument("--data_config_file",
type=str,
help="Path to the data config file",
default=DEFAULT_DATA_CONFIG_FILE)
parser.add_argument("--nn_config_file",
type=str,
help="Path to the network config file",
default=DEFAULT_NN_CONFIG_FILE)
parser.add_argument("--infer_config_file",
type=str,
help="Path to the trainval config file",
default=DEFAULT_INFERENCE_CONFIG_FILE)
args = parser.parse_args()
return args
def main(global_config):
# -----------------------------------------------
# Data pipeline
# -----------------------------------------------
# Dataset
preprocessor = Preprocessor(**global_config['preprocessor-kwargs'])
if not global_config['inferer-kwargs']['input_data_config']['is-bimodal']:
dataset = HECKTORUnimodalDataset(**global_config['dataset-kwargs'], preprocessor=preprocessor)
else:
dataset = HECKTORPETCTDataset(**global_config['dataset-kwargs'], preprocessor=preprocessor)
# Patch based inference stuff
volume_loader = DataLoader(dataset, batch_size=1, shuffle=False)
# print(len(volume_loader))
patch_sampler = PatchSampler3D(**global_config['patch-sampler-kwargs'])
patch_aggregator = PatchAggregator3D(**global_config['patch-aggregator-kwargs'])
# -----------------------------------------------
# Network
# -----------------------------------------------
if global_config['nn-name'] == "unet3d":
model = nnmodules.UNet3D(**global_config['nn-kwargs'])
elif global_config['nn-name'] == "msam3d":
model = nnmodules.MSAM3D(**global_config['nn-kwargs'])
# -----------------------------------------------
# Inference
# -----------------------------------------------
inferer = Inferer(model,
volume_loader, patch_sampler, patch_aggregator,
**global_config['inferer-kwargs'])
inferer.run_inference()
if __name__ == '__main__':
cli_args = get_cli_args()
global_config = config_utils.build_config(cli_args, training=False)
main(global_config)