Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wavesplit 2021 #454

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions egs/wham/wavesplit/dataloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
from torch.utils import data
import json
import os
import numpy as np
import soundfile as sf

DATASET = "WHAM"
# WHAM tasks
enh_single = {"mixture": "mix_single", "sources": ["s1"], "infos": ["noise"], "default_nsrc": 1}
enh_both = {"mixture": "mix_both", "sources": ["mix_clean"], "infos": ["noise"], "default_nsrc": 1}
sep_clean = {"mixture": "mix_clean", "sources": ["s1", "s2"], "infos": [], "default_nsrc": 2}
sep_noisy = {"mixture": "mix_both", "sources": ["s1", "s2"], "infos": ["noise"], "default_nsrc": 2}

WHAM_TASKS = {
"enhance_single": enh_single,
"enhance_both": enh_both,
"sep_clean": sep_clean,
"sep_noisy": sep_noisy,
}
# Aliases.
WHAM_TASKS["enh_single"] = WHAM_TASKS["enhance_single"]
WHAM_TASKS["enh_both"] = WHAM_TASKS["enhance_both"]


class WHAMID(data.Dataset):
"""Dataset class for WHAM source separation and speech enhancement tasks.

Args:
json_dir (str): The path to the directory containing the json files.
task (str): One of ``'enh_single'``, ``'enh_both'``, ``'sep_clean'`` or
``'sep_noisy'``.

* ``'enh_single'`` for single speaker speech enhancement.
* ``'enh_both'`` for multi speaker speech enhancement.
* ``'sep_clean'`` for two-speaker clean source separation.
* ``'sep_noisy'`` for two-speaker noisy source separation.

sample_rate (int, optional): The sampling rate of the wav files.
segment (float, optional): Length of the segments used for training,
in seconds. If None, use full utterances (e.g. for test).
nondefault_nsrc (int, optional): Number of sources in the training
targets.
If None, defaults to one for enhancement tasks and two for
separation tasks.
"""

def __init__(self, json_dir, task, sample_rate=8000, segment=4.0, nondefault_nsrc=None):
super(WHAMID, self).__init__()
if task not in WHAM_TASKS.keys():
raise ValueError(
"Unexpected task {}, expected one of " "{}".format(task, WHAM_TASKS.keys())
)
# Task setting
self.json_dir = json_dir
self.task = task
self.task_dict = WHAM_TASKS[task]
self.sample_rate = sample_rate
self.seg_len = None if segment is None else int(segment * sample_rate)
if not nondefault_nsrc:
self.n_src = self.task_dict["default_nsrc"]
else:
assert nondefault_nsrc >= self.task_dict["default_nsrc"]
self.n_src = nondefault_nsrc
self.like_test = self.seg_len is None
# Load json examples
ex_json = os.path.join(json_dir, self.task_dict["mixture"] + ".json")

with open(ex_json, "r") as f:
examples = json.load(f)

# Filter out short utterances only when segment is specified
self.examples = []
orig_len = len(examples)
drop_utt, drop_len = 0, 0
if not self.like_test:
for ex in examples: # Go backward
if ex["length"] < self.seg_len:
drop_utt += 1
drop_len += ex["length"]
else:
self.examples.append(ex)

print(
"Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
drop_utt, drop_len / sample_rate / 36000, orig_len, self.seg_len
)
)

# count total number of speakers
speakers = set()
for ex in self.examples:
for spk in ex["spk_id"]:
speakers.add(spk[:3])

print("Total number of speakers {}".format(len(list(speakers))))

# convert speakers id into integers
indx = 0
spk2indx = {}
for spk in list(speakers):
spk2indx[spk] = indx
indx += 1
self.spk2indx = spk2indx

for ex in self.examples:
new = []
for spk in ex["spk_id"]:
new.append(spk2indx[spk[:3]])
ex["spk_id"] = new

def __len__(self):
return len(self.examples)

def __getitem__(self, idx):
"""Gets a mixture/sources pair.
Returns:
mixture, vstack([source_arrays])
"""
c_ex = self.examples[idx]
# Random start
if c_ex["length"] == self.seg_len or self.like_test:
rand_start = 0
else:
rand_start = np.random.randint(0, c_ex["length"] - self.seg_len)
if self.like_test:
stop = None
else:
stop = rand_start + self.seg_len
# Load mixture
x, _ = sf.read(c_ex["mix"], start=rand_start, stop=stop, dtype="float32")
# seg_len = torch.as_tensor([len(x)])
# Load sources
source_arrays = []
for src in c_ex["sources"]:
s, _ = sf.read(src, start=rand_start, stop=stop, dtype="float32")
source_arrays.append(s)
sources = torch.from_numpy(np.vstack(source_arrays))

if np.random.random() > 0.5: # randomly permute (not sure if it can help but makes sense)
sources = torch.stack((sources[1], sources[0]))
c_ex["spk_id"] = [c_ex["spk_id"][1], c_ex["spk_id"][0]]

return torch.from_numpy(x), sources, torch.Tensor(c_ex["spk_id"]).long()


if __name__ == "__main__":
a = WHAMID(
"/media/sam/bx500/wavesplit/asteroid/egs/wham/wavesplit/data/wav8k/min/tt", "sep_clean"
)

for i in a:
print(i[-1])
Comment on lines +147 to +153
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be removed

25 changes: 25 additions & 0 deletions egs/wham/wavesplit/local/conf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Network config
masknet:
n_src: 2

