forked from google-research/t5x
-
Notifications
You must be signed in to change notification settings - Fork 7
/
finetune.gin
148 lines (132 loc) · 4.23 KB
/
finetune.gin
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Defaults for finetuning with train.py.
#
#
# You must also include a binding for MODEL.
#
# Required to be set:
#
# - MIXTURE_OR_TASK_NAME
# - TASK_FEATURE_LENGTHS
# - TRAIN_STEPS # includes pretrain steps
# - MODEL_DIR # automatically set when using xm_launch
# - INITIAL_CHECKPOINT_PATH
#
# When running locally, it needs to be passed in the `gin.MODEL_DIR` flag.
#
# `TRAIN_STEPS` should include pre-training steps, e.g., if pre-trained ckpt
# has 1M steps, TRAIN_STEPS = 1.1M will perform 0.1M fine-tuning steps.
#
# Commonly overridden options:
# - DROPOUT_RATE
# - BATCH_SIZE
# - PjitPartitioner.num_partitions
# - Trainer.num_microbatches
# - USE_CACHED_TASKS: Whether to look for preprocessed SeqIO data, or preprocess
# on the fly. Most common tasks are cached, hence this is set to True by
# default.
from __gin__ import dynamic_registration
import __main__ as train_script
import seqio
from t5x import gin_utils
from t5x import partitioning
from t5x import utils
from t5x import trainer
# Must be overridden
MODEL_DIR = %gin.REQUIRED
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TASK_FEATURE_LENGTHS = %gin.REQUIRED
MIXTURE_OR_TASK_MODULE = %gin.REQUIRED
TRAIN_STEPS = %gin.REQUIRED
INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
# Commonly overridden
DROPOUT_RATE = 0.1
USE_CACHED_TASKS = True
BATCH_SIZE = 128
# Sometimes overridden
EVAL_STEPS = 20
# Convenience overrides.
EVALUATOR_USE_MEMORY_CACHE = True
EVALUATOR_NUM_EXAMPLES = None # Use all examples in the infer_eval dataset.
JSON_WRITE_N_RESULTS = None # Write all inferences.
# HW RNG is faster than SW, but has limited determinism.
# Most notably it is not deterministic across different
# submeshes.
USE_HARDWARE_RNG = False
# None always uses faster, hardware RNG
RANDOM_SEED = None
# DEPRECATED: Import the this module in your gin file.
MIXTURE_OR_TASK_MODULE = None
train_script.train:
model = %MODEL # imported from separate gin file
model_dir = %MODEL_DIR
train_dataset_cfg = @train/utils.DatasetConfig()
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
checkpoint_cfg = @utils.CheckpointConfig()
partitioner = @partitioning.ModelBasedPjitPartitioner()
trainer_cls = @trainer.Trainer
total_steps = %TRAIN_STEPS
eval_steps = %EVAL_STEPS
eval_period = 1000
random_seed = %RANDOM_SEED
use_hardware_rng = %USE_HARDWARE_RNG
summarize_config_fn = @gin_utils.summarize_gin_config
inference_evaluator_cls = @seqio.Evaluator
partitioning.ModelBasedPjitPartitioner:
num_partitions = 1
model_parallel_submesh = ()
seqio.Evaluator:
logger_cls = [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
num_examples = %EVALUATOR_NUM_EXAMPLES
use_memory_cache = %EVALUATOR_USE_MEMORY_CACHE
seqio.JSONLogger:
write_n_results = %JSON_WRITE_N_RESULTS
train/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'train'
batch_size = %BATCH_SIZE
shuffle = True
seed = None # use a new seed each run/restart
use_cached = %USE_CACHED_TASKS
pack = True
module = %MIXTURE_OR_TASK_MODULE
train_eval/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = %TASK_FEATURE_LENGTHS
split = 'validation'
batch_size = %BATCH_SIZE
shuffle = False
seed = 42
use_cached = %USE_CACHED_TASKS
pack = True
module = %MIXTURE_OR_TASK_MODULE
infer_eval/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = None # compute max
split = 'validation'
batch_size = %BATCH_SIZE
shuffle = False
seed = 42
use_cached = %USE_CACHED_TASKS
pack = False
module = %MIXTURE_OR_TASK_MODULE
utils.CheckpointConfig:
restore = @utils.RestoreCheckpointConfig()
save = @utils.SaveCheckpointConfig()
utils.RestoreCheckpointConfig:
path = %INITIAL_CHECKPOINT_PATH
mode = 'specific'
dtype = 'float32'
utils.SaveCheckpointConfig:
period = 5000
dtype = 'float32'
keep = None # keep all checkpoints
save_dataset = False # don't checkpoint dataset state
trainer.Trainer:
num_microbatches = None
learning_rate_fn = @utils.create_learning_rate_scheduler()
utils.create_learning_rate_scheduler:
factors = 'constant'
base_learning_rate = 0.001
warmup_steps = 1000