-
Notifications
You must be signed in to change notification settings - Fork 17
/
optim.py
83 lines (68 loc) · 2.89 KB
/
optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#! python
# -*- coding: utf-8 -*-
# Author: kun
# @Time: 2019-10-29 20:41
import torch
import numpy as np
from functools import partial
class Optimizer(object):
def __init__(self, parameters, optimizer, lr, eps, lr_scheduler, tf_start=1, tf_end=1, tf_step=1, **kwargs):
# Setup teacher forcing scheduler
self.tf_type = tf_end != 1
self.tf_rate = lambda step: max(
tf_end, tf_start - (tf_start - tf_end) * step / tf_step)
# Setup torch optimizer
self.opt_type = optimizer
self.init_lr = lr
self.sch_type = lr_scheduler
opt = getattr(torch.optim, optimizer)
if lr_scheduler == 'warmup':
warmup_step = 4000.0
init_lr = lr
self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \
np.minimum((step + 1) * warmup_step ** -1.5, (step + 1) ** -0.5)
self.opt = opt(parameters, lr=1.0)
elif lr_scheduler == 'spec-aug-basic':
# Scheduler from https://arxiv.org/pdf/1904.08779.pdf
self.lr_scheduler = partial(speech_aug_scheduler, s_r=500, s_i=20000, s_f=80000, peak_lr=lr)
self.opt = opt(parameters, lr=lr, eps=eps)
elif lr_scheduler == 'spec-aug-double':
# Scheduler from https://arxiv.org/pdf/1904.08779.pdf
self.lr_scheduler = partial(speech_aug_scheduler, s_r=1000, s_i=40000, s_f=160000, peak_lr=lr)
self.opt = opt(parameters, lr=lr, eps=eps)
else:
self.lr_scheduler = None
self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better?
def get_opt_state_dict(self):
return self.opt.state_dict()
def load_opt_state_dict(self, state_dict):
self.opt.load_state_dict(state_dict)
def pre_step(self, step):
if self.lr_scheduler is not None:
cur_lr = self.lr_scheduler(step)
for param_group in self.opt.param_groups:
param_group['lr'] = cur_lr
self.opt.zero_grad()
return self.tf_rate(step)
def step(self):
self.opt.step()
def create_msg(self):
return ['Optim.spec.| Algo. = {}\t| Lr = {}\t (Scheduler = {})| Scheduled sampling = {}'
.format(self.opt_type, self.init_lr, self.sch_type, self.tf_type)]
def speech_aug_scheduler(step, s_r, s_i, s_f, peak_lr):
# Starting from 0, ramp-up to set LR and converge to 0.01*LR, w/ exp. decay
final_lr_ratio = 0.01
exp_decay_lambda = -np.log10(final_lr_ratio) / (s_f - s_i) # Approx. w/ 10-based
cur_step = step + 1
if cur_step < s_r:
# Ramp-up
return peak_lr * float(cur_step) / s_r
elif cur_step < s_i:
# Hold
return peak_lr
elif cur_step <= s_f:
# Decay
return peak_lr * np.power(10, -exp_decay_lambda * (cur_step - s_i))
else:
# Converge
return peak_lr * final_lr_ratio