Skip to content

Commit d8e7ef7

Browse files
committed
updated wavesplit implementation. Tested only distance loss
1 parent 8d1ca2e commit d8e7ef7

File tree

8 files changed

+154
-184
lines changed

8 files changed

+154
-184
lines changed

egs/wham/WaveSplit/README.md

+9-15
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
### WaveSplit
22

3-
things currently not clear:
4-
---
5-
- not clear if different encoders are used for separation and speaker stack. (from image in the paper it seems so)
6-
- what is embedding dimension ? It seems 512 but it is not explicit in the paper
7-
- mask used (sigmoid ?)
8-
- when speakers in an example < sep stack outputs loss is simply masked or an embedding for silence is used ? (Probably masked)
9-
- is VAD used in WSJ02MiX/ WHAM for determining speech activity at frame level ? Some files can have pauses of even one second
10-
- loss right now is prone to go NaN especially if we don't take the mean after l2-distances computation.
11-
12-
---
13-
structure:
14-
- train.py contains training loop (nets instantiation lines 48-60, training loop lines 100- 116)
15-
- losses.py wavesplit losses
16-
- wavesplit.py sep and speaker stacks nets
17-
- wavesplitwham.py dataset parsing
3+
we train on 1 sec now.
4+
5+
tried with 256 embedding dimension.
6+
7+
still does not work with oracle embeddings.
8+
9+
not clear how in sep stack loss at every layer is computed ( is the same output layer used in all ?).
10+
Also no mention in the paper about output layer and that first conv has no skip connection.
11+

egs/wham/WaveSplit/local/conf.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ training:
2323
num_workers: 4
2424
half_lr: yes
2525
early_stop: yes
26-
gradient_clipping: 5
26+
gradient_clipping: 5000
2727
# Optim config
2828
optim:
2929
optimizer: adam
@@ -38,4 +38,4 @@ data:
3838
nondefault_nsrc:
3939
sample_rate: 8000
4040
mode: min
41-
segment: 0.750
41+
segment: 1

egs/wham/WaveSplit/local/preprocess_wham.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def preprocess_task(task, in_dir, out_dir):
1414
examples = []
1515
for mix in mix_both:
1616
filename = mix.split("/")[-1]
17-
spk1_id = filename.split("_")[0][:3]
18-
spk2_id = filename.split("_")[2][:3]
17+
spk1_id = filename.split("_")[0]
18+
spk2_id = filename.split("_")[2]
1919
length = len(sf.SoundFile(mix))
2020

2121
noise = os.path.join(in_dir, "noise", filename)
@@ -33,8 +33,8 @@ def preprocess_task(task, in_dir, out_dir):
3333
examples = []
3434
for mix in mix_clean:
3535
filename = mix.split("/")[-1]
36-
spk1_id = filename.split("_")[0][:3]
37-
spk2_id = filename.split("_")[2][:3]
36+
spk1_id = filename.split("_")[0]
37+
spk2_id = filename.split("_")[2]
3838
length = len(sf.SoundFile(mix))
3939

4040
s1 = os.path.join(in_dir, "s1", filename)
@@ -51,7 +51,7 @@ def preprocess_task(task, in_dir, out_dir):
5151
examples = []
5252
for mix in mix_single:
5353
filename = mix.split("/")[-1]
54-
spk1_id = filename.split("_")[0][:3]
54+
spk1_id = filename.split("_")[0]
5555
length = len(sf.SoundFile(mix))
5656

5757
s1 = os.path.join(in_dir, "s1", filename)

egs/wham/WaveSplit/losses.py

+30-44
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import numpy as np
44
from torch.nn import functional as F
55
from itertools import permutations
6-
from asteroid.losses.sdr import MultiSrcNegSDR
6+
from asteroid.losses.sdr import MultiSrcNegSDR, SingleSrcNegSDR
7+
from asteroid.losses import PITLossWrapper, PairwiseNegSDR,pairwise_neg_sisdr
78
import math
89

910

@@ -12,7 +13,7 @@ class ClippedSDR(nn.Module):
1213
def __init__(self, clip_value=-30):
1314
super(ClippedSDR, self).__init__()
1415

15-
self.snr = MultiSrcNegSDR("snr")
16+
self.snr = PITLossWrapper(pairwise_neg_sisdr)
1617
self.clip_value = float(clip_value)
1718

1819
def forward(self, est_targets, targets):
@@ -23,12 +24,9 @@ def forward(self, est_targets, targets):
2324
class SpeakerVectorLoss(nn.Module):
2425

