-
Notifications
You must be signed in to change notification settings - Fork 424
/
model.py
179 lines (162 loc) · 6.36 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import json
import os
import torch
from torch import nn
from sklearn.cluster import KMeans
from asteroid import torch_utils
import asteroid_filterbanks as fb
from asteroid.engine.optimizers import make_optimizer
from asteroid_filterbanks.transforms import mag, apply_mag_mask
from asteroid.dsp.vad import ebased_vad
from asteroid.masknn.recurrent import SingleRNN
from asteroid.utils.torch_utils import pad_x_to_y
def make_model_and_optimizer(conf):
"""Function to define the model and optimizer for a config dictionary.
Args:
conf: Dictionary containing the output of hierachical argparse.
Returns:
model, optimizer.
The main goal of this function is to make reloading for resuming
and evaluation very simple.
"""
enc, dec = fb.make_enc_dec("stft", **conf["filterbank"])
masker = Chimera(enc.n_feats_out // 2, **conf["masknet"])
model = Model(enc, masker, dec)
optimizer = make_optimizer(model.parameters(), **conf["optim"])
return model, optimizer
class Chimera(nn.Module):
def __init__(
self,
in_chan,
n_src,
rnn_type="lstm",
n_layers=2,
hidden_size=600,
bidirectional=True,
dropout=0.3,
embedding_dim=20,
take_log=False,
EPS=1e-8,
):
super().__init__()
self.input_dim = in_chan
self.n_src = n_src
self.take_log = take_log
# RNN common
self.embedding_dim = embedding_dim
self.rnn = SingleRNN(
rnn_type,
in_chan,
hidden_size,
n_layers=n_layers,
dropout=dropout,
bidirectional=bidirectional,
)
self.dropout = nn.Dropout(dropout)
rnn_out_dim = hidden_size * 2 if bidirectional else hidden_size
# Mask heads
self.mask_layer = nn.Linear(rnn_out_dim, in_chan * self.n_src)
self.mask_act = nn.Sigmoid() # sigmoid or relu or softmax
# DC head
self.embedding_layer = nn.Linear(rnn_out_dim, in_chan * embedding_dim)
self.embedding_act = nn.Tanh() # sigmoid or tanh
self.EPS = EPS
def forward(self, input_data):
batch, _, n_frames = input_data.shape
if self.take_log:
input_data = torch.log(input_data + self.EPS)
# Common net
out = self.rnn(input_data.permute(0, 2, 1))
out = self.dropout(out)
# DC head
proj = self.embedding_layer(out) # batch, time, freq * emb
proj = self.embedding_act(proj)
proj = proj.view(batch, n_frames, -1, self.embedding_dim).transpose(1, 2)
# (batch, freq * frames, emb)
proj = proj.reshape(batch, -1, self.embedding_dim)
proj_norm = torch.norm(proj, p=2, dim=-1, keepdim=True)
projection_final = proj / (proj_norm + self.EPS)
# Mask head
mask_out = self.mask_layer(out).view(batch, n_frames, self.n_src, self.input_dim)
mask_out = mask_out.permute(0, 2, 3, 1)
mask_out = self.mask_act(mask_out)
return projection_final, mask_out
class Model(nn.Module):
def __init__(self, encoder, masker, decoder):
super().__init__()
self.encoder = encoder
self.masker = masker
self.decoder = decoder
def forward(self, x):
if len(x.shape) == 2:
x = x.unsqueeze(1)
tf_rep = self.encoder(x)
final_proj, mask_out = self.masker(mag(tf_rep))
return final_proj, mask_out
def separate(self, x):
"""Separate with mask-inference head, output waveforms"""
if len(x.shape) == 2:
x = x.unsqueeze(1)
tf_rep = self.encoder(x)
proj, mask_out = self.masker(mag(tf_rep))
masked = apply_mag_mask(tf_rep.unsqueeze(1), mask_out)
wavs = torch_utils.pad_x_to_y(self.decoder(masked), x)
dic_out = dict(tfrep=tf_rep, mask=mask_out, masked_tfrep=masked, proj=proj)
return wavs, dic_out
def dc_head_separate(self, x):
"""Cluster embeddings to produce binary masks, output waveforms"""
kmeans = KMeans(n_clusters=self.masker.n_src)
if len(x.shape) == 2:
x = x.unsqueeze(1)
tf_rep = self.encoder(x)
mag_spec = mag(tf_rep)
proj, mask_out = self.masker(mag_spec)
active_bins = ebased_vad(mag_spec)
active_proj = proj[active_bins.view(1, -1)]
#
bin_clusters = kmeans.fit_predict(active_proj.cpu().data.numpy())
# Create binary masks
est_mask_list = []
for i in range(self.masker.n_src):
# Add ones in all inactive bins in each mask.
mask = ~active_bins
mask[active_bins] = torch.from_numpy((bin_clusters == i)).to(mask.device)
est_mask_list.append(mask.float()) # Need float, not bool
# Go back to time domain
est_masks = torch.stack(est_mask_list, dim=1)
masked = apply_mag_mask(tf_rep, est_masks)
wavs = pad_x_to_y(self.decoder(masked), x)
dic_out = dict(tfrep=tf_rep, mask=mask_out, masked_tfrep=masked, proj=proj)
return wavs, dic_out
def load_best_model(train_conf, exp_dir):
"""Load best model after training.
Args:
train_conf (dict): dictionary as expected by `make_model_and_optimizer`
exp_dir(str): Experiment directory. Expects to find
`'best_k_models.json'` of `checkpoints` directory in it.
Returns:
nn.Module the best (or last) pretrained model according to the val_loss.
"""
# Create the model from recipe-local function
model, _ = make_model_and_optimizer(train_conf)
try:
# Last best model summary
with open(os.path.join(exp_dir, "best_k_models.json"), "r") as f:
best_k = json.load(f)
best_model_path = min(best_k, key=best_k.get)
except FileNotFoundError:
# Get last checkpoint
all_ckpt = os.listdir(os.path.join(exp_dir, "checkpoints/"))
all_ckpt = [
(ckpt, int("".join(filter(str.isdigit, os.path.basename(ckpt)))))
for ckpt in all_ckpt
if ckpt.find("ckpt") >= 0
]
all_ckpt.sort(key=lambda x: x[1])
best_model_path = os.path.join(exp_dir, "checkpoints", all_ckpt[-1][0])
# Load checkpoint
checkpoint = torch.load(best_model_path, map_location="cpu")
# Load state_dict into model.
model = torch_utils.load_state_dict_in(checkpoint["state_dict"], model)
model.eval()
return model