# Training config
training:
epochs: 200
batch_size: 4
num_workers: 6
half_lr: yes
early_stop: yes
gradient_clipping: 5
# Optim config
optim:
optimizer: adam
lr: 0.001
# Data config
data:
train_dir: data/wav8k/min/tr/
valid_dir: data/wav8k/min/cv/
task: sep_clean
nondefault_nsrc:
sample_rate: 8000
mode: min
segment: 1.0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.0 seconds or 0.75 as in the paper is enough

38 changes: 38 additions & 0 deletions egs/wham/wavesplit/local/convert_sphere2wav.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/bin/bash
# MIT Copyright (c) 2018 Kaituo XU


sphere_dir=tmp
wav_dir=tmp

. utils/parse_options.sh || exit 1;


echo "Download sph2pipe_v2.5 into egs/tools"
mkdir -p ../../tools
wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools
cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd -

echo "Convert sphere format to wav format"
sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe

if [ ! -x $sph2pipe ]; then
echo "Could not find (or execute) the sph2pipe program at $sph2pipe";
exit 1;
fi

tmp=data/local/
mkdir -p $tmp

[ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list

if [ ! -d $wav_dir ]; then
while read line; do
wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'`
echo $wav
mkdir -p `dirname $wav`
$sph2pipe -f wav $line > $wav
done < $tmp/sph.list > $tmp/wav.list
else
echo "Do you already get wav files? if not, please remove $wav_dir"
fi
32 changes: 32 additions & 0 deletions egs/wham/wavesplit/local/prepare_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash

wav_dir=tmp
out_dir=tmp
python_path=python

. utils/parse_options.sh

## Download WHAM noises
mkdir -p $out_dir
echo "Download WHAM noises into $out_dir"
# If downloading stalls for more than 20s, relaunch from previous state.
wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir

echo "Download WHAM scripts into $out_dir"
wget https://storage.googleapis.com/whisper-public/wham_scripts.tar.gz -P $out_dir
mkdir -p $out_dir/wham_scripts
tar -xzvf $out_dir/wham_scripts.tar.gz -C $out_dir/wham_scripts
mv $out_dir/wham_scripts.tar.gz $out_dir/wham_scripts

wait

unzip $out_dir/wham_noise.zip $out_dir >> logs/unzip_wham.log

echo "Run python scripts to create the WHAM mixtures"
# Requires : Numpy, Scipy, Pandas, and Pysoundfile
cd $out_dir/wham_scripts/wham_scripts
$python_path create_wham_from_scratch.py \
--wsj0-root $wav_dir \
--wham-noise-root $out_dir/wham_noise\
--output-dir $out_dir
cd -
93 changes: 93 additions & 0 deletions egs/wham/wavesplit/local/preprocess_wham.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import argparse
import json
import os
import soundfile as sf
import glob


def preprocess_task(task, in_dir, out_dir):
if not os.path.exists(out_dir):
os.makedirs(out_dir)

if task == "mix_both":
mix_both = glob.glob(os.path.join(in_dir, "mix_both", "*.wav"))
examples = []
for mix in mix_both:
filename = mix.split("/")[-1]
spk1_id = filename.split("_")[0][:3]
spk2_id = filename.split("_")[2][:3]
length = len(sf.SoundFile(mix))

noise = os.path.join(in_dir, "noise", filename)
s1 = os.path.join(in_dir, "s1", filename)
s2 = os.path.join(in_dir, "s2", filename)

ex = {
"mix": mix,
"sources": [s1, s2],
"noise": noise,
"spk_id": [spk1_id, spk2_id],
"length": length,
}
examples.append(ex)

with open(os.path.join(out_dir, "mix_both.json"), "w") as f:
json.dump(examples, f, indent=4)

elif task == "mix_clean":
mix_clean = glob.glob(os.path.join(in_dir, "mix_clean", "*.wav"))
examples = []
for mix in mix_clean:
filename = mix.split("/")[-1]
spk1_id = filename.split("_")[0][:3]
spk2_id = filename.split("_")[2][:3]
length = len(sf.SoundFile(mix))

s1 = os.path.join(in_dir, "s1", filename)
s2 = os.path.join(in_dir, "s2", filename)

ex = {"mix": mix, "sources": [s1, s2], "spk_id": [spk1_id, spk2_id], "length": length}
examples.append(ex)

with open(os.path.join(out_dir, "mix_clean.json"), "w") as f:
json.dump(examples, f, indent=4)

elif task == "mix_single":
mix_single = glob.glob(os.path.join(in_dir, "mix_single", "*.wav"))
examples = []
for mix in mix_single:
filename = mix.split("/")[-1]
spk1_id = filename.split("_")[0][:3]
length = len(sf.SoundFile(mix))

s1 = os.path.join(in_dir, "s1", filename)

ex = {"mix": mix, "sources": [s1], "spk_id": [spk1_id], "length": length}
examples.append(ex)

with open(os.path.join(out_dir, "mix_single.json"), "w") as f:
json.dump(examples, f, indent=4)
else:
raise EnvironmentError


def preprocess(inp_args):
tasks = ["mix_both", "mix_clean", "mix_single"]
for split in ["tr", "cv", "tt"]:
for task in tasks:
preprocess_task(
task, os.path.join(inp_args.in_dir, split), os.path.join(inp_args.out_dir, split)
)


if __name__ == "__main__":
parser = argparse.ArgumentParser("WHAM data preprocessing")
parser.add_argument(
"--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt"
)
parser.add_argument(
"--out_dir", type=str, default=None, help="Directory path to put output files"
)
args = parser.parse_args()
print(args)
preprocess(args)
Comment on lines +83 to +93
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should create a def main(args) at the beginning if the file, put the args for the parser also at the beginning and call preprocess inside main(args) it's more user friendly we can see directly the arguments and the function that is called without scrolling

Loading