3
3
import numpy as np
4
4
from torch .nn import functional as F
5
5
from itertools import permutations
6
- from asteroid .losses .sdr import MultiSrcNegSDR
6
+ from asteroid .losses .sdr import MultiSrcNegSDR , SingleSrcNegSDR
7
+ from asteroid .losses import PITLossWrapper , PairwiseNegSDR ,pairwise_neg_sisdr
7
8
import math
8
9
9
10
@@ -12,7 +13,7 @@ class ClippedSDR(nn.Module):
12
13
def __init__ (self , clip_value = - 30 ):
13
14
super (ClippedSDR , self ).__init__ ()
14
15
15
- self .snr = MultiSrcNegSDR ( "snr" )
16
+ self .snr = PITLossWrapper ( pairwise_neg_sisdr )
16
17
self .clip_value = float (clip_value )
17
18
18
19
def forward (self , est_targets , targets ):
@@ -23,12 +24,9 @@ def forward(self, est_targets, targets):
23
24
class SpeakerVectorLoss (nn .Module ):
24
25
25
26
def __init__ (self , n_speakers , embed_dim = 32 , learnable_emb = True , loss_type = "global" ,
26
- weight = 10 , distance_reg = 0.3 , gaussian_reg = 0.2 , return_oracle = True ):
27
+ weight = 2 , distance_reg = 0.3 , gaussian_reg = 0.2 , return_oracle = False ):
27
28
super (SpeakerVectorLoss , self ).__init__ ()
28
29
29
-
30
- # not clear how embeddings are initialized.
31
-
32
30
self .learnable_emb = learnable_emb
33
31
self .loss_type = loss_type
34
32
self .weight = float (weight )
@@ -38,36 +36,30 @@ def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="glob
38
36
39
37
assert loss_type in ["distance" , "global" , "local" ]
40
38
41
- # I initialize embeddings to be on unit sphere as speaker stack uses euclidean normalization
42
-
43
- spk_emb = torch .rand ((n_speakers , embed_dim ))
44
- norms = torch .sum (spk_emb ** 2 , - 1 , keepdim = True ).sqrt ()
45
- spk_emb = spk_emb / norms # generate points on n-dimensional unit sphere
39
+ spk_emb = torch .eye (max (n_speakers , embed_dim )) # one-hot init works better according to Neil
40
+ spk_emb = spk_emb [:n_speakers , :embed_dim ]
46
41
47
42
if learnable_emb == True :
48
43
self .spk_embeddings = nn .Parameter (spk_emb )
49
44
else :
50
45
self .register_buffer ("spk_embeddings" , spk_emb )
51
46
52
- if loss_type != "dist " :
53
- self .alpha = nn .Parameter (torch .Tensor ([1. ])) # not clear how these are initialized...
47
+ if loss_type != "distance " :
48
+ self .alpha = nn .Parameter (torch .Tensor ([1. ]))
54
49
self .beta = nn .Parameter (torch .Tensor ([0. ]))
55
50
56
-
57
- ### losses go to NaN if I follow strictly the formulas maybe I am missing something...
58
-
59
51
@staticmethod
60
52
def _l_dist_speaker (c_spk_vec_perm , spk_embeddings , spk_labels , spk_mask ):
61
53
62
54
utt_embeddings = spk_embeddings [spk_labels ].unsqueeze (- 1 ) * spk_mask .unsqueeze (2 )
63
55
c_spk = c_spk_vec_perm [:, 0 ]
64
56
pair_dist = ((c_spk .unsqueeze (1 ) - c_spk_vec_perm )** 2 ).sum (2 )
65
- pair_dist = pair_dist [:, 1 :]. sqrt ()
66
- distance = ((c_spk_vec_perm - utt_embeddings )** 2 ).sum (2 ). sqrt ( )
67
- return ( distance + F .relu (1. - pair_dist ).sum (1 ). unsqueeze (1 )). sum ( 1 )
57
+ pair_dist = pair_dist [:, 1 :]
58
+ distance = ((c_spk_vec_perm - utt_embeddings )** 2 ).sum (dim = ( 1 , 2 ) )
59
+ return distance + F .relu (1. - pair_dist ).sum (dim = (1 ))
68
60
69
61
def _l_local_speaker (self , c_spk_vec_perm , spk_embeddings , spk_labels , spk_mask ):
70
-
62
+ raise NotImplemented
71
63
utt_embeddings = spk_embeddings [spk_labels ].unsqueeze (- 1 ) * spk_mask .unsqueeze (2 )
72
64
alpha = torch .clamp (self .alpha , 1e-8 )
73
65
@@ -79,42 +71,37 @@ def _l_local_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask)
79
71
return out .sum (1 )
80
72
81
73
def _l_global_speaker (self , c_spk_vec_perm , spk_embeddings , spk_labels , spk_mask ):
82
-
74
+ raise NotImplemented
83
75
utt_embeddings = spk_embeddings [spk_labels ].unsqueeze (- 1 ) * spk_mask .unsqueeze (2 )
84
76
alpha = torch .clamp (self .alpha , 1e-8 )
85
77
86
- distance_utt = alpha * ((c_spk_vec_perm - utt_embeddings )** 2 ).sum (2 ). sqrt () + self .beta
78
+ distance_utt = alpha * ((c_spk_vec_perm - utt_embeddings )** 2 ).sum (2 ) + self .beta
87
79
88
80
B , src , embed_dim , frames = c_spk_vec_perm .size ()
89
81
spk_embeddings = spk_embeddings .reshape (1 , spk_embeddings .shape [0 ], embed_dim , 1 ).expand (B , - 1 , - 1 , frames )
90
82
distances = alpha * ((c_spk_vec_perm .unsqueeze (1 ) - spk_embeddings .unsqueeze (2 )) ** 2 ).sum (3 ).sqrt () + self .beta
91
83
# exp normalize trick
92
- with torch .no_grad ():
93
- b = torch .max (distances , dim = 1 , keepdim = True )[0 ]
94
- out = - distance_utt + b .squeeze (1 ) - torch .log (torch .exp (- distances + b ).sum (1 ))
95
- return out .sum (1 )
84
+ # with torch.no_grad():
85
+ # b = torch.max(distances, dim=1, keepdim=True)[0]
86
+ # out = -distance_utt + b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1))
87
+ # return out.sum(1)
96
88
97
- def forward (self , speaker_vectors , spk_mask , spk_labels ):
98
89
99
- # spk_mask ideally would be the speaker activty at frame level. Because WHAM speakers can be considered always two and active we fix this for now.
100
- # mask with ones and zeros B, SRC, FRAMES
90
+ def forward (self , speaker_vectors , spk_mask , spk_labels ):
101
91
102
92
if self .gaussian_reg :
103
93
noise = torch .randn (self .spk_embeddings .size (), device = speaker_vectors .device )* math .sqrt (self .gaussian_reg )
104
94
spk_embeddings = self .spk_embeddings + noise
105
95
else :
106
96
spk_embeddings = self .spk_embeddings
107
97
108
- if self .learnable_emb or self .gaussian_reg : # re project on unit sphere after noise has been applied and before computing the distance reg
98
+ if self .learnable_emb or self .gaussian_reg : # re project on unit sphere
109
99
110
100
spk_embeddings = spk_embeddings / torch .sum (spk_embeddings ** 2 , - 1 , keepdim = True ).sqrt ()
111
101
112
102
if self .distance_reg :
113
103
114
- pairwise_dist = ((spk_embeddings .unsqueeze (0 ) - spk_embeddings .unsqueeze (1 ))** 2 ).sum (- 1 )
115
- idx = torch .arange (0 , pairwise_dist .shape [0 ])
116
- pairwise_dist [idx , idx ] = np .inf # masking with itself
117
- pairwise_dist = pairwise_dist .sqrt ()
104
+ pairwise_dist = (torch .abs (spk_embeddings .unsqueeze (0 ) - spk_embeddings .unsqueeze (1 ))).mean (- 1 ).fill_diagonal_ (np .inf )
118
105
distance_reg = - torch .sum (torch .min (torch .log (pairwise_dist ), dim = - 1 )[0 ])
119
106
120
107
# speaker vectors B, n_src, dim, frames
@@ -145,10 +132,8 @@ def forward(self, speaker_vectors, spk_mask, spk_labels):
145
132
min_loss_perm = min_loss_perm .transpose (0 , 1 ).reshape (B , n_src , 1 , frames ).expand (- 1 , - 1 , embed_dim , - 1 )
146
133
# tot_loss
147
134
148
-
149
135
spk_loss = self .weight * min_loss .mean ()
150
136
if self .distance_reg :
151
-
152
137
spk_loss += self .distance_reg * distance_reg
153
138
reordered_sources = torch .gather (speaker_vectors , dim = 1 , index = min_loss_perm )
154
139
@@ -160,23 +145,24 @@ def forward(self, speaker_vectors, spk_mask, spk_labels):
160
145
161
146
162
147
if __name__ == "__main__" :
148
+ n_speakers = 101
149
+ emb_speaker = 256
163
150
164
151
# testing exp normalize average
165
- distances = torch .ones ((1 , 101 , 4000 ))* 99
166
- with torch .no_grad ():
167
- b = torch .max (distances , dim = 1 , keepdim = True )[0 ]
168
- out = b .squeeze (1 ) - torch .log (torch .exp (- distances + b ).sum (1 ))
169
- out2 = - torch .log (torch .exp (- distances ).sum (1 ))
152
+ # distances = torch.ones((1, 101, 4000))
153
+ # with torch.no_grad():
154
+ # b = torch.max(distances, dim=1, keepdim=True)[0]
155
+ # out = b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1))
156
+ # out2 = - torch.log(torch.exp(-distances).sum(1))
170
157
171
- loss_spk = SpeakerVectorLoss (1000 , 32 , loss_type = "distance" ) # 1000 speakers in training set
158
+ loss_spk = SpeakerVectorLoss (n_speakers , emb_speaker , loss_type = "global" )
172
159
173
- speaker_vectors = torch .rand (2 , 3 , 32 , 200 )
160
+ speaker_vectors = torch .rand (2 , 3 , emb_speaker , 200 )
174
161
speaker_labels = torch .from_numpy (np .array ([[1 , 2 , 0 ], [5 , 2 , 10 ]]))
175
162
speaker_mask = torch .randint (0 , 2 , (2 , 3 , 200 )) # silence where there are no speakers actually thi is test
176
163
speaker_mask [:, - 1 , :] = speaker_mask [:, - 1 , :]* 0
177
164
loss_spk (speaker_vectors , speaker_mask , speaker_labels )
178
165
179
-
180
166
c = ClippedSDR (- 30 )
181
167
a = torch .rand ((2 , 3 , 200 ))
182
168
print (c (a , a ))
0 commit comments