-
Notifications
You must be signed in to change notification settings - Fork 17
/
attention_intervention_model.py
713 lines (588 loc) · 28.5 KB
/
attention_intervention_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
"""
Changes the huggingface transformer attention module to allow interventions
in the attention distribution.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class AttentionOverride(nn.Module):
"""A copy of `modeling_gpt2.Attention` class, but with overridden attention values"""
def __init__(self, attention, attn_override, attn_override_mask):
"""
Args:
attention: instance of modeling_gpt2.Attention from which variables will be
copied.
attn_override: values to override the computed attention weights.
Shape is [num_heads, seq_len, seq_len]
attn_override_mask: indicates which attention weights to override.
Shape is [num_heads, seq_len, seq_len]
"""
super(AttentionOverride, self).__init__()
# Copy values from attention
self.output_attentions = attention.output_attentions
self.register_buffer("bias", attention._buffers["bias"])
self.n_head = attention.n_head
self.split_size = attention.split_size
self.scale = attention.scale
self.c_attn = attention.c_attn
self.c_proj = attention.c_proj
self.attn_dropout = attention.attn_dropout
self.resid_dropout = attention.resid_dropout
# Set attention override values
self.attn_override = attn_override
self.attn_override_mask = attn_override_mask
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
nd, ns = w.size(-2), w.size(-1)
b = self.bias[:, :, ns - nd : ns, :ns]
w = w * b - 1e4 * (1 - b)
if attention_mask is not None:
# Apply the attention mask
w = w + attention_mask
w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
# Mask heads if we want to
if head_mask is not None:
w = w * head_mask
# attn_override and attn_override_mask are of shape
# (batch_size, num_heads, override_seq_len, override_seq_len)
# where override_seq_len is the length of subsequence for which attention is
# being overridden.
override_seq_len = self.attn_override_mask.shape[-1]
w[:, :, :override_seq_len, :override_seq_len] = torch.where(
self.attn_override_mask,
self.attn_override,
w[:, :, :override_seq_len, :override_seq_len],
)
outputs = [torch.matmul(w, v)]
if self.output_attentions:
outputs.append(w)
return outputs
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
if layer_past is not None:
past_key, past_value = (
layer_past[0].transpose(-2, -1),
layer_past[1],
) # transpose back cf below
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2)
present = torch.stack(
(key.transpose(-2, -1), value)
) # transpose to have same shapes for stacking
attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
a = attn_outputs[0]
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
outputs = [a, present] + attn_outputs[1:]
return outputs # a, present, (attentions)
class TXLAttentionOverride(nn.Module):
""" A copy of `modeling_transfo_xl.RelPartialLearnableMultiHeadAttn` class,
but with overridden attention values """
def __init__(self, module, attn_override, attn_override_mask):
"""
Args:
module: instance of modeling_transfo_xl.RelPartialLearnableMultiHeadAttn
from which variables will be copied
attn_override: values to override the computed attention weights.
Shape is [bsz, num_heads, seq_len, seq_len]
attn_override_mask: indicates which attention weights to override.
Shape is [bsz, num_heads, seq_len, seq_len]
"""
super(TXLAttentionOverride, self).__init__()
# Copy values from module
self.output_attentions = module.output_attentions
self.n_head = module.n_head
self.d_model = module.d_model
self.d_head = module.d_head
self.dropout = module.dropout
self.qkv_net = module.qkv_net
self.drop = module.drop
self.dropatt = module.dropatt
self.o_net = module.o_net
self.layer_norm = module.layer_norm
self.scale = module.scale
self.pre_lnorm = module.pre_lnorm
self.r_r_bias = module.r_r_bias
self.r_w_bias = module.r_w_bias
self.r_net = module.r_net
# Set attention override values
self.attn_override = attn_override
self.attn_override_mask = attn_override_mask
def _rel_shift(self, x):
zero_pad_shape = (x.size(0), 1) + x.size()[2:]
zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
x_padded = x_padded.view(*x_padded_shape)
x = x_padded[1:].view_as(x)
return x
def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
# compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
rr_head_q = w_head_q + self.r_r_bias
BD = torch.einsum("ibnd,jnd->ijbn", (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
BD = self._rel_shift(BD)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
# compute attention probability
if attn_mask is not None and torch.sum(attn_mask).item():
attn_mask = attn_mask == 1 # Switch to bool
if attn_mask.dim() == 2:
if next(self.parameters()).dtype == torch.float16:
attn_score = (
attn_score.float().masked_fill(attn_mask[None, :, :, None], -65000).type_as(attn_score)
)
else:
attn_score = attn_score.float().masked_fill(attn_mask[None, :, :, None], -1e30).type_as(attn_score)
elif attn_mask.dim() == 3:
if next(self.parameters()).dtype == torch.float16:
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -65000).type_as(attn_score)
else:
attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -1e30).type_as(attn_score)
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * head_mask
# Intervention:
# attn_override and attn_override_mask are of shape (bsz, n_heads, query_seq_len, key_seq_len)
# attn_prob is of shape (query_seq_len, key_seq_len, bsz, n_heads)
_, _, override_q_len, override_k_len = self.attn_override_mask.shape
attn_prob[:override_q_len, :override_k_len, :, :] = torch.where(
self.attn_override_mask.permute(2, 3, 0, 1),
self.attn_override.permute(2, 3, 0, 1),
attn_prob[:override_q_len, :override_k_len, :, :])
# compute attention vector
attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
# linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
# residual connection
outputs = [w + attn_out]
else:
# residual connection + layer normalization
outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions:
outputs.append(attn_prob)
return outputs
class XLNetAttentionOverride(nn.Module):
""" A copy of `modeling_xlnet.XLNetRelativeAttention` class,
but with overridden attention values """
def __init__(self, module, attn_override, attn_override_mask):
"""
Args:
module: instance of modeling_xlnet.XLNetRelativeAttention
from which variables will be copied
attn_override: values to override the computed attention weights.
Tuple of content and query attentions (2-stream self-attention),
each of shape [bsz, num_heads, seq_len, seq_len]
attn_override_mask: indicates which attention weights to override.
Shape is [bsz, num_heads, seq_len, seq_len]
"""
super().__init__()
self.output_attentions = module.output_attentions
# if config.d_model % config.n_head != 0:
# raise ValueError(
# "The hidden size (%d) is not a multiple of the number of attention "
# "heads (%d)" % (config.d_model, config.n_head)
# )
self.n_head = module.n_head
self.d_head = module.d_head
self.d_model = module.d_model
self.scale = module.scale
self.q = module.q
self.k = module.k
self.v = module.v
self.o = module.o
self.r = module.r
self.r_r_bias = module.r_r_bias
self.r_s_bias = module.r_s_bias
self.r_w_bias = module.r_w_bias
self.seg_embed = module.seg_embed
self.layer_norm = module.layer_norm
self.dropout = module.dropout
# Set attention override values
self.content_attn_override, self.query_attn_override = attn_override
self.attn_override_mask = attn_override_mask
def prune_heads(self, heads):
raise NotImplementedError
@staticmethod
def rel_shift(x, klen=-1):
"""perform relative shift to form the relative attention score."""
x_size = x.shape
x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
x = x[1:, ...]
x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
# x = x[:, 0:klen, :, :]
x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
return x
@staticmethod
def rel_shift_bnij(x, klen=-1):
x_size = x.shape
x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
x = x[:, :, 1:, :]
x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
# Note: the tensor-slice form was faster in my testing than torch.index_select
# However, tracing doesn't like the nature of the slice, and if klen changes
# during the run then it'll fail, whereas index_select will be fine.
x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
# x = x[:, :, :, :klen]
return x
def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, attn_override, seg_mat=None, attn_mask=None, head_mask=None):
"""Core relative positional attention operations."""
# content based attention score
ac = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_w_bias, k_head_h)
# position based attention score
bd = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_r_bias, k_head_r)
bd = self.rel_shift_bnij(bd, klen=ac.shape[3])
# segment based attention score
if seg_mat is None:
ef = 0
else:
ef = torch.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
ef = torch.einsum("ijbs,ibns->bnij", seg_mat, ef)
# merge attention scores and perform masking
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
if attn_mask.dtype == torch.float16:
attn_score = attn_score - 65500 * torch.einsum("ijbn->bnij", attn_mask)
else:
attn_score = attn_score - 1e30 * torch.einsum("ijbn->bnij", attn_mask)
# attention probability
attn_prob = F.softmax(attn_score, dim=3)
attn_prob = self.dropout(attn_prob)
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * torch.einsum("ijbn->bnij", head_mask)
# Intervention:
# attn_override and attn_override_mask are of shape (batch_size, num_heads, override_seq_len, override_seq_len)
# where override_seq_len is the length of subsequence for which attention is being overridden
override_seq_len = self.attn_override_mask.shape[-1]
attn_prob[:, :, :override_seq_len, :override_seq_len] = torch.where(
self.attn_override_mask,
attn_override,
attn_prob[:, :, :override_seq_len, :override_seq_len])
# attention output
attn_vec = torch.einsum("bnij,jbnd->ibnd", attn_prob, v_head_h)
if self.output_attentions:
return attn_vec, torch.einsum("bnij->ijbn", attn_prob)
return attn_vec
def post_attention(self, h, attn_vec, residual=True):
"""Post-attention processing."""
# post-attention projection (back to `d_model`)
attn_out = torch.einsum("ibnd,hnd->ibh", attn_vec, self.o)
attn_out = self.dropout(attn_out)
if residual:
attn_out = attn_out + h
output = self.layer_norm(attn_out)
return output
def forward(self, h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None):
if g is not None:
# Two-stream attention with relative positional encoding.
# content based attention score
if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0)
else:
cat = h
# content-based key head
k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
# content-based value head
v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
# position-based key head
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
# h-stream
# content-stream query head
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
# core attention ops
attn_vec_h = self.rel_attn_core(
q_head_h, k_head_h, v_head_h, k_head_r,
attn_override=self.content_attn_override,
seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask
)
if self.output_attentions:
attn_vec_h, attn_prob_h = attn_vec_h
# post processing
output_h = self.post_attention(h, attn_vec_h)
# g-stream
# query-stream query head
q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
# core attention ops
if target_mapping is not None:
q_head_g = torch.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
q_head_g, k_head_h, v_head_h, k_head_r,
attn_override=self.query_attn_override,
seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask
)
if self.output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = torch.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else:
assert False ### NEW
attn_vec_g = self.rel_attn_core(
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask
)
if self.output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
# post processing
output_g = self.post_attention(g, attn_vec_g)
if self.output_attentions:
attn_prob = attn_prob_h, attn_prob_g
else:
assert False ### NEW
# Multi-head attention with relative positional encoding
if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0)
else:
cat = h
# content heads
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
# positional heads
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
# core attention ops
attn_vec = self.rel_attn_core(
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask
)
if self.output_attentions:
attn_vec, attn_prob = attn_vec
# post processing
output_h = self.post_attention(h, attn_vec)
output_g = None
outputs = (output_h, output_g)
if self.output_attentions:
outputs = outputs + (attn_prob,)
return outputs
class BertAttentionOverride(nn.Module):
"""A copy of `modeling_bert.BertSelfAttention` class, but with overridden attention values"""
def __init__(self, module, attn_override, attn_override_mask):
"""
Args:
module: instance of modeling_bert.BertSelfAttentionOverride
from which variables will be copied
attn_override: values to override the computed attention weights.
Shape is [bsz, num_heads, seq_len, seq_len]
attn_override_mask: indicates which attention weights to override.
Shape is [bsz, num_heads, seq_len, seq_len]
"""
super().__init__()
# if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
# raise ValueError(
# "The hidden size (%d) is not a multiple of the number of attention "
# "heads (%d)" % (config.hidden_size, config.num_attention_heads)
# )
self.output_attentions = module.output_attentions
self.num_attention_heads = module.num_attention_heads
self.attention_head_size = module.attention_head_size
self.all_head_size = module.all_head_size
self.query = module.query
self.key = module.key
self.value = module.value
self.dropout = module.dropout
# Set attention override values
self.attn_override = attn_override
self.attn_override_mask = attn_override_mask
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None:
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
# Intervention:
# attn_override and attn_override_mask are of shape (batch_size, num_heads, override_seq_len, override_seq_len)
# where override_seq_len is the length of subsequence for which attention is being overridden
override_seq_len = self.attn_override_mask.shape[-1]
attention_probs[:, :, :override_seq_len, :override_seq_len] = torch.where(
self.attn_override_mask,
self.attn_override,
attention_probs[:, :, :override_seq_len, :override_seq_len])
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
return outputs
class DistilBertAttentionOverride(nn.Module):
"""A copy of `modeling_distilbert.MultiHeadSelfAttention` class, but with overridden attention values"""
def __init__(self, module, attn_override, attn_override_mask):
"""
Args:
module: instance of modeling_distilbert.MultiHeadSelfAttention
from which variables will be copied
attn_override: values to override the computed attention weights.
Shape is [bsz, num_heads, seq_len, seq_len]
attn_override_mask: indicates which attention weights to override.
Shape is [bsz, num_heads, seq_len, seq_len]
"""
super().__init__()
self.n_heads = module.n_heads
self.dim = module.dim
self.dropout = module.dropout
self.output_attentions = module.output_attentions
assert self.dim % self.n_heads == 0
self.q_lin = module.q_lin
self.k_lin = module.k_lin
self.v_lin = module.v_lin
self.out_lin = module.out_lin
self.pruned_heads = module.pruned_heads
# Set attention override values
self.attn_override = attn_override
self.attn_override_mask = attn_override_mask
def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, attention_head_size)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.q_lin = prune_linear_layer(self.q_lin, index)
self.k_lin = prune_linear_layer(self.k_lin, index)
self.v_lin = prune_linear_layer(self.v_lin, index)
self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, query, key, value, mask, head_mask=None):
"""
Parameters
----------
query: torch.tensor(bs, seq_length, dim)
key: torch.tensor(bs, seq_length, dim)
value: torch.tensor(bs, seq_length, dim)
mask: torch.tensor(bs, seq_length)
Outputs
-------
weights: torch.tensor(bs, n_heads, seq_length, seq_length)
Attention weights
context: torch.tensor(bs, seq_length, dim)
Contextualized layer. Optional: only if `output_attentions=True`
"""
bs, q_length, dim = query.size()
k_length = key.size(1)
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
# assert key.size() == value.size()
dim_per_head = self.dim // self.n_heads
mask_reshp = (bs, 1, 1, k_length)
def shape(x):
""" separate heads """
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x):
""" group heads """
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
weights = nn.Softmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
# Intervention:
# attn_override and attn_override_mask are of shape (batch_size, num_heads, override_seq_len, override_seq_len)
# where override_seq_len is the length of subsequence for which attention is being overridden
override_seq_len = self.attn_override_mask.shape[-1]
weights[:, :, :override_seq_len, :override_seq_len] = torch.where(
self.attn_override_mask,
self.attn_override,
weights[:, :, :override_seq_len, :override_seq_len])
context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim)
if self.output_attentions:
return (context, weights)
else:
return (context,)