-
Notifications
You must be signed in to change notification settings - Fork 172
/
mean_teacher.py
54 lines (40 loc) · 1.73 KB
/
mean_teacher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import copy
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update
@TRAINER_REGISTRY.register()
class MeanTeacher(TrainerXU):
"""Mean teacher.
https://arxiv.org/abs/1703.01780.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.MEANTEACHER.WEIGHT_U
self.ema_alpha = cfg.TRAINER.MEANTEACHER.EMA_ALPHA
self.rampup = cfg.TRAINER.MEANTEACHER.RAMPUP
self.teacher = copy.deepcopy(self.model)
self.teacher.train()
for param in self.teacher.parameters():
param.requires_grad_(False)
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
logit_x = self.model(input_x)
loss_x = F.cross_entropy(logit_x, label_x)
target_u = F.softmax(self.teacher(input_u), 1)
prob_u = F.softmax(self.model(input_u), 1)
loss_u = ((prob_u - target_u)**2).sum(1).mean()
weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup)
loss = loss_x + loss_u*weight_u
self.model_backward_and_update(loss)
global_step = self.batch_idx + self.epoch * self.num_batches
ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
ema_model_update(self.model, self.teacher, ema_alpha)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": compute_accuracy(logit_x, label_x)[0].item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary