Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wavesplit implementation #70

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions egs/wham/WaveSplit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
### WaveSplit

we train on 1 sec now.

tried with 256 embedding dimension.

still does not work with oracle embeddings.

not clear how in sep stack loss at every layer is computed ( is the same output layer used in all ?).
Also no mention in the paper about output layer and that first conv has no skip connection.

118 changes: 118 additions & 0 deletions egs/wham/WaveSplit/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
import random
import soundfile as sf
import torch
import yaml
import json
import argparse
import pandas as pd
from tqdm import tqdm
from pprint import pprint

from asteroid.metrics import get_metrics
from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr
from asteroid.data.wham_dataset import WhamDataset
from asteroid.utils import tensors_to_device

from model import load_best_model

parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, required=True,
help='One of `enh_single`, `enh_both`, '
'`sep_clean` or `sep_noisy`')
parser.add_argument('--test_dir', type=str, required=True,
help='Test directory including the json files')
parser.add_argument('--use_gpu', type=int, default=0,
help='Whether to use the GPU for model execution')
parser.add_argument('--exp_dir', default='exp/tmp',
help='Experiment root')
parser.add_argument('--n_save_ex', type=int, default=50,
help='Number of audio examples to save, -1 means all')

compute_metrics = ['si_sdr', 'sdr', 'sir', 'sar', 'stoi']


def main(conf):
model = load_best_model(conf['train_conf'], conf['exp_dir'])
# Handle device placement
if conf['use_gpu']:
model.cuda()
model_device = next(model.parameters()).device
test_set = WhamDataset(conf['test_dir'], conf['task'],
sample_rate=conf['sample_rate'],
nondefault_nsrc=model.masker.n_src,
segment=None) # Uses all segment length
# Used to reorder sources only
loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')

# Randomly choose the indexes of sentences to save.
ex_save_dir = os.path.join(conf['exp_dir'], 'examples/')
if conf['n_save_ex'] == -1:
conf['n_save_ex'] = len(test_set)
save_idx = random.sample(range(len(test_set)), conf['n_save_ex'])
series_list = []
torch.no_grad().__enter__()
for idx in tqdm(range(len(test_set))):
# Forward the network on the mixture.
mix, sources = tensors_to_device(test_set[idx], device=model_device)
est_sources = model(mix[None, None])
loss, reordered_sources = loss_func(est_sources, sources[None],
return_est=True)
mix_np = mix[None].cpu().data.numpy()
sources_np = sources.squeeze().cpu().data.numpy()
est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
utt_metrics = get_metrics(mix_np, sources_np, est_sources_np,
sample_rate=conf['sample_rate'])
utt_metrics['mix_path'] = test_set.mix[idx][0]
series_list.append(pd.Series(utt_metrics))

# Save some examples in a folder. Wav files and metrics as text.
if idx in save_idx:
local_save_dir = os.path.join(ex_save_dir, 'ex_{}/'.format(idx))
os.makedirs(local_save_dir, exist_ok=True)
sf.write(local_save_dir + "mixture.wav", mix_np[0],
conf['sample_rate'])
# Loop over the sources and estimates
for src_idx, src in enumerate(sources_np):
sf.write(local_save_dir + "s{}.wav".format(src_idx+1), src,
conf['sample_rate'])
for src_idx, est_src in enumerate(est_sources_np):
sf.write(local_save_dir + "s{}_estimate.wav".format(src_idx+1),
est_src, conf['sample_rate'])
# Write local metrics to the example folder.
with open(local_save_dir + 'metrics.json', 'w') as f:
json.dump(utt_metrics, f, indent=0)

# Save all metrics to the experiment folder.
all_metrics_df = pd.DataFrame(series_list)
all_metrics_df.to_csv(os.path.join(conf['exp_dir'], 'all_metrics.csv'))

