-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutility.py
221 lines (193 loc) · 7.64 KB
/
utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import torch
import torchaudio
import numpy as np
import os
import time
import shutil
import uuid
import scipy.io
import matplotlib.pylab as plt
def set_device(device):
'''
set device to cuda if avaliable otherwise device will be cpu
'''
if (device == 'cuda') & torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
device = 'cpu'
return device
@torch.no_grad()
def get_response(x, net):
'''
get impulse and magnitude resoponse generated by the learned parameters of FDN()
Args net (nn.Module): trained FDN() network
num (int): number of frequnecy samples to evaluate the impulse response on
Output h (torch.tensor): FDN impulse response
H (torch.tensor): FDN magnitude response
'''
with torch.no_grad():
H = net(x)
H = torch.sum(H, dim=-1)
h = torch.fft.irfft(H)
return H, h
def get_frequency_samples(num):
'''
get frequency samples (in radians) sampled at linearly spaced points along the unit circle
Args num (int): number of frequency samples
Output frequency samples in radians between [0, pi]
'''
# angle = torch.arange(0, 1+(1/num)/2, 1/num)
# abs = torch.ones(num+1)
angle = torch.linspace(0, 1, num)
abs = torch.ones(num)
return torch.polar(abs, angle * np.pi)
def weights_init_normal(m):
'''
Takes in a module and initializes all linear layers with weight
values taken from a normal distribution.
'''
classname = m.__class__.__name__
# for every Linear layer in a model
if classname.find('Linear') != -1:
y = m.in_features * m.in_features
# m.weight.data shoud be taken from a normal distribution
m.weight.data.normal_(0.0,1/np.sqrt(y))
# m.bias.data should be 0
m.bias.data.fill_(0)
def save_parametes(net, dir_path, filename, scattering=False):
'''
save parameters of FDN() net to .mat file
Args net (nn.Module): trained FDN() network
dir_path (string): path to output firectory
filename (string): name of the file
Output param (dictionary of tensors): FDN() net parameters
param_np (dictionary of numpy arrays): FDN() net parameters
'''
if not os.path.exists(dir_path):
os.makedirs(dir_path)
param = fdn2dir(net)
param_np = {}
for name, value in param.items():
try:
param_np[name] = value.squeeze().cpu().numpy()
except AttributeError:
param_np[name] = value
if ('m' not in param_np):
param['m'] = net.m
param_np['m'] = net.m.squeeze().cpu().numpy()
if ('m_L' not in param_np) & net.scattering == True:
param['m_L'] = net.m_L
param_np['m_L'] = net.m_L.squeeze().cpu().numpy()
param['m_R'] = net.m_R
param_np['m_R'] = net.m_R.squeeze().cpu().numpy()
# save parameters in numpy format
scipy.io.savemat(os.path.join(dir_path, filename),
param_np)
if scattering:
scipy.io.savemat(os.path.join(dir_path, "scat_"+filename),
net.get_param_dict())
return param, param_np
def save_filters(net, filter_designer, dir_path, filename):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
param_np = {}
param_np['G_SOS'] = net.G_SOS
# param_np['G'] = net.G.squeeze().cpu().numpy()
param_np['TC_SOS'] = net.TC_SOS
# param_np['TC'] = net.TC.squeeze().cpu().numpy()
param_np['T'] = filter_designer.T
param_np['A'] = filter_designer.A
param_np['N'] = filter_designer.N
param_np['f_bands'] = filter_designer.f_bands
# save parameters in numpy format
scipy.io.savemat(os.path.join(dir_path, filename),
param_np)
def fdn2dir(net):
'''
save learnable parameters to a dictionary
Args net (nn.Module): trained FDN() network
Output d (dictionary of tensors): FDN() net parameters
'''
d = {} # enpty dictionary
for name, param in net.named_parameters():
if param.requires_grad:
d[name] = param.data
d['gain_per_sample'] = net.gain_per_sample
d['N'] = net.N
return d
def get_str_results(epoch=None, train_loss=None, valid_loss=None, time=None, lossF = None, lossT = None):
'''construct the string that has to be print at the end of the epoch'''
to_print=''
if epoch is not None:
to_print += 'epoch: {:3d} '.format(epoch)
if train_loss is not None:
to_print += '- train_loss: {:6.4f} '.format(train_loss[-1])
if valid_loss is not None:
to_print += '- test_loss: {:6.4f} '.format(valid_loss[-1])
if time is not None:
to_print += '- time: {:6.4f} s'.format(time)
if lossF is not None:
to_print += '- lossF: {:6.4f}'.format(lossF)
if lossT is not None:
to_print += '- lossT: {:6.4f}'.format(lossT)
return to_print
def save_output(net,
path_dir,
save_audio = False,
samplerate = 48000,):
'''
create output directory and save FDN parameters from FDN() network and config file
Args net (nn.Module): trained FDN() network
output_dir (string): path to output directory where a new dedicated folder will be created
save_audio (bool): if true save impulse response as .wav file with sampling rate defined in samplerate
samplerate (int): sampling rate
Output full_output_dir (string): path to dedicated output directory
param (dictionary of tensors): FDN() net parameters
filename (string): filename of the audiofile if save_audio=True, None otherwise
'''
# create output folder
if not os.path.exists(path_dir):
os.makedirs(path_dir)
path_dir = os.path.join('output', time.strftime("%Y%m%d-%H%M%S"))
os.makedirs(path_dir)
# save parameters
param, _ = save_parametes(net, path_dir, filename='parameters.mat')
filename = None
if save_audio:
# compute outputs
[H, h] = get_response(net, samplerate)
h_norm = h / torch.norm(h)
# save outputs
unique_str = str(uuid.uuid4())
filename = os.path.join(path_dir, unique_str+'_output.wav')
torchaudio.save(filename,
torch.stack((h_norm.squeeze(0),h_norm.squeeze(0)),1).detach().cpu().numpy(),
samplerate,
bits_per_sample=32,
channels_first=False)
return path_dir, param, filename
def save_loss(train_loss, valid_loss, output_dir, save_plot=True, filename=''):
'''
save training and validation loss values in .mat format
Args train_loss (list): training loss values at each epoch
valid_loss (list): validation loss values at each epoch
output_dir (string): path to output directory
save_plot (bool): if True saves the plot of the losses in .pdf format
filename (string): additional string to add before .pdf and .mat
'''
if not os.path.exists(output_dir):
os.makedirs(output_dir)
losses = {}
losses['train'] = train_loss
losses['valid'] = valid_loss
n_epochs = len(train_loss)
if save_plot:
plt.plot(range(1,n_epochs+1), train_loss, label='training')
plt.plot(range(1,n_epochs+1), valid_loss, label='validation')
plt.legend()
plt.xlabel('epoch n')
plt.ylabel('loss')
plt.savefig(os.path.join(output_dir,'losses'+filename+'.pdf'))
scipy.io.savemat(os.path.join(output_dir,'losses'+filename+'.mat'), losses)
def to_complex(X):
return torch.complex(X, torch.zeros_like(X))