diff --git a/README.md b/README.md index d9f5d7d..ad618ab 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Example results: - [Inference](#inference) - [Training](#training) - [Evaluation](#evaluation) -- [Improvements](#improvements) +- [To Do](#to-do) #### Dependencies diff --git a/net/audio.py b/net/audio.py index 28c89b2..7c4b594 100644 --- a/net/audio.py +++ b/net/audio.py @@ -4,11 +4,7 @@ import torch.nn.functional as F from torch.distributions import Normal, Uniform, HalfNormal -from torchaudio_contrib import STFT, StretchSpecTime, MelFilterbank, ComplexNorm, ApplyFilterbank - - -def torch_angle(t): - return torch.atan2(t[...,1], t[...,0]) +from torchaudio_contrib import STFT, TimeStretch, MelFilterbank, ComplexNorm, ApplyFilterbank def amplitude_to_db(spec, ref=1.0, amin=1e-10, top_db=80): @@ -61,13 +57,14 @@ def spec_whiten(spec, eps=1): return resu -def _num_stft_bins(lengths, fft_len, hop_len, pad): - return (lengths + 2 * pad - fft_len + hop_len) // hop_len + +def _num_stft_bins(lengths, fft_length, hop_length, pad): + return (lengths + 2 * pad - fft_length + hop_length) // hop_length class MelspectrogramStretch(nn.Module): - def __init__(self, hop_len=None, num_bands=128, fft_len=2048, norm='whiten', stretch_param=[0.4, 0.4]): + def __init__(self, hop_length=None, num_mels=128, fft_length=2048, norm='whiten', stretch_param=[0.4, 0.4]): super(MelspectrogramStretch, self).__init__() @@ -78,16 +75,16 @@ def __init__(self, hop_len=None, num_bands=128, fft_len=2048, norm='whiten', str 'db' : amplitude_to_db }.get(norm, None) - self.stft = STFT(fft_len=fft_len, hop_len=hop_len) - self.pv = StretchSpecTime(hop_len=self.stft.hop_len, num_bins=fft_len//2+1) + self.stft = STFT(fft_length=fft_length, hop_length=fft_length//4) + self.pv = TimeStretch(hop_length=self.stft.hop_length, num_freqs=fft_length//2+1) self.cn = ComplexNorm(power=2.) - fb = MelFilterbank(num_bands=num_bands).get_filterbank() + fb = MelFilterbank(num_mels=num_mels, max_freq=1.0).get_filterbank() self.app_fb = ApplyFilterbank(fb) - self.fft_len = fft_len - self.hop_len = self.stft.hop_len - self.num_bands = num_bands + self.fft_length = fft_length + self.hop_length = self.stft.hop_length + self.num_mels = num_mels self.stretch_param = stretch_param self.counter = 0 @@ -96,7 +93,7 @@ def forward(self, x, lengths=None): x = self.stft(x) if lengths is not None: - lengths = _num_stft_bins(lengths, self.fft_len, self.hop_len, self.fft_len//2) + lengths = _num_stft_bins(lengths, self.fft_length, self.hop_length, self.fft_length//2) if torch.rand(1)[0] <= self.prob and self.training: rate = 1 - self.dist.sample() @@ -114,6 +111,6 @@ def forward(self, x, lengths=None): return x def __repr__(self): - param_str = '(num_bands={}, fft_len={}, norm={}, stretch_param={})'.format( - self.num_bands, self.fft_len, self.norm.__name__, self.stretch_param) + param_str = '(num_mels={}, fft_length={}, norm={}, stretch_param={})'.format( + self.num_mels, self.fft_length, self.norm.__name__, self.stretch_param) return self.__class__.__name__ + param_str diff --git a/net/model.py b/net/model.py index 835da65..e4a417d 100644 --- a/net/model.py +++ b/net/model.py @@ -19,14 +19,14 @@ def __init__(self, classes, config={}, state_dict=None): self.classes = classes self.lstm_units = 64 self.lstm_layers = 2 - self.spec = MelspectrogramStretch(hop_len=None, - num_bands=128, - fft_len=2048, + self.spec = MelspectrogramStretch(hop_length=None, + num_mels=128, + fft_length=2048, norm='whiten', stretch_param=[0.4, 0.4]) # shape -> (channel, freq, token_time) - self.net = parse_cfg(config['cfg'], in_shape=[in_chan, self.spec.num_bands, 400]) + self.net = parse_cfg(config['cfg'], in_shape=[in_chan, self.spec.num_mels, 400]) def _many_to_one(self, t, lengths): return t[torch.arange(t.size(0)), lengths - 1] diff --git a/run.py b/run.py index b6529a1..84ed93b 100755 --- a/run.py +++ b/run.py @@ -71,9 +71,6 @@ def train_main(config, resume): data_config = config['data'] - tsf_name = config['transforms']['type'] - tsf_args = config['transforms']['args'] - t_transforms = _get_transform(config, 'train') v_transforms = _get_transform(config, 'val') print(t_transforms)