# Print and save summary metrics
final_results = {}
for metric_name in compute_metrics:
input_metric_name = 'input_' + metric_name
ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name]
final_results[metric_name] = all_metrics_df[metric_name].mean()
final_results[metric_name + '_imp'] = ldf.mean()
print('Overall metrics :')
pprint(final_results)
with open(os.path.join(conf['exp_dir'], 'final_metrics.json'), 'w') as f:
json.dump(final_results, f, indent=0)


if __name__ == '__main__':
args = parser.parse_args()
arg_dic = dict(vars(args))

# Load training config
conf_path = os.path.join(args.exp_dir, 'conf.yml')
with open(conf_path) as f:
train_conf = yaml.safe_load(f)
arg_dic['sample_rate'] = train_conf['data']['sample_rate']
arg_dic['train_conf'] = train_conf

if args.task != arg_dic['train_conf']['data']['task']:
print("Warning : the task used to test is different than "
"the one from training, be sure this is what you want.")

main(arg_dic)
41 changes: 41 additions & 0 deletions egs/wham/WaveSplit/local/conf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Filterbank config
filterbank:
n_filters: 64
kernel_size: 16
stride: 8
# Network config
masknet:
in_chan: 64
n_src: 2
out_chan: 64
bn_chan: 128
hid_size: 128
chunk_size: 250
hop_size: 125
n_repeats: 6
mask_act: 'sigmoid'
bidirectional: true
dropout: 0
# Training config
training:
epochs: 200
batch_size: 4
num_workers: 4
half_lr: yes
early_stop: yes
gradient_clipping: 5000
# Optim config
optim:
optimizer: adam
lr: 0.001
weight_decay: 0.
# Data config
data:
train_dir: data/wav8k/min/tr/
valid_dir: data/wav8k/min/cv/
data_augmentation: True
task: sep_clean
nondefault_nsrc:
sample_rate: 8000
mode: min
segment: 1
38 changes: 38 additions & 0 deletions egs/wham/WaveSplit/local/convert_sphere2wav.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/bin/bash
# MIT Copyright (c) 2018 Kaituo XU


sphere_dir=tmp
wav_dir=tmp

. utils/parse_options.sh || exit 1;


echo "Download sph2pipe_v2.5 into egs/tools"
mkdir -p ../../tools
wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools
cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd -

echo "Convert sphere format to wav format"
sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe

if [ ! -x $sph2pipe ]; then
echo "Could not find (or execute) the sph2pipe program at $sph2pipe";
exit 1;
fi

tmp=data/local/
mkdir -p $tmp

[ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list

if [ ! -d $wav_dir ]; then
while read line; do
wav=`echo "$line" | sed "s:wv[12]:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'`
echo $wav
mkdir -p `dirname $wav`
$sph2pipe -f wav $line > $wav
done < $tmp/sph.list > $tmp/wav.list
else
echo "Do you already get wav files? if not, please remove $wav_dir"
fi
32 changes: 32 additions & 0 deletions egs/wham/WaveSplit/local/prepare_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash

wav_dir=tmp
out_dir=tmp
python_path=python

. utils/parse_options.sh

## Download WHAM noises
mkdir -p $out_dir
echo "Download WHAM noises into $out_dir"
# If downloading stalls for more than 20s, relaunch from previous state.
wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir

echo "Download WHAM scripts into $out_dir"
wget https://storage.googleapis.com/whisper-public/wham_scripts.tar.gz -P $out_dir
mkdir -p $out_dir/wham_scripts
tar -xzvf $out_dir/wham_scripts.tar.gz -C $out_dir/wham_scripts
mv $out_dir/wham_scripts.tar.gz $out_dir/wham_scripts

wait

unzip $out_dir/wham_noise.zip $out_dir >> logs/unzip_wham.log

echo "Run python scripts to create the WHAM mixtures"
# Requires : Numpy, Scipy, Pandas, and Pysoundfile
cd $out_dir/wham_scripts
$python_path create_wham_from_scratch.py \
--wsj0-root $wav_dir \
--wham-noise-root $out_dir/wham_noise\
--output-dir $out_dir
cd -
84 changes: 84 additions & 0 deletions egs/wham/WaveSplit/local/preprocess_wham.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse
import json
import os
import soundfile as sf
import glob


def preprocess_task(task, in_dir, out_dir):
if not os.path.exists(out_dir):
os.makedirs(out_dir)

if task == "mix_both":
mix_both = glob.glob(os.path.join(in_dir, "mix_both", "*.wav"))
examples = []
for mix in mix_both:
filename = mix.split("/")[-1]
spk1_id = filename.split("_")[0][:3]
spk2_id = filename.split("_")[2][:3]
length = len(sf.SoundFile(mix))

noise = os.path.join(in_dir, "noise", filename)
s1 = os.path.join(in_dir, "s1", filename)
s2 = os.path.join(in_dir, "s2", filename)

ex = {"mix": mix, "sources": [s1 ,s2], "noise": noise, "spk_id": [spk1_id, spk2_id], "length": length}
examples.append(ex)

with open(os.path.join(out_dir, 'mix_both.json'), 'w') as f:
json.dump(examples, f, indent=4)

elif task == "mix_clean":
mix_clean = glob.glob(os.path.join(in_dir, "mix_clean", "*.wav"))
examples = []
for mix in mix_clean:
filename = mix.split("/")[-1]
spk1_id = filename.split("_")[0][:3]
spk2_id = filename.split("_")[2][:3]
length = len(sf.SoundFile(mix))

s1 = os.path.join(in_dir, "s1", filename)
s2 = os.path.join(in_dir, "s2", filename)

ex = {"mix": mix, "sources": [s1, s2], "spk_id": [spk1_id, spk2_id], "length": length}
examples.append(ex)

with open(os.path.join(out_dir, 'mix_clean.json'), 'w') as f:
json.dump(examples, f, indent=4)

elif task == "mix_single":
mix_single = glob.glob(os.path.join(in_dir, "mix_single", "*.wav"))
examples = []
for mix in mix_single:
filename = mix.split("/")[-1]
spk1_id = filename.split("_")[0][:3]
length = len(sf.SoundFile(mix))

s1 = os.path.join(in_dir, "s1", filename)

ex = {"mix": mix, "sources": [s1], "spk_id": [spk1_id], "length": length}
examples.append(ex)

with open(os.path.join(out_dir, 'mix_single.json'), 'w') as f:
json.dump(examples, f, indent=4)
else:
raise EnvironmentError


def preprocess(inp_args):
tasks = ['mix_both', 'mix_clean', 'mix_single']
for split in ["tr", "cv", "tt"]:
for task in tasks:
preprocess_task(task, os.path.join(inp_args.in_dir, split), os.path.join(inp_args.out_dir, split))



if __name__ == "__main__":
parser = argparse.ArgumentParser("WHAM data preprocessing")
parser.add_argument('--in_dir', type=str, default=None,
help='Directory path of wham including tr, cv and tt')
parser.add_argument('--out_dir', type=str, default=None,
help='Directory path to put output files')
args = parser.parse_args()
print(args)
preprocess(args)
29 changes: 29 additions & 0 deletions egs/wham/WaveSplit/local/resample_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import argparse
import os
from glob import glob
from distutils.dir_util import copy_tree
from scipy.signal import resample_poly
import soundfile as sf

parser = argparse.ArgumentParser("Script for resampling a dataset")
parser.add_argument("source_dir", type=str)
parser.add_argument("out_dir", type=str)
parser.add_argument("original_sr", type=int)
parser.add_argument("target_sr", type=int)
parser.add_argument("extension", type=str, default="wav")


def main(out_dir, original_sr, target_sr, extension):
assert original_sr >= target_sr, "Upsampling not supported"
wavs = glob(os.path.join(out_dir, "*.{}".format(extension)))
for wav in wavs:
data, fs = sf.read(wav)
assert fs == original_sr
data = resample_poly(data, target_sr, fs)
sf.write(wav, data)


if __name__ == "__main__":
args = parser.add_argument()
copy_tree(args.source_dir, args.out_dir)
main(args.out_dir, args.original_sr, args.target_sr, args.extension)
Loading