-
Notifications
You must be signed in to change notification settings - Fork 154
/
train_model_pairwise_cls.py
150 lines (127 loc) · 4.7 KB
/
train_model_pairwise_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import random
import os
import numpy as np
import json
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoConfig
from transformers import BertPreTrainedModel, BertModel
from transformers import AdamW, get_scheduler
from tqdm.auto import tqdm
def seed_everything(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# some cudnn methods can be random even after fixing the seed
# unless you tell it to be deterministic
torch.backends.cudnn.deterministic = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
seed_everything(42)
learning_rate = 1e-5
batch_size = 4
epoch_num = 3
checkpoint = "bert-base-chinese"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
class AFQMC(Dataset):
def __init__(self, data_file):
self.data = self.load_data(data_file)
def load_data(self, data_file):
Data = {}
with open(data_file, 'rt') as f:
for idx, line in enumerate(f):
sample = json.loads(line.strip())
Data[idx] = sample
return Data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
train_data = AFQMC('data/afqmc_public/train.json')
valid_data = AFQMC('data/afqmc_public/dev.json')
def collote_fn(batch_samples):
batch_sentence_1, batch_sentence_2 = [], []
batch_label = []
for sample in batch_samples:
batch_sentence_1.append(sample['sentence1'])
batch_sentence_2.append(sample['sentence2'])
batch_label.append(int(sample['label']))
X = tokenizer(
batch_sentence_1,
batch_sentence_2,
padding=True,
truncation=True,
return_tensors="pt"
)
y = torch.tensor(batch_label)
return X, y
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collote_fn)
valid_dataloader= DataLoader(valid_data, batch_size=batch_size, shuffle=False, collate_fn=collote_fn)
class BertForPairwiseCLS(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(768, 2)
self.post_init()
def forward(self, x):
outputs = self.bert(**x)
cls_vectors = outputs.last_hidden_state[:, 0, :]
cls_vectors = self.dropout(cls_vectors)
logits = self.classifier(cls_vectors)
return logits
config = AutoConfig.from_pretrained(checkpoint)
model = BertForPairwiseCLS.from_pretrained(checkpoint, config=config).to(device)
def train_loop(dataloader, model, loss_fn, optimizer, lr_scheduler, epoch, total_loss):
progress_bar = tqdm(range(len(dataloader)))
progress_bar.set_description(f'loss: {0:>7f}')
finish_step_num = (epoch-1)*len(dataloader)
model.train()
for step, (X, y) in enumerate(dataloader, start=1):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
total_loss += loss.item()
progress_bar.set_description(f'loss: {total_loss/(finish_step_num + step):>7f}')
progress_bar.update(1)
return total_loss
def test_loop(dataloader, model, mode='Test'):
assert mode in ['Valid', 'Test']
size = len(dataloader.dataset)
correct = 0
model.eval()
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
correct /= size
print(f"{mode} Accuracy: {(100*correct):>0.1f}%\n")
return correct
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=epoch_num*len(train_dataloader),
)
total_loss = 0.
best_acc = 0.
for t in range(epoch_num):
print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
total_loss = train_loop(train_dataloader, model, loss_fn, optimizer, lr_scheduler, t+1, total_loss)
valid_acc = test_loop(valid_dataloader, model, mode='Valid')
if valid_acc > best_acc:
best_acc = valid_acc
print('saving new weights...\n')
torch.save(model.state_dict(), f'epoch_{t+1}_valid_acc_{(100*valid_acc):0.1f}_model_weights.bin')
print("Done!")