-
Notifications
You must be signed in to change notification settings - Fork 63
/
arg_parsing.py
133 lines (128 loc) · 3.71 KB
/
arg_parsing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
from squeezenet import networks
class ArgParser(object):
def __init__(self):
self.parser = self._create_parser()
def parse_args(self, args=None):
args = self.parser.parse_args(args)
return args
@staticmethod
def _create_parser():
program_name = 'Squeezenet Training Program'
desc = 'Program for training squeezenet with periodic evaluation.'
parser = argparse.ArgumentParser(program_name, description=desc)
parser.add_argument(
'--model_dir',
type=str,
required=True,
help='''Output directory for checkpoints and summaries.'''
)
parser.add_argument(
'--train_tfrecord_filepaths',
nargs='+',
type=str,
required=True,
help='''Filepaths of the TFRecords to be used for training.'''
)
parser.add_argument(
'--validation_tfrecord_filepaths',
nargs='+',
type=str,
required=True,
help='''Filepaths of the TFRecords to be used for evaluation.'''
)
parser.add_argument(
'--network',
type=str,
required=True,
choices=networks.catalogue
)
parser.add_argument(
'--target_image_size',
default=[224, 224],
nargs=2,
type=int,
help='''Input images will be resized to this.'''
)
parser.add_argument(
'--num_classes',
default=10,
type=int,
required=True,
help='''Number of classes (unique labels) in the dataset.
Ignored if using CIFAR network version.'''
)
parser.add_argument(
'--num_gpus',
default=1,
type=int,
required=True,
)
parser.add_argument(
'--batch_size',
type=int,
required=True
)
parser.add_argument(
'--learning_rate', '-l',
type=float,
default=0.001,
help='''Initial learning rate for ADAM optimizer.'''
)
parser.add_argument(
'--batch_norm_decay',
type=float,
default=0.9
)
parser.add_argument(
'--weight_decay',
type=float,
default=0.0,
help='''L2 regularization factor for convolution layer weights.
0.0 indicates no regularization.'''
)
parser.add_argument(
'--num_input_threads',
default=1,
type=int,
required=True,
help='''The number input elements to process in parallel.'''
)
parser.add_argument(
'--shuffle_buffer',
type=int,
required=True,
help='''The minimum number of elements in the pool of training data
from which to randomly sample.'''
)
parser.add_argument(
'--seed',
default=1337,
type=int
)
parser.add_argument(
'--max_train_steps',
default=1801,
type=int
)
parser.add_argument(
'--summary_interval',
default=100,
type=int
)
parser.add_argument(
'--checkpoint_interval',
default=100,
type=int
)
parser.add_argument(
'--validation_interval',
default=100,
type=int
)
parser.add_argument(
'--keep_last_n_checkpoints',
default=3,
type=int
)
return parser