2526
def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="global",
26-
weight=10, distance_reg=0.3, gaussian_reg=0.2, return_oracle=True):
27+
weight=2, distance_reg=0.3, gaussian_reg=0.2, return_oracle=False):
2728
super(SpeakerVectorLoss, self).__init__()
2829

29-
30-
# not clear how embeddings are initialized.
31-
3230
self.learnable_emb = learnable_emb
3331
self.loss_type = loss_type
3432
self.weight = float(weight)
@@ -38,36 +36,30 @@ def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="glob
3836

3937
assert loss_type in ["distance", "global", "local"]
4038

41-
# I initialize embeddings to be on unit sphere as speaker stack uses euclidean normalization
42-
43-
spk_emb = torch.rand((n_speakers, embed_dim))
44-
norms = torch.sum(spk_emb ** 2, -1, keepdim=True).sqrt()
45-
spk_emb = spk_emb / norms # generate points on n-dimensional unit sphere
39+
spk_emb = torch.eye(max(n_speakers, embed_dim)) # one-hot init works better according to Neil
40+
spk_emb = spk_emb[:n_speakers, :embed_dim]
4641

4742
if learnable_emb == True:
4843
self.spk_embeddings = nn.Parameter(spk_emb)
4944
else:
5045
self.register_buffer("spk_embeddings", spk_emb)
5146

52-
if loss_type != "dist":
53-
self.alpha = nn.Parameter(torch.Tensor([1.])) # not clear how these are initialized...
47+
if loss_type != "distance":
48+
self.alpha = nn.Parameter(torch.Tensor([1.]))
5449
self.beta = nn.Parameter(torch.Tensor([0.]))
5550

56-
57-
### losses go to NaN if I follow strictly the formulas maybe I am missing something...
58-
5951
@staticmethod
6052
def _l_dist_speaker(c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask):
6153

6254
utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2)
6355
c_spk = c_spk_vec_perm[:, 0]
6456
pair_dist = ((c_spk.unsqueeze(1) - c_spk_vec_perm)**2).sum(2)
65-
pair_dist = pair_dist[:, 1:].sqrt()
66-
distance = ((c_spk_vec_perm - utt_embeddings)**2).sum(2).sqrt()
67-
return (distance + F.relu(1. - pair_dist).sum(1).unsqueeze(1)).sum(1)
57+
pair_dist = pair_dist[:, 1:]
58+
distance = ((c_spk_vec_perm - utt_embeddings)**2).sum(dim=(1,2))
59+
return distance + F.relu(1. - pair_dist).sum(dim=(1))
6860

6961
def _l_local_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask):
70-
62+
raise NotImplemented
7163
utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2)
7264
alpha = torch.clamp(self.alpha, 1e-8)
7365

@@ -79,42 +71,37 @@ def _l_local_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask)
7971
return out.sum(1)
8072

8173
def _l_global_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask):
82-
74+
raise NotImplemented
8375
utt_embeddings = spk_embeddings[spk_labels].unsqueeze(-1) * spk_mask.unsqueeze(2)
8476
alpha = torch.clamp(self.alpha, 1e-8)
8577

86-
distance_utt = alpha*((c_spk_vec_perm - utt_embeddings)**2).sum(2).sqrt() + self.beta
78+
distance_utt = alpha*((c_spk_vec_perm - utt_embeddings)**2).sum(2) + self.beta
8779

8880
B, src, embed_dim, frames = c_spk_vec_perm.size()
8981
spk_embeddings = spk_embeddings.reshape(1, spk_embeddings.shape[0], embed_dim, 1).expand(B, -1, -1, frames)
9082
distances = alpha * ((c_spk_vec_perm.unsqueeze(1) - spk_embeddings.unsqueeze(2)) ** 2).sum(3).sqrt() + self.beta
9183
# exp normalize trick
92-
with torch.no_grad():
93-
b = torch.max(distances, dim=1, keepdim=True)[0]
94-
out = -distance_utt + b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1))
95-
return out.sum(1)
84+
#with torch.no_grad():
85+
# b = torch.max(distances, dim=1, keepdim=True)[0]
86+
#out = -distance_utt + b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1))
87+
#return out.sum(1)
9688

97-
def forward(self, speaker_vectors, spk_mask, spk_labels):
9889

99-
# spk_mask ideally would be the speaker activty at frame level. Because WHAM speakers can be considered always two and active we fix this for now.
100-
# mask with ones and zeros B, SRC, FRAMES
90+
def forward(self, speaker_vectors, spk_mask, spk_labels):
10191

