-
Notifications
You must be signed in to change notification settings - Fork 26
/
finetune_distributed.py
202 lines (164 loc) · 9.38 KB
/
finetune_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch
import json
import datetime
import os
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from functools import partial
from util.vision_util import process_vision_info
from util.logutil import init_logger, get_logger
from accelerate import Accelerator
accelerator = Accelerator(gradient_accumulation_steps=2)
device = accelerator.device
if accelerator.is_local_main_process:
output_dir = f'train_output/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}/'
init_logger(output_dir)
logger = get_logger()
class ToyDataSet(Dataset): # for toy demo, for train_data/data.json
def __init__(self, data_path):
super().__init__()
with open(data_path, "r") as f:
self.data = json.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def find_assistant_content_sublist_indexes(l):
'''
A message from train_data/data.json may look like below:
{
"messages": [
{'role': 'user', 'content': [{'type': 'image', 'image': 'train_data/1.jpeg'}, {'type': 'text', 'text': '描述一下这个图片'}]},
{'role': 'assistant', 'content': [{'type': 'text', 'text': '这张图片展示了一位年轻女子和她的狗在海滩上玩耍的场景。女子穿着格子衬衫和黑色裤子,坐在沙滩上,与她的金毛犬互动。她们的手臂伸展着,似乎在进行某种游戏或训练。背景是广阔的海洋和晴朗的天空,阳光洒在沙滩上,营造出温暖而宁静的氛围。整体画面充满了快乐和放松的感觉。'}]}
]
}
After apply_chat_template, the text will look like below:
['<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>描述一下这个图片<|im_end|>\n<|im_start|>assistant\n这张图片展示了一位年轻女子和她的狗在海滩上玩耍的场景。女子穿着格子衬衫和黑色裤子,坐在沙滩上,与她的金毛犬互动。她们的手臂伸展着,似乎在进行某种游戏或训练。背景是广阔的海洋和晴朗的天空,阳光洒在沙滩上,营造出温暖而宁静的氛围。整体画面充满了快乐和放松的感觉。<|im_end|>\n']
This function tries to find the indexes of the assistant content in the input_ids list to build labels.
'''
# (Pdb++) processor.tokenizer.encode("<|im_start|>assistant\n")
# [151644, 77091, 198]
# (Pdb++) processor.tokenizer.encode("<|im_end|>\n")
# [151645, 198]
start_indexes = []
end_indexes = []
# Iterate through the list to find starting points
for i in range(len(l) - 2):
# Check if the current and next elements form the start sequence
if l[i] == 151644 and l[i+1] == 77091 and l[i+2] == 198:
start_indexes.append(i+3)
# Now look for the first 151645 and 198 after the start
for j in range(i+3, len(l)-1):
if l[j] == 151645 and l[j+1] == 198:
end_indexes.append(j+2) # **NOTE** the <|im_end|>\n 2 tokens should be included in the label, so that model can predicate end of output.
break # Move to the next start after finding the end
return list(zip(start_indexes, end_indexes))
def collate_fn(batch, processor, device):
# (Pdb++) processor.tokenizer.encode("<|im_start|>assistant")
# [151644, 77091]
# (Pdb++) processor.tokenizer.encode("<|im_end|>")
# [151645]
messages = [m['messages'] for m in batch]
texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=False) for msg in messages]
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=texts,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
input_ids_lists = inputs['input_ids'].tolist()
assert len(messages) == len(input_ids_lists)
labels_list = []
for ids_list in input_ids_lists:
label_ids = [-100] * len(ids_list) # -100 is the ignore index in loss function
for begin_end_indexs in find_assistant_content_sublist_indexes(ids_list):
label_ids[begin_end_indexs[0]:begin_end_indexs[1]] = ids_list[begin_end_indexs[0]:begin_end_indexs[1]]
labels_list.append(label_ids)
labels_ids = torch.tensor(labels_list, dtype=torch.int64)
return inputs, labels_ids
def write_chat_template(processor, output_dir):
'''
***Note**
We should have not had this function, as normal processor.save_pretrained(output_dir) would save chat_template.json file.
However, on 2024/09/05, I think a commit introduced a bug to "huggingface/transformers", which caused the chat_template.json file not to be saved.
See the below commit, src/transformers/processing_utils.py line 393, this commit avoided chat_template.json to be saved.
https://github.com/huggingface/transformers/commit/43df47d8e78238021a4273746fc469336f948314#diff-6505546ec5a9ab74b2ce6511681dd31194eb91e9fa3ce26282e487a5e61f9356
To walk around that bug, we need manually save the chat_template.json file.
I hope this bug will be fixed soon and I can remove this function then.
'''
output_chat_template_file = os.path.join(output_dir, "chat_template.json")
chat_template_json_string = json.dumps({"chat_template": processor.chat_template}, indent=2, sort_keys=True) + "\n"
with open(output_chat_template_file, "w", encoding="utf-8") as writer:
writer.write(chat_template_json_string)
logger.info(f"chat template saved in {output_chat_template_file}")
def train():
# Load the model on the available device(s)
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2VLForConditionalGeneration.from_pretrained(
# "Qwen/Qwen2-VL-2B-Instruct",
# torch_dtype=torch.bfloat16,
# attn_implementation="flash_attention_2",
# device_map="auto",
# )
# ** WARNING ** When run below line , we got below warning message:
# Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}"
# It is a issue, see https://github.com/huggingface/transformers/issues/33401
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="bfloat16"
)
# Load processor.
# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
# **Note:** About padding_side parameter, it default value is "left", here we set it as "right".
# For why, read below.
# Typically, in training, when batch size of training dataloader is > 1, it is often we need pad shorter inputs to the same length.
# To pad, we often add "padding_token_id" to the right side of shorter inputs to make them the same length and set 0 in attention_mask for those padding_token_id.
# It makes casual_mask easier to build by attention mask. for more detail, see *** notes.txt *** of this repo.
# BTW, in batching inference, we must use "padding_side" left, as generation usually uses the last token of output list of tokens.
#
# If you like to read more, here are more discussions about padding and padding side:
# https://github.com/huggingface/transformers/pull/26572
# https://github.com/pytorch/pytorch/issues/110213
# transformers/models/qwen2_vl/modeling_qwen2_vl.py: causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=256*28*28, max_pixels=512*28*28, padding_side="right")
train_loader = DataLoader(
ToyDataSet("train_data/data.json"),
batch_size=1,
collate_fn=partial(collate_fn, processor=processor, device=device)
)
model.train()
epochs = 10
optimizer = AdamW(model.parameters(), lr=1e-5)
model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
for epoch in range(epochs):
steps = 0
for batch in train_loader:
steps += 1
with accelerator.accumulate(model):
optimizer.zero_grad()
inputs, labels = batch
outputs = model(**inputs, labels=labels)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
if accelerator.is_local_main_process:
logger.info(f"Batch {steps} of epoch {epoch + 1}/{epochs}, training loss : {loss.item()}")
accelerator.wait_for_everyone()
if accelerator.is_local_main_process:
os.makedirs(output_dir, exist_ok=True)
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
)
processor.save_pretrained(output_dir)
write_chat_template(processor, output_dir)
if __name__ == "__main__":
train()