Skip to content

Commit

Permalink
Updates along torchaudio_contrib
Browse files Browse the repository at this point in the history
  • Loading branch information
ksanjeevan committed Jun 8, 2019
1 parent 6dd8093 commit 41884b9
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 25 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Example results:
- [Inference](#inference)
- [Training](#training)
- [Evaluation](#evaluation)
- [Improvements](#improvements)
- [To Do](#to-do)

#### Dependencies

Expand Down
31 changes: 14 additions & 17 deletions net/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
import torch.nn.functional as F
from torch.distributions import Normal, Uniform, HalfNormal

from torchaudio_contrib import STFT, StretchSpecTime, MelFilterbank, ComplexNorm, ApplyFilterbank


def torch_angle(t):
return torch.atan2(t[...,1], t[...,0])
from torchaudio_contrib import STFT, TimeStretch, MelFilterbank, ComplexNorm, ApplyFilterbank


def amplitude_to_db(spec, ref=1.0, amin=1e-10, top_db=80):
Expand Down Expand Up @@ -61,13 +57,14 @@ def spec_whiten(spec, eps=1):

return resu

def _num_stft_bins(lengths, fft_len, hop_len, pad):
return (lengths + 2 * pad - fft_len + hop_len) // hop_len

def _num_stft_bins(lengths, fft_length, hop_length, pad):
return (lengths + 2 * pad - fft_length + hop_length) // hop_length


class MelspectrogramStretch(nn.Module):

def __init__(self, hop_len=None, num_bands=128, fft_len=2048, norm='whiten', stretch_param=[0.4, 0.4]):
def __init__(self, hop_length=None, num_mels=128, fft_length=2048, norm='whiten', stretch_param=[0.4, 0.4]):

super(MelspectrogramStretch, self).__init__()

Expand All @@ -78,16 +75,16 @@ def __init__(self, hop_len=None, num_bands=128, fft_len=2048, norm='whiten', str
'db' : amplitude_to_db
}.get(norm, None)

self.stft = STFT(fft_len=fft_len, hop_len=hop_len)
self.pv = StretchSpecTime(hop_len=self.stft.hop_len, num_bins=fft_len//2+1)
self.stft = STFT(fft_length=fft_length, hop_length=fft_length//4)
self.pv = TimeStretch(hop_length=self.stft.hop_length, num_freqs=fft_length//2+1)
self.cn = ComplexNorm(power=2.)

fb = MelFilterbank(num_bands=num_bands).get_filterbank()
fb = MelFilterbank(num_mels=num_mels, max_freq=1.0).get_filterbank()
self.app_fb = ApplyFilterbank(fb)

self.fft_len = fft_len
self.hop_len = self.stft.hop_len
self.num_bands = num_bands
self.fft_length = fft_length
self.hop_length = self.stft.hop_length
self.num_mels = num_mels
self.stretch_param = stretch_param

self.counter = 0
Expand All @@ -96,7 +93,7 @@ def forward(self, x, lengths=None):
x = self.stft(x)

if lengths is not None:
lengths = _num_stft_bins(lengths, self.fft_len, self.hop_len, self.fft_len//2)
lengths = _num_stft_bins(lengths, self.fft_length, self.hop_length, self.fft_length//2)

if torch.rand(1)[0] <= self.prob and self.training:
rate = 1 - self.dist.sample()
Expand All @@ -114,6 +111,6 @@ def forward(self, x, lengths=None):
return x

def __repr__(self):
param_str = '(num_bands={}, fft_len={}, norm={}, stretch_param={})'.format(
self.num_bands, self.fft_len, self.norm.__name__, self.stretch_param)
param_str = '(num_mels={}, fft_length={}, norm={}, stretch_param={})'.format(
self.num_mels, self.fft_length, self.norm.__name__, self.stretch_param)
return self.__class__.__name__ + param_str
8 changes: 4 additions & 4 deletions net/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def __init__(self, classes, config={}, state_dict=None):
self.classes = classes
self.lstm_units = 64
self.lstm_layers = 2
self.spec = MelspectrogramStretch(hop_len=None,
num_bands=128,
fft_len=2048,
self.spec = MelspectrogramStretch(hop_length=None,
num_mels=128,
fft_length=2048,
norm='whiten',
stretch_param=[0.4, 0.4])

# shape -> (channel, freq, token_time)
self.net = parse_cfg(config['cfg'], in_shape=[in_chan, self.spec.num_bands, 400])
self.net = parse_cfg(config['cfg'], in_shape=[in_chan, self.spec.num_mels, 400])

def _many_to_one(self, t, lengths):
return t[torch.arange(t.size(0)), lengths - 1]
Expand Down
3 changes: 0 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ def train_main(config, resume):

data_config = config['data']

tsf_name = config['transforms']['type']
tsf_args = config['transforms']['args']

t_transforms = _get_transform(config, 'train')
v_transforms = _get_transform(config, 'val')
print(t_transforms)
Expand Down

0 comments on commit 41884b9

Please sign in to comment.