10292
if self.gaussian_reg:
10393
noise = torch.randn(self.spk_embeddings.size(), device=speaker_vectors.device)*math.sqrt(self.gaussian_reg)
10494
spk_embeddings = self.spk_embeddings + noise
10595
else:
10696
spk_embeddings = self.spk_embeddings
10797

108-
if self.learnable_emb or self.gaussian_reg: # re project on unit sphere after noise has been applied and before computing the distance reg
98+
if self.learnable_emb or self.gaussian_reg: # re project on unit sphere
10999

110100
spk_embeddings = spk_embeddings / torch.sum(spk_embeddings ** 2, -1, keepdim=True).sqrt()
111101

112102
if self.distance_reg:
113103

114-
pairwise_dist = ((spk_embeddings.unsqueeze(0) - spk_embeddings.unsqueeze(1))**2).sum(-1)
115-
idx = torch.arange(0, pairwise_dist.shape[0])
116-
pairwise_dist[idx, idx] = np.inf # masking with itself
117-
pairwise_dist = pairwise_dist.sqrt()
104+
pairwise_dist = (torch.abs(spk_embeddings.unsqueeze(0) - spk_embeddings.unsqueeze(1))).mean(-1).fill_diagonal_(np.inf)
118105
distance_reg = -torch.sum(torch.min(torch.log(pairwise_dist), dim=-1)[0])
119106

120107
# speaker vectors B, n_src, dim, frames
@@ -145,10 +132,8 @@ def forward(self, speaker_vectors, spk_mask, spk_labels):
145132
min_loss_perm = min_loss_perm.transpose(0, 1).reshape(B, n_src, 1, frames).expand(-1, -1, embed_dim, -1)
146133
# tot_loss
147134

148-
149135
spk_loss = self.weight*min_loss.mean()
150136
if self.distance_reg:
151-
152137
spk_loss += self.distance_reg*distance_reg
153138
reordered_sources = torch.gather(speaker_vectors, dim=1, index=min_loss_perm)
154139

@@ -160,23 +145,24 @@ def forward(self, speaker_vectors, spk_mask, spk_labels):
160145

161146

162147
if __name__ == "__main__":
148+
n_speakers = 101
149+
emb_speaker = 256
163150

164151
# testing exp normalize average
165-
distances = torch.ones((1, 101, 4000))*99
166-
with torch.no_grad():
167-
b = torch.max(distances, dim=1, keepdim=True)[0]
168-
out = b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1))
169-
out2 = - torch.log(torch.exp(-distances).sum(1))
152+
#distances = torch.ones((1, 101, 4000))
153+
#with torch.no_grad():
154+
# b = torch.max(distances, dim=1, keepdim=True)[0]
155+
#out = b.squeeze(1) - torch.log(torch.exp(-distances + b).sum(1))
156+
#out2 = - torch.log(torch.exp(-distances).sum(1))
170157

171-
loss_spk = SpeakerVectorLoss(1000, 32, loss_type="distance") # 1000 speakers in training set
158+
loss_spk = SpeakerVectorLoss(n_speakers, emb_speaker, loss_type="global")
172159

173-
speaker_vectors = torch.rand(2, 3, 32, 200)
160+
speaker_vectors = torch.rand(2, 3, emb_speaker, 200)
174161
speaker_labels = torch.from_numpy(np.array([[1, 2, 0], [5, 2, 10]]))
175162
speaker_mask = torch.randint(0, 2, (2, 3, 200)) # silence where there are no speakers actually thi is test
176163
speaker_mask[:, -1, :] = speaker_mask[:, -1, :]*0
177164
loss_spk(speaker_vectors, speaker_mask, speaker_labels)
178165

179-
180166
c = ClippedSDR(-30)
181167
a = torch.rand((2, 3, 200))
182168
print(c(a, a))

egs/wham/WaveSplit/run.sh

+2-6
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,8 @@ mode=min
4242
nondefault_src= # If you want to train a network with 3 output streams for example.
4343

4444
# Training
45-
batch_size=1
46-
num_workers=8
47-
kernel_size=16
48-
stride=8
45+
batch_size=4
46+
num_workers=4
4947
#optimizer=adam
5048
lr=0.001
5149
epochs=400
@@ -134,8 +132,6 @@ if [[ $stage -le 3 ]]; then
134132
--epochs $epochs \
135133
--batch_size $batch_size \
136134
--num_workers $num_workers \
137-
--kernel_size $kernel_size \
138-
--stride $stride \
139135
--exp_dir ${expdir}/ | tee logs/train_${tag}.log
140136
fi
141137

0 commit comments

Comments
 (0)