-
Notifications
You must be signed in to change notification settings - Fork 0
/
modeling_mamba_transformer.py
334 lines (304 loc) · 15.1 KB
/
modeling_mamba_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
from __future__ import annotations
import json
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum
from transformers import GPTNeoXLayer, GPTNeoXConfig, PreTrainedModel, PretrainedConfig, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import create_block
from safetensors.torch import load_file
from transformers.debug_utils import detect_overflow
@dataclass
class ModelArgs:
d_model: int
first_transformer_layers: int
mamba_layers: int
vocab_size: int
transformer_config: GPTNeoXConfig
mamba_config: dict
class Adapter(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.down_proj = nn.Linear(hidden_dim, hidden_dim // 4)
self.up_proj = nn.Linear(hidden_dim // 4, hidden_dim)
self.non_linear = nn.Sigmoid()
torch.nn.init.zeros_(self.down_proj.weight)
torch.nn.init.zeros_(self.down_proj.bias)
torch.nn.init.zeros_(self.up_proj.weight)
torch.nn.init.zeros_(self.up_proj.bias)
def forward(self, x):
t = self.down_proj(x)
t = self.non_linear(t)
return self.up_proj(t) + x
class MambaTransformer(nn.Module):
def __init__(self, args: ModelArgs):
"""Full Mamba model."""
super().__init__()
self.args = args
self.max_len = 1024
self.embed_in = nn.Embedding(args.vocab_size, args.d_model)
self.emb_dropout = nn.Dropout(args.transformer_config.hidden_dropout)
self.first_transformer_layers = nn.ModuleList([GPTNeoXLayer(args.transformer_config) for _ in range(args.first_transformer_layers)])
self.mamba_layers = nn.ModuleList(
[
create_block(
d_model=args.d_model,
d_intermediate=0,
#rms_norm=False,
#residual_in_fp32=True,
#fused_add_norm=False,
#layer_idx=i,
)
for i in range(args.mamba_layers)
]
)
self._use_flash_attention_2 = args.transformer_config._attn_implementation == "flash_attention_2"
self.final_transformer_layer = GPTNeoXLayer(args.transformer_config)
self.final_layer_norm = nn.LayerNorm(args.transformer_config.hidden_size, eps=args.transformer_config.layer_norm_eps)
self.embed_out = nn.Linear(args.d_model, args.vocab_size, bias=False)
self.embed_out.weight = self.embed_in.weight # Tie output projection to embedding weights.
# See "Weight Tying" paper
def forward(self, input_ids, attention_mask, attn_dtype):
input_shape = input_ids.size()
batch_size, seq_length = input_shape
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
if self._use_flash_attention_2:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(attn_dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(attn_dtype).min
past_length = 0
device = input_ids.device
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
x = self.embed_in(input_ids)
x = self.emb_dropout(x)
head_mask = [None] * (self.args.first_transformer_layers+1)
for i, layer in enumerate(self.first_transformer_layers):
outputs = layer(
x,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=True,
)
x = outputs[0]
residual = None
for layer in self.mamba_layers:
x, residual = layer(
x, residual
)
x = x + residual
x = self.final_transformer_layer(
x,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[self.args.first_transformer_layers],
use_cache=True,
)[0]
x = self.final_layer_norm(x)
logits = self.embed_out(x)
return logits
@staticmethod
def from_pretrained(pretrained_mamba_name: str, pretrained_pythia_name: str, first_transformer_layers=None, mamba_start_layer=None, mamba_end_layer=None):
"""Load pretrained weights from HuggingFace into model.
Args:
pretrained_model_name: One of
* 'state-spaces/mamba-2.8b-slimpj'
* 'state-spaces/mamba-2.8b'
* 'state-spaces/mamba-1.4b'
* 'state-spaces/mamba-790m'
* 'state-spaces/mamba-370m'
* 'state-spaces/mamba-130m'
Returns:
model: Mamba model with weights loaded
"""
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file
def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME,
_raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))
mamba_config_data = load_config_hf(pretrained_mamba_name)
pythia_config_data = GPTNeoXConfig.from_pretrained(pretrained_pythia_name)
# Originally we have 12 transformer layers, now we keep 8 and replace the next 3 with 4 mamba layers.
# But we still keep the last transformer layer.
args = ModelArgs(
d_model=mamba_config_data['d_model'],
mamba_layers=mamba_end_layer-mamba_start_layer+1,
first_transformer_layers=first_transformer_layers,
vocab_size=pythia_config_data.vocab_size,
transformer_config=pythia_config_data,
mamba_config=mamba_config_data
)
model = MambaTransformer(args).to('cuda')
def load_state_dict_hf(model_name, device=None, dtype=None):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=True)
return torch.load(resolved_archive_file, weights_only=True, map_location='cuda', mmap=True)
# breakpoint()
mamba_state_dict = load_state_dict_hf(pretrained_mamba_name)
pythia_state_dict = load_state_dict_hf(pretrained_pythia_name)
transformer_target_layers_set = set()
mamba_target_layers_set = set()
for i in range(first_transformer_layers):
transformer_target_layers_set.add(i)
transformer_target_layers_set.add(pythia_config_data.num_hidden_layers-1)
for i in range(mamba_start_layer, mamba_end_layer+1):
mamba_target_layers_set.add(i)
pattern = r".layers\.(\d+)\."
new_state_dict = {}
for key in mamba_state_dict:
match = re.search(pattern, key)
if match:
layer_index = int(match.group(1))
if layer_index in mamba_target_layers_set:
new_key = key.replace('backbone.layers', 'mamba_layers').replace(str(layer_index), str(layer_index-mamba_start_layer))
new_state_dict[new_key] = mamba_state_dict[key]
for key in pythia_state_dict:
if 'embed' in key or 'final_layer_norm' in key:
new_key = key.replace('gpt_neox.', '')
new_state_dict[new_key] = pythia_state_dict[key]
else:
match = re.search(pattern, key)
if match:
layer_index = int(match.group(1))
if layer_index in transformer_target_layers_set:
if layer_index == pythia_config_data.num_hidden_layers-1:
new_key = key.replace('gpt_neox.layers.', '').replace(str(layer_index), 'final_transformer_layer')
else:
new_key = key.replace('gpt_neox.layers', 'first_transformer_layers')
new_state_dict[new_key] = pythia_state_dict[key]
model.load_state_dict(new_state_dict, strict=False)
model.embed_in = nn.Embedding.from_pretrained(pythia_state_dict['gpt_neox.embed_in.weight'].to(model.embed_in.weight.dtype))
#breakpoint()
# model.embed_in.requires_grad_ = False
return model
def freeze_layers(self, freeze_transformers=False, freeze_mamba=False):
"""Freezes all parameters except for those in the Mamba layers."""
if freeze_transformers or freeze_mamba:
for param in self.parameters():
param.requires_grad = False
if freeze_mamba:
self.embed_in.requires_grad_ = True
for layer in self.first_transformer_layers:
for param in layer.parameters():
param.requires_grad = True
else:
for layer in self.mamba_layers:
for param in layer.parameters():
param.requires_grad = True
for param in self.final_transformer_layer.parameters():
param.requires_grad = True
for param in self.final_layer_norm.parameters():
param.requires_grad = True
for param in self.embed_out.parameters():
param.requires_grad = True
class MambaTransformerForLM(PreTrainedModel):
def __init__(self,
config=None,
pretrained_mamba_name=None,
pretrained_pythia_name=None,
check_point_path=None,
sft=False,
distilling=False,
T=4,
distill_loss_weight=0.5,
first_transformer_layers=None,
mamba_start_layer=None,
mamba_end_layer=None,
freeze_mamba=False,
freeze_transformers=False):
super().__init__(config)
self.pretrained_mamba_name = pretrained_mamba_name
self.pretrained_pythia_name = pretrained_pythia_name
teacher_name = 'EleutherAI/pythia-410m'
self.model = MambaTransformer.from_pretrained(pretrained_mamba_name,
pretrained_pythia_name,
first_transformer_layers,
mamba_start_layer,
mamba_end_layer)
if check_point_path is not None:
loaded = load_file(check_point_path)
keys_to_change = list(loaded.keys()) # Create a list of keys to iterate over
for key in keys_to_change:
new_key = key.replace('model.', '')
loaded[new_key] = loaded.pop(key) # Move the value to the new key and remove the old key
self.model.load_state_dict(loaded, strict=False)
del loaded
self.model.freeze_layers(freeze_transformers, freeze_mamba)
self.teacher = None
self.sft = sft
if distilling:
self.batch_count = 0
device = 'cuda'
self.teacher = AutoModelForCausalLM.from_pretrained(teacher_name).to(device)
self.T = T
self.distill_loss_weight = distill_loss_weight
self.log_steps = 0
self.ce_loss_sum = 0
self.distill_loss_sum = 0
def forward(self, input_ids, attention_mask, labels=None):
logits = self.model(input_ids, attention_mask, self.dtype)
if labels is None:
return {"logits": logits}
else:
if not self.sft:
cross_entropy_fcn = nn.CrossEntropyLoss()
shift_logits = logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
cross_entropy_loss = cross_entropy_fcn(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
else:
ignore_index = -1
shift_logits = logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
seq_len = labels.size()[1]
labels[:, :(seq_len+1) // 3 * 2-1] = ignore_index
#labels[:, seq_len // 2:] = ignore_index
criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)
cross_entropy_loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
if self.teacher is not None:
kl_loss = nn.KLDivLoss(reduction="batchmean")
self.batch_count += 1
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
teacher_logits = self.teacher(input_ids, attention_mask=attention_mask).logits
s_log_probs = F.log_softmax(logits/self.T, dim=-1)
t_probs = F.softmax(teacher_logits/self.T, dim=-1)
distill_loss = kl_loss(s_log_probs, t_probs) / t_probs.size()[1] * (self.T**2)
total_loss = self.distill_loss_weight*distill_loss + (1-self.distill_loss_weight)*cross_entropy_loss
self.ce_loss_sum += cross_entropy_loss.item()
self.distill_loss_sum += distill_loss.item()
if self.batch_count == 50:
self.log_steps += 50
s = "Step:" + str(self.log_steps) + ",CE loss:" + str(self.ce_loss_sum / 100) + ",Soft loss:" + str(self.distill_loss_sum / 100)+'\n'
print(s)
# Open a file in append mode
with open('record.txt', 'a') as file:
file.write(s)
self.ce_loss_sum = 0
self.distill_loss_sum = 0
self.batch_count = 0
return {"loss": total_loss, "logits": logits}
else:
return {"loss": cross_entropy_loss, "logits": logits}
class MambaTransformerConfig(PretrainedConfig):
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)