Coder Social home page Coder Social logo

diffusion-separation's Introduction

Diffusion-based Generative Speech Source Separation

This repository contains the code to reproduce the results of the paper Diffusion-based Generative Speech Source Separation presented at ICASSP 2023.

We propose DiffSep, a new single channel source separation method based on score-matching of a stochastic differential equation (SDE). We craft a tailored continuous time diffusion-mixing process starting from the separated sources and converging to a Gaussian distribution centered on their mixture. This formulation lets us apply the machinery of score-based generative modelling. First, we train a neural network to approximate the score function of the marginal probabilities or the diffusion-mixing process. Then, we use it to solve the reverse time SDE that progressively separates the sources starting from their mixture. We propose a modified training strategy to handle model mismatch and source permutation ambiguity. Experiments on the WSJ0 2mix dataset demonstrate the potential of the method. Furthermore, the method is also suitable for speech enhancement and shows performance competitive with prior work on the VoiceBank-DEMAND dataset.

Show Me How to Separate Wav Files!

We got you covered. Just run the following command (after setting up the environment as described under Training).

python separate.py path/to/wavfiles/folder path/to/output/folder

where path/to/wavfiles/folder points to a folder containing wavfiles. The input files should be sampled at 8 kHz for the default model. Two speakers are separated and stored in path/to/output/folder/s1 and path/to/output/folder/s2, respectively. The model weights are stored on huggingface.

Configuration

Configuration is done using the hydra hierarchical configuration package. The hierarchy is as follows.

config/
|-- config.yaml  # main config file
|-- datamodule  # config of dataset and dataloaders
|   |-- default.yaml
|   `-- diffuse.yaml  # smaller batch size for CDiffuse
|-- model
|   |-- default.yaml  # NCSN++ model
|   `-- diffuse.yaml  # CDiffuse model
`-- trainer
    `-- default.yaml  # config of pytorch-lightning trainer

Dataset

The wsj0_mix dataset is expected in data/wsj0_mix

data/wsj0_mix/
|-- 2speakers
|   |-- wav16k
|   |   |-- max
|   |   |   |-- cv
|   |   |   |-- tr
|   |   |   `-- tt
|   |   `-- min
|   |       |-- cv
|   |       |-- tr
|   |       `-- tt
|   `-- wav8k
|       |-- max
|       |   |-- cv
|       |   |-- tr
|       |   `-- tt
|       `-- min
|           |-- cv
|           |-- tr
|           `-- tt
`-- 3speakers
    |-- wav16k
    |   `-- max
    |       |-- cv
    |       |-- tr
    |       `-- tt
    `-- wav8k
        `-- max
            |-- cv
            |-- tr
            `-- tt

The VCTK-DEMAND dataset is expected in data/VCTK_DEMAND

data/VCTK_DEMAND/
|--train
|   |-- noisy
|   `-- clean
`-- test
    |-- noisy
    `-- clean

Training

Preparation

conda env create -f environment.yaml
conda activate diff-sep

Run training. The results of training and tensorboard files are stored in ./exp/.

python ./train.py

Thanks to hydra, parameters can be added easily

python ./train.py model.sde.sigma_min=0.1

The training can be run in multi-gpu setting by overriding the trainer config trainer=allgpus. Since validation is quite expensive to do, we set trainer.check_val_every_n_epoch=5 to run it only every 5 epochs. The train and validation batch sizes are multiplied by the number of GPUS.

Evaluation

The evaluation.py script can be used to run the inference for val and test datasets.

$ python ./evaluate.py --help
usage: evaluate.py [-h] [-d DEVICE] [-l LIMIT] [--save-n SAVE_N] [--val] [--test] [-N N] [--snr SNR] [--corrector-steps CORRECTOR_STEPS] [--denoise DENOISE] ckpt

Run evaluation on validation or test dataset

positional arguments:
  ckpt                  Path to checkpoint to use

options:
  -h, --help            show this help message and exit
  -d DEVICE, --device DEVICE
                        Device to use (default: cuda:0)
  -l LIMIT, --limit LIMIT
                        Limit the number of samples to process
  --save-n SAVE_N       Save a limited number of output samples
  --val                 Run on validation dataset
  --test                Run on test dataset
  -N N                  Number of steps
  --snr SNR             Step size of corrector
  --corrector-steps CORRECTOR_STEPS
                        Number of corrector steps
  --denoise DENOISE     Use denoising in solver
  --enhance             Run evaluation for speech enhancement task (default: false)

This will save the results in a folder named results/{exp_name}_{ckpt_name}_{infer_params}. The option --save-n N allows to save the firs N samples as figures and audio samples.

Reproduce

Separation

# train
python ./train.py experiment=icassp-separation

# evaluate
python ./evaluate_mp.py exp/default/<YYYY-MM-DD_hh-mm-ss>_experiment-icassp-separation/checkpoints/epoch-<NNN>_si_sdr-<F.FFF>.ckpt --split test libri-clean

Enhancement

# train
python ./train.py experiment=noise-reduction

# evaluate
python ./evaluate.py exp/enhancement/<YYYY-MM-DD_hh-mm-ss>_experiment-noise-reduction/checkpoints/epoch-<NNN>_si_sdr-<F.FFF>.ckpt --test --pesq-mode wb

License

2023 (c) LINE Corporation

The repo is released under MIT license, but please refer to individual files for their specific license.

diffusion-separation's People

Contributors

fakufaku avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

diffusion-separation's Issues

Regarding the issue of model convergence

Hello, I have recently been trying to train your model. Could you please let me know at what value of score loss we can consider that the model has converged ?

Make pre-trained checkpoints available

Dear author, hello! Would it be possible for me to have the privilege of obtaining your pre-trained weights? I've tried running your code, but due to limitations in computational resources, it may take a long time to complete training and may not achieve the results reported in your paper.

Clarification Regarding Default Parameters in Relation to Paper Results

Hello,

I am reaching out to gain a better understanding of the default parameters specified in this repository. I would like to know if the default parameters set in the code are the ones that correspond to the optimal results mentioned in the paper.

Could you please clarify if any adjustments are needed to replicate the results mentioned in the paper, or if the default settings are indeed the configurations that yielded the best results?

I appreciate your time and effort in maintaining this repository and look forward to your response.

Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.