-
Notifications
You must be signed in to change notification settings - Fork 1
/
discrete_ECAPA_TDNN.py
123 lines (105 loc) · 4.07 KB
/
discrete_ECAPA_TDNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""A popular speaker recognition and diarization model.
Authors
* Hwidong Na 2020
"""
# import os
from functools import partial
import entmax
import torch # noqa: F401
from speechbrain.dataio.dataio import length_to_mask
from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN, \
AttentiveStatisticsPooling
from scorer import SelfAdditiveScorer
available_max_activations = {
'softmax': torch.softmax,
'sparsemax': entmax.sparsemax,
'entmax15': entmax.entmax15,
'entmax1333': partial(entmax.entmax_bisect, alpha=1.333)
}
class DiscreteAttentiveStatisticsPooling(AttentiveStatisticsPooling):
def __init__(self, channels, attention_channels=128, global_context=True,
scorer='default', max_activation='softmax'):
super().__init__(channels, attention_channels, global_context)
self.scorer = scorer
self.max_activation = max_activation
self.alphas = None
if self.scorer == 'self_add':
self.add_scorer = SelfAdditiveScorer(
channels * 3 if global_context else channels,
attention_channels,
scaled=False
)
def forward(self, x, lengths=None):
"""Calculates mean and std for a batch (input tensor).
Arguments
---------
x : torch.Tensor
Tensor of shape [N, C, L].
"""
L = x.shape[-1]
def _compute_statistics(x, m, dim=2, eps=self.eps):
mean = (m * x).sum(dim)
std = torch.sqrt(
(m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
)
return mean, std
if lengths is None:
lengths = torch.ones(x.shape[0], device=x.device)
# Make binary mask of shape [N, 1, L]
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
mask = mask.unsqueeze(1)
# Expand the temporal context of the pooling layer by allowing the
# self-attention to look at global properties of the utterance.
if self.global_context:
# torch.std is unstable for backward computation
# https://github.com/pytorch/pytorch/issues/4320
total = mask.sum(dim=2, keepdim=True).float()
mean, std = _compute_statistics(x, mask / total)
mean = mean.unsqueeze(2).repeat(1, 1, L)
std = std.unsqueeze(2).repeat(1, 1, L)
attn = torch.cat([x, mean, std], dim=1)
else:
attn = x
# Apply layers
if self.scorer == 'default':
attn = self.conv(self.tanh(self.tdnn(attn)))
elif self.scorer == 'self_add':
attn = self.add_scorer(attn, attn)
# Filter out zero-paddings
attn = attn.masked_fill(mask == 0, float("-inf"))
max_activation = available_max_activations[self.max_activation]
attn = max_activation(attn, dim=2)
mean, std = _compute_statistics(x, attn)
# Append mean and std of the batch
pooled_stats = torch.cat((mean, std), dim=1)
pooled_stats = pooled_stats.unsqueeze(2)
self.alphas = attn
return pooled_stats
class DiscreteECAPA_TDNN(ECAPA_TDNN):
def __init__(
self,
input_size,
device="cpu",
lin_neurons=192,
activation=torch.nn.ReLU,
channels=[512, 512, 512, 512, 1536],
kernel_sizes=[5, 3, 3, 3, 1],
dilations=[1, 2, 3, 4, 1],
attention_channels=128,
res2net_scale=8,
se_channels=128,
global_context=True,
attn_scorer='default',
attn_max_activation='softmax',
):
super().__init__(
input_size, device, lin_neurons, activation, channels, kernel_sizes, dilations, attention_channels,
res2net_scale, se_channels, global_context
)
self.asp = DiscreteAttentiveStatisticsPooling(
channels[-1],
attention_channels=attention_channels,
global_context=global_context,
scorer=attn_scorer,
max_activation=attn_max_activation
)