Coder Social home page Coder Social logo

segan's Introduction

SEGAN

A PyTorch implementation of SEGAN based on INTERSPEECH 2017 paper SEGAN: Speech Enhancement Generative Adversarial Network.

Requirements

conda install pytorch torchvision -c pytorch
  • librosa
pip install librosa

Datasets

The clear and noisy speech datasets are downloaded from DataShare. Download the 56kHZ train datasets and test datasets, then extract them into data directory.

If you want using other datasets, you should change the path of data defined on data_preprocess.py.

Usage

Data Pre-process

python data_preprocess.py

The pre-processed datas are on data/serialized_train_data and data/serialized_test_data.

Train Model and Test

python main.py ----batch_size 128 --num_epochs 300
optional arguments:
--batch_size             train batch size [default value is 50]
--num_epochs             train epochs number [default value is 86]

The test results are on results.

Test Audio

python test_audio.py ----file_name p232_160.wav --epoch_name generator-80.pkl
optional arguments:
--file_name              audio file name
--epoch_name             generator epoch name

The generated enhanced audio is on the same directory of input audio.

Results

The example results and the pre-train Generator weight can be downloaded from BaiduYun(access code:tzdd).

segan's People

Contributors

leftthomas avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

segan's Issues

results

您好,请问results里面保存的,为什么同一条语音的不同段内容都是一样的?

TypeError: 'str' object is not callable

When load_state_dict in test_audio.py

File "test_audio.py", line 25, in
generator.load_state_dict(torch.load('epochs/' + EPOCH_NAME, map_location='cpu'))

TypeError: 'str' object is not callable

Any hints?

Thank you!

pre-train Generator weights

Hi @leftthomas

I have been trying to download the weights and samples from link provided, but I haven't been able to download from baidu, could you provide another link to download them?

thanks a lot,

regards.

Out of memory error

So I tried for a batch size of 64 with 300 epochs

I got this error :

CUDA out of memory. Tried to allocate 32.00 MiB (GPU 0; 6.00 GiB total capacity; 4.06 GiB already allocated; 26.44 MiB free; 4.11 GiB reserved in total by PyTorch)

I looked for this issue and found a solution to add torch.cuda.empty_cache() or use garbage collector gc.collect()
I've been looking through your code and looking for a place to add these snippets

If you have any idea where to add these or any alternative solution let me know

Version of pytorch

Can you explain the specific version of pytorch? I ran it locally and encountered many errors caused by version problems. Other versions of packages you think are key can also be said, please.

I am kind of doubt if this model could be called as a GAN.

Hi there, recently I'm trying to reproduce this SEGAN model and find out some questions.

  1. The biggest question is about the loss function of the discriminator. As we know the original GAN's discriminator is doing binary classification task. So it use a Sigmoid at the last output layer and Binary Cross Entropy as the loss function. For this model's discriminator it seems it is doing a regression task, the loss function is trying to minimize the distance between outputs and 1 (or 0). So I think the discriminator contributes nothing to the final performance. minimizing L1 loss between clean speech and generated speech make the whole system work.

  2. So I discarded the discriminator and only train the generator for speech enhancement, it gives a very close performance of SEGAN. If only use the generator for training, the model could be seen as a de-noising auto encoder.

3.I'm kind of confused about that how much does the discriminator contribute during the Adversarial Process.

Many thanks!

How to use more than one GPU to train model

Hi,

Thanks a lot for submitting the codes.

I am trying your code, however the training process only use single GPU.

Do you have any idea how to use more than one GPU so that I can increase batch_size for training model.

Thanks.

Shouldn't ref_batch be inside the loop?

Currently, the ref_batch is sampled only once before the training loop starts as here in line 32. Shouldn't it be randomly sampled for each batch? I am guessing it based on its implementation (didn't read the virtual bn paper or anything else) but the comments sound like it should be. What do you think?

    def reference_batch(self, batch_size):
        """
        Randomly selects a reference batch from dataset.
        Reference batch is used for calculating statistics for virtual batch normalization operation.

        Args:
            batch_size(int): batch size

        Returns:
            ref_batch: reference batch
        """
        ref_file_names = np.random.choice(self.file_names, batch_size)
        ref_batch = np.stack([np.load(f) for f in ref_file_names])

        ref_batch = emphasis(ref_batch, emph_coeff=0.95)
        return torch.from_numpy(ref_batch).type(torch.FloatTensor)

ref_batch的mean和mean_sq每个step都重新算了一遍

hi 非常感谢分享代码,简单明了清晰!
一个小疑问&建议:做virtual batch norm的时候,每次调用discriminator(train_batch,ref_batch)的时候在forward里都重新算了一遍相同ref_batch的mean和mean_sq,好像有点多余。是不是算一次保存起来会好一点?

How is the performance of this model?

Hi. I recently tried to reproduce the SEGAN model, but it seems the performance of my trained model is not very good, I'm trying to find out the reason. I'd like to know how is the performance of your model? I could have some reference of your approach, loss or STOI or any other measurement of performance

BadiduYun is empty

BadiduYun is empty. I cannot download the example results and the pre-train Generator weight. can you send it to me ?

nomal() error

When I run the main.py,the code is error that:
`Traceback (most recent call last):
File "main.py", line 35, in
discriminator = Discriminator()
File "D:\users\XiaoFY\SEGAN-master\model.py", line 228, in init
self.vbn1 = VirtualBatchNorm1d(32)
File "D:\users\XiaoFY\SEGAN-master\model.py", line 21, in init
self.gamma = Parameter(torch.normal(means=1.0, std=0.02, out=(1, num_features, 1)))
TypeError: normal() received an invalid combination of arguments - got (means=float, std=float, out=tuple, ), but expected one of:

  • (Tensor mean, Tensor std, torch.Generator generator, Tensor out)
  • (Tensor mean, float std, torch.Generator generator, Tensor out)
  • (float mean, Tensor std, torch.Generator generator, Tensor out)

`
Could you please help me to fix it?

License

Is there an open source license on this?

56kHz dataset

Hi,
in the README you suggest to download the 56kHz dataset from DataShare, but they seem to offer only 48kHz datasets. However, they offer one dataset with 56 speakers (and another one with 28).
I suppose this is just a typo, as 56kHz is a rather odd audio format.
Best regards!

enhanced无法打开

您好,您在readme中分享的训练模型,result中enhanced文件中的音频无法打开,noisy和clean都可以打开播放,请问这是怎么回事呢
Hello, the training model you Shared in readme, the audio in the enhanced file in the result cannot be opened, and noisy and clean can both be opened and played. What's the matter?

TypeError: normal() received an invalid combination of arguments - got (std=float, means=Tensor, )

/home/machine/.local/lib/python3.6/site-packages/numba/errors.py:137: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9
  warnings.warn(msg)
loading data...
Traceback (most recent call last):
  File "main.py", line 35, in <module>
    discriminator = Discriminator()
  File "/home/machine/code/SEGAN/model.py", line 228, in __init__
    self.vbn1 = VirtualBatchNorm1d(32)
  File "/home/machine/code/SEGAN/model.py", line 21, in __init__
    self.gamma = Parameter(torch.normal(means=torch.ones(1, num_features, 1), std=0.02))
TypeError: normal() received an invalid combination of arguments - got (std=float, means=Tensor, ), but expected one of:
 * (Tensor mean, Tensor std, torch.Generator generator, Tensor out)
 * (Tensor mean, float std, torch.Generator generator, Tensor out)
 * (float mean, Tensor std, torch.Generator generator, Tensor out)
 * (float mean, float std, tuple of ints size, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)```

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.