-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
252 lines (213 loc) · 10.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import datetime
import math
import os
import random
import argparse
import sys
import time
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from fastprogress import master_bar, progress_bar
from torch.nn import utils
from torch.utils import tensorboard
from data import AutoVCDataset, get_loader
from hp import hp
from model_vc import Generator
def train(args):
print("[BACKEND] Setting up paths and training.")
out_path = Path(hp.output_path)
os.makedirs(out_path, exist_ok=True)
device = torch.device(hp.device)
if args.mel_path is None:
ds_root = Path(hp.data_root)/'wav48_silence_trimmed'
else:
print("[DATA] Using precomputed mels from ", args.mel_path)
ds_root = Path(args.mel_path)
spk_folders = sorted(list(ds_root.iterdir()))
print(f"[DATA] Found a total of {len(spk_folders)} speakers")
# Gather training / testing paths.
random.seed(hp.seed)
train_spk_folders = sorted(random.sample(spk_folders, k=hp.n_train_speakers))
test_spk_folders = sorted(list(set(spk_folders) - set(train_spk_folders)))
train_files = []
for pth in train_spk_folders: train_files.extend(list(pth.iterdir()))
test_files = []
for pth in test_spk_folders: test_files.extend(list(pth.iterdir()))
print(f"[DATA] Split into {len(train_spk_folders)} train speakers ({len(train_files)} files)")
print(f"[DATA] and {len(test_spk_folders)} test speakers ({len(test_files)} files)")
# Getting embeddings:
sse = torch.hub.load('RF5/simple-speaker-embedding', 'gru_embedder').to(device)
sse.eval()
mb = master_bar(spk_folders)
spk_out_path = Path(hp.speaker_embedding_dir)
spk_embs = {}
os.makedirs(spk_out_path, exist_ok=True)
print("[SPEAKER EMBEDDING] Gathering speaker embeddings")
for spk_folder in mb:
random.seed(hp.seed)
sample_uttrs = random.sample(list(spk_folder.iterdir()), k=hp.n_uttr_per_spk_embedding)
embs = []
if (spk_out_path/f"{spk_folder.stem}_sse_emb.pt").is_file():
spk_embs[spk_folder.stem] = torch.load(spk_out_path/f"{spk_folder.stem}_sse_emb.pt")
continue
for i, uttr_pth in progress_bar(enumerate(sample_uttrs), total=len(sample_uttrs), parent=mb):
mb.child.comment = f"processing speaker {spk_folder.stem} ({i} of {len(sample_uttrs)})"
mel = sse.melspec_from_file(uttr_pth).to(device)
if str(uttr_pth).endswith('.pt'):
raise NotImplementedError(("If spectrograms are not precomputed, please do not use pre-computed mel-spectrograms in args."))
with torch.no_grad():
embedding = sse(mel[None])[0]
embs.append(embedding.cpu())
emb = torch.stack(embs, dim=0)
emb = emb.mean(dim=0)
spk_embs[spk_folder.stem] = emb
torch.save(emb, spk_out_path/f"{spk_folder.stem}_sse_emb.pt")
del sse
torch.cuda.empty_cache()
# Patch in LJSpeech:
if args.lj_path is not None:
print("[DATA] Adding LJSpeech")
ljpath = Path(args.lj_path)
split_folder = [v for v in ljpath.iterdir() if v.is_dir() and 'split' in v.stem]
if len(split_folder) != 1: raise AssertionError("Split folder not found.")
split_folder = split_folder[0]
with open(split_folder/'train.txt', 'r') as f: lj_trn_files = f.readlines()
with open(split_folder/'validation.txt', 'r') as f: lj_eval_files = f.readlines()
lj_trn_files = [ljpath/f"{f.strip()}.wav" for f in lj_trn_files]
lj_eval_files = [ljpath/f"{f.strip()}.wav" for f in lj_eval_files]
lj_emb = torch.load(ljpath/'lj_sse_emb100.pt').cpu()
spk_embs['wavs'] = lj_emb # cheeky hack to make it work nicely with VCTK loader
train_files += lj_trn_files
test_files += lj_eval_files
print("[DATA] Constructing final dataloaders")
train_dl = get_loader(train_files, spk_embs, hp.len_crop, hp.bs,
shuffle=True, shift=hp.mel_shift, scale=hp.mel_scale)
test_dl = get_loader(test_files, spk_embs, hp.len_crop, hp.bs,
shuffle=False, shift=hp.mel_shift, scale=hp.mel_scale)
print("[LOGGING] Setting up logger")
writer = tensorboard.writer.SummaryWriter(out_path)
keys = ['G/loss_id','G/loss_id_psnt','G/loss_cd']
print("[MODEL] Setting up model")
G = Generator(32, 256, 512, 32).to(device)
opt = torch.optim.Adam(G.parameters(), hp.lr)
if args.fp16:
print("[TRAIN] Using fp16 training.")
scaler = GradScaler()
if args.checkpoint is not None:
ckpt = torch.load(args.checkpoint, map_location=device)
epoch = ckpt['epoch']
ite = ckpt['iter']
G.load_state_dict(ckpt['model_state_dict'])
opt.load_state_dict(ckpt['opt_state_dict'])
print(f"[CHECKPOINT] Loaded checkpoint starting from epoch {epoch} (iter {ite})",
f" with last known loss {ckpt['loss']:6.5f}")
print("[TRAIN] Beginning training")
start_time = time.time()
running_loss = 0.0
n_epochs = math.ceil(hp.n_iters / len(train_dl))
iter = 0
mb = master_bar(range(n_epochs))
for epoch in mb:
G.train()
pb = progress_bar(enumerate(train_dl), total=len(train_dl), parent=mb)
for i, (x_src, s_src) in pb:
x_src = x_src.to(device)
s_src = s_src.to(device)
opt.zero_grad()
# fp16 enable
if args.fp16:
with autocast():
# Identity mapping loss
x_identic, x_identic_psnt, code_real = G(x_src, s_src, s_src)
g_loss_id = F.mse_loss(x_src, x_identic.squeeze(1))
g_loss_id_psnt = F.mse_loss(x_src, x_identic_psnt.squeeze(1))
# Code semantic loss.
code_reconst = G(x_identic_psnt.squeeze(1), s_src, None)
g_loss_cd = F.l1_loss(code_real, code_reconst)
g_loss = g_loss_id + hp.mu*g_loss_id_psnt + hp.lamb*g_loss_cd
scaler.scale(g_loss).backward()
scaler.step(opt)
scaler.update()
else:
# Identity mapping loss
x_identic, x_identic_psnt, code_real = G(x_src, s_src, s_src)
g_loss_id = F.mse_loss(x_src, x_identic.squeeze(1))
g_loss_id_psnt = F.mse_loss(x_src, x_identic_psnt.squeeze(1))
# Code semantic loss.
code_reconst = G(x_identic_psnt.squeeze(1), s_src, None)
g_loss_cd = F.l1_loss(code_real, code_reconst)
g_loss = g_loss_id + hp.mu*g_loss_id_psnt + hp.lamb*g_loss_cd
g_loss.backward()
opt.step()
loss = {}
loss['G/loss_id'] = g_loss_id.item()
loss['G/loss_id_psnt'] = g_loss_id_psnt.item()
loss['G/loss_cd'] = g_loss_cd.item()
# lerp smooth running loss
running_loss = running_loss + 0.1*(float(g_loss) - running_loss)
mb.child.comment = f"loss = {float(running_loss):6.5f}"
if iter % hp.print_log_interval == 0:
et = time.time() - start_time
et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], Iteration [{}/{}]".format(et, iter+1, hp.n_iters)
for tag in keys:
log += ", {}: {:.4f}".format(tag, loss[tag])
mb.write(log)
if iter % hp.tb_log_interval == 0:
for tag in keys: writer.add_scalar(tag, loss[tag], iter)
writer.add_scalar('G/loss', g_loss.item(), iter)
iter += 1
if iter >= hp.n_iters:
print("[TRAIN] Training completed.")
break
mb.write(f"[TRAIN] epoch {epoch} completed. Beginning eval.")
G.eval()
pb = progress_bar(enumerate(test_dl), total=len(test_dl), parent=mb)
valid_losses = {tag: [] for tag in keys}
valid_losses['G/loss'] = []
for i, (x_src, s_src) in pb:
x_src = x_src.to(device)
s_src = s_src.to(device)
with torch.no_grad():
# Identity mapping loss
x_identic, x_identic_psnt, code_real = G(x_src, s_src, s_src)
g_loss_id = F.mse_loss(x_src, x_identic.squeeze(1))
g_loss_id_psnt = F.mse_loss(x_src, x_identic_psnt.squeeze(1))
# Code semantic loss.
code_reconst = G(x_identic_psnt.squeeze(1), s_src, None)
g_loss_cd = F.l1_loss(code_real, code_reconst)
g_loss = g_loss_id + hp.mu*g_loss_id_psnt + hp.lamb*g_loss_cd
valid_losses['G/loss_id'].append(g_loss_id.item())
valid_losses['G/loss_id_psnt'].append(g_loss_id_psnt.item())
valid_losses['G/loss_cd'].append(g_loss_cd.item())
valid_losses['G/loss'].append(g_loss.item())
mb.child.comment = f"loss = {float(g_loss):6.5f}"
valid_losses = {k: np.mean(valid_losses[k]) for k in valid_losses.keys()}
for tag in valid_losses.keys(): writer.add_scalar('valid/' + tag, valid_losses[tag], iter)
pst = [f"{k}: {valid_losses[k]:5.4f}" for k in valid_losses.keys()]
mb.write(f"[TRAIN] epoch {epoch} eval metrics: " + '\t'.join(pst))
if iter >= hp.n_iters: break
print("[CLEANUP] Saving model")
torch.save({
'epoch': epoch,
'iter': iter,
'model_state_dict': G.state_dict(),
'opt_state_dict': opt.state_dict(),
'loss': valid_losses['G/loss']
}, out_path/'checkpoint_last.pth')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='autovc trainer')
parser.add_argument('--checkpoint', action='store', required=False, default=None,
help='checkpoint to restore from')
parser.add_argument('--fp16', required=False, default=False, action='store', type=bool,
help='use fp16 in training')
parser.add_argument('--mel_path', required=False, default=None, action='store',
help='path to precomputed spectrograms. Compute them on the fly if not.')
parser.add_argument('--lj_path', required=False, default=None,
help="Add LJSpeech dataset")
args = parser.parse_args()
train(args)