Skip to content
This repository has been archived by the owner on Aug 10, 2022. It is now read-only.

Commit

Permalink
fix: fix spleeter architecture as in #2
Browse files Browse the repository at this point in the history
  • Loading branch information
Tuan Nguyen Duc committed Jul 12, 2020
1 parent 6abf4fe commit 853d4bb
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 32 deletions.
Binary file modified output/out_0.wav
Binary file not shown.
Binary file modified output/out_1.wav
Binary file not shown.
58 changes: 34 additions & 24 deletions spleeter/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@


def down_block(in_filters, out_filters):
return nn.Sequential(
nn.Conv2d(in_filters, out_filters, kernel_size=5,
stride=2, padding=2,
),
nn.BatchNorm2d(out_filters, track_running_stats=True),
nn.LeakyReLU(0.3)
return nn.Conv2d(in_filters, out_filters, kernel_size=5,
stride=2, padding=2,
), nn.Sequential(
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01),
nn.LeakyReLU(0.2)
)


Expand All @@ -18,7 +17,7 @@ def up_block(in_filters, out_filters, dropout=False):
stride=2, padding=2, output_padding=1
),
nn.ReLU(),
nn.BatchNorm2d(out_filters, track_running_stats=True)
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01)
]
if dropout:
layers.append(nn.Dropout(0.5))
Expand All @@ -29,12 +28,12 @@ def up_block(in_filters, out_filters, dropout=False):
class UNet(nn.Module):
def __init__(self, in_channels=2):
super(UNet, self).__init__()
self.down1 = down_block(in_channels, 16)
self.down2 = down_block(16, 32)
self.down3 = down_block(32, 64)
self.down4 = down_block(64, 128)
self.down5 = down_block(128, 256)
self.down6 = down_block(256, 512)
self.down1_conv, self.down1_act = down_block(in_channels, 16)
self.down2_conv, self.down2_act = down_block(16, 32)
self.down3_conv, self.down3_act = down_block(32, 64)
self.down4_conv, self.down4_act = down_block(64, 128)
self.down5_conv, self.down5_act = down_block(128, 256)
self.down6_conv, self.down6_act = down_block(256, 512)

self.up1 = up_block(512, 256, dropout=True)
self.up2 = up_block(512, 128, dropout=True)
Expand All @@ -48,19 +47,30 @@ def __init__(self, in_channels=2):
)

def forward(self, x):
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
d1_conv = self.down1_conv(x)
d1 = self.down1_act(d1_conv)

d2_conv = self.down2_conv(d1)
d2 = self.down2_act(d2_conv)

d3_conv = self.down3_conv(d2)
d3 = self.down3_act(d3_conv)

d4_conv = self.down4_conv(d3)
d4 = self.down4_act(d4_conv)

d5_conv = self.down5_conv(d4)
d5 = self.down5_act(d5_conv)

d6_conv = self.down6_conv(d5)
d6 = self.down6_act(d6_conv)

u1 = self.up1(d6)
u2 = self.up2(torch.cat([d5, u1], axis=1))
u3 = self.up3(torch.cat([d4, u2], axis=1))
u4 = self.up4(torch.cat([d3, u3], axis=1))
u5 = self.up5(torch.cat([d2, u4], axis=1))
u6 = self.up6(torch.cat([d1, u5], axis=1))
u2 = self.up2(torch.cat([d5_conv, u1], axis=1))
u3 = self.up3(torch.cat([d4_conv, u2], axis=1))
u4 = self.up4(torch.cat([d3_conv, u3], axis=1))
u5 = self.up5(torch.cat([d2_conv, u4], axis=1))
u6 = self.up6(torch.cat([d1_conv, u5], axis=1))
u7 = self.up7(u6)
return u7 * x

Expand Down
14 changes: 7 additions & 7 deletions spleeter/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ def tf2pytorch(checkpoint_path, num_instrumments):
else:
bn_suffix = "_" + str(bn_idx)

output['down{}.0.weight'.format(j)] = np.transpose(
output['down{}_conv.weight'.format(j)] = np.transpose(
tf_vars["conv2d{}/kernel".format(conv_suffix)], (3, 2, 0, 1))
# print('conv dtype: ',output['down{}.0.weight'.format(j)].dtype)
output['down{}.0.bias'.format(
output['down{}_conv.bias'.format(
j)] = tf_vars["conv2d{}/bias".format(conv_suffix)]

output['down{}.1.weight'.format(
output['down{}_act.0.weight'.format(
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
output['down{}.1.bias'.format(
output['down{}_act.0.bias'.format(
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
output['down{}.1.running_mean'.format(
output['down{}_act.0.running_mean'.format(
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
output['down{}.1.running_var'.format(
output['down{}_act.0.running_var'.format(
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]

conv_idx += 1
Expand All @@ -65,7 +65,7 @@ def tf2pytorch(checkpoint_path, num_instrumments):
bn_suffix= "_" + str(bn_idx)

output['up{}.0.weight'.format(j)] = np.transpose(
tf_vars["conv2d_transpose{}/kernel".format(tconv_suffix)], (3, 2, 0, 1))
tf_vars["conv2d_transpose{}/kernel".format(tconv_suffix)], (3,2,0, 1))
output['up{}.0.bias'.format(
j)] = tf_vars["conv2d_transpose{}/bias".format(tconv_suffix)]
output['up{}.2.weight'.format(
Expand Down
3 changes: 2 additions & 1 deletion test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torchaudio
from librosa.core import load
from librosa.output import write_wav
import numpy as np

from spleeter.estimator import Estimator

Expand All @@ -17,4 +18,4 @@
for i in range(len(wavs)):
fname = 'output/out_{}.wav'.format(i)
print('Writing ',fname)
write_wav(fname, wavs[i].squeeze().numpy(), sr)
write_wav(fname, np.asfortranarray(wavs[i].squeeze().numpy()), sr)

0 comments on commit 853d4bb

Please sign in to comment.