-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathTest.py
96 lines (62 loc) · 2.61 KB
/
Test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import unittest
import time_frequence as tf
import numpy as np
import torch
from torch.autograd import Variable
import librosa
def CalSNR(ref, sig):
ref_p = np.mean(np.square(ref))
noi_p = np.mean(np.square(sig - ref))
return 10 * (np.log10(ref_p) - np.log10(noi_p))
class TimeFrequencyTestCase(unittest.TestCase):
'''
def test_ifft(self):
print('\n#########TESTING IFFT##########')
N = 1024
signal = np.random.random(N)
input_ = np.fft.fft(signal, n=N)
ac = Variable(torch.from_numpy(np.real(input_[0]) * np.ones((1, N, 1, 1))).float())
input_ = np.reshape(input_[1:N//2+1], (1, 1, N//2, 1))
input_real = Variable(torch.from_numpy(np.real(input_)).float())
input_imag = Variable(torch.from_numpy(np.imag(input_)).float())
model = tf.ifft(n_fft=N)
output = model.forward(input_real, input_imag, ac).data.numpy().flatten()
snr = CalSNR(signal, output)
print("SNR:{} dB".format(snr))
self.assertTrue(snr > 60)
print('#########IFFT TESTED##########\n')
'''
def test_istft(self):
print("###########TESTING ISTFT###########")
signal = np.random.random(1016 * 1024)
spec = librosa.stft(signal, n_fft=1024, hop_length=512, center=False)
magn = np.real(spec)[np.newaxis, np.newaxis, :, :]
phase = np.imag(spec)[np.newaxis, np.newaxis, :, :]
ac = magn[:, :, 0, :]
magn = magn[:, :, 1:, :]
phase = phase[:, :, 1:, :]
magn = Variable(torch.from_numpy(magn).float())
phase = Variable(torch.from_numpy(phase).float())
ac = Variable(torch.from_numpy(ac).float())
model = tf.istft(1024, 512)
re_signal = model.forward(magn, phase, ac).data.numpy().flatten()
snr = CalSNR(signal[1024:-1024], re_signal[1024:-1024])
print("SNR:{} dB".format(snr))
self.assertTrue(snr > 60)
print("###########ISTFT TESTED###########\n")
def test_stft(self):
print("\n###########TESTING STFT###########")
N = 1024
signal = np.random.random(1016 * N)
input = Variable(torch.from_numpy(signal[np.newaxis, :]).float())
stft_model = tf.stft()
istft_model = tf.istft()
magn, phase, ac = stft_model(input)
re_signal = istft_model.forward(magn, phase, ac)
re_signal = re_signal.data.numpy().flatten()
snr = CalSNR(signal[N:-N], re_signal[N:-N])
print("SNR:{} dB".format(snr))
self.assertTrue(snr > 60)
print("###########STFT TESTED###########\n")
if __name__ == '__main__':
unittest.main()