-
Notifications
You must be signed in to change notification settings - Fork 172
/
mixmatch.py
98 lines (77 loc) · 3.05 KB
/
mixmatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling.ops import mixup
from dassl.modeling.ops.utils import (
sharpen_prob, create_onehot, linear_rampup, shuffle_index
)
@TRAINER_REGISTRY.register()
class MixMatch(TrainerXU):
"""MixMatch: A Holistic Approach to Semi-Supervised Learning.
https://arxiv.org/abs/1905.02249.
"""
def __init__(self, cfg):
super().__init__(cfg)
self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U
self.temp = cfg.TRAINER.MIXMATCH.TEMP
self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA
self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP
def check_cfg(self, cfg):
assert cfg.DATALOADER.K_TRANSFORMS > 1
def forward_backward(self, batch_x, batch_u):
input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
num_x = input_x.shape[0]
global_step = self.batch_idx + self.epoch * self.num_batches
weight_u = self.weight_u * linear_rampup(global_step, self.rampup)
# Generate pseudo-label for unlabeled data
with torch.no_grad():
output_u = 0
for input_ui in input_u:
output_ui = F.softmax(self.model(input_ui), 1)
output_u += output_ui
output_u /= len(input_u)
label_u = sharpen_prob(output_u, self.temp)
label_u = [label_u] * len(input_u)
label_u = torch.cat(label_u, 0)
input_u = torch.cat(input_u, 0)
# Combine and shuffle labeled and unlabeled data
input_xu = torch.cat([input_x, input_u], 0)
label_xu = torch.cat([label_x, label_u], 0)
input_xu, label_xu = shuffle_index(input_xu, label_xu)
# Mixup
input_x, label_x = mixup(
input_x,
input_xu[:num_x],
label_x,
label_xu[:num_x],
self.beta,
preserve_order=True,
)
input_u, label_u = mixup(
input_u,
input_xu[num_x:],
label_u,
label_xu[num_x:],
self.beta,
preserve_order=True,
)
# Compute losses
output_x = F.softmax(self.model(input_x), 1)
loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean()
output_u = F.softmax(self.model(input_u), 1)
loss_u = ((label_u - output_u)**2).mean()
loss = loss_x + loss_u*weight_u
self.model_backward_and_update(loss)
loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch_x, batch_u):
input_x = batch_x["img"][0]
label_x = batch_x["label"]
label_x = create_onehot(label_x, self.num_classes)
input_u = batch_u["img"]
input_x = input_x.to(self.device)
label_x = label_x.to(self.device)
input_u = [input_ui.to(self.device) for input_ui in input_u]
return input_x, label_x, input_u