Coder Social home page Coder Social logo

square-of-w / byol-a Goto Github PK

View Code? Open in Web Editor NEW

This project forked from nttcslab/byol-a

0.0 1.0 0.0 27.28 MB

BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

Home Page: https://arxiv.org/abs/2103.06695

License: Other

Python 100.00%

byol-a's Introduction

key_visual

BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

This is a demo implementation of BYOL for Audio (BYOL-A), a self-supervised learning method for general-purpose audio representation, includes:

  • Training code that can train models with arbitrary audio files.
  • Evaluation code that can evaluate trained models with downstream tasks.
  • Pretrained weights.

If you find BYOL-A useful in your research, please use the following BibTeX entry for citation.

@inproceedings{niizumi2021byol-a,
      title={BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation}, 
      author={Daisuke Niizumi and Daiki Takeuchi and Yasunori Ohishi and Noboru Harada and Kunio Kashino},
      booktitle = {2021 International Joint Conference on Neural Networks, {IJCNN} 2021},
      year={2021},
      eprint={2103.06695},
      archivePrefix={arXiv},
      primaryClass={eess.AS}
}

Getting Started

  1. Download external source files, and apply a patch. Our implementation uses the following.

    curl -O https://raw.githubusercontent.com/lucidrains/byol-pytorch/2aa84ee18fafecaf35637da4657f92619e83876d/byol_pytorch/byol_pytorch.py
    patch < byol_a/byol_pytorch.diff
    mv byol_pytorch.py byol_a
    curl -O https://raw.githubusercontent.com/daisukelab/general-learning/7b31d31637d73e1a74aec3930793bd5175b64126/MLP/torch_mlp_clf.py
    mv torch_mlp_clf.py utils
  2. Install PyTorch 1.7.1, torchaudio, and other dependencies listed on requirements.txt.

Evaluating BYOL-A Representations

Downstream Task Evaluation

The following steps will perform a downstream task evaluation by linear-probe fashion. This is an example with SPCV2; Speech commands dataset v2.

  1. Preprocess metadata (.csv file) and audio files, processed files will be stored under a folder work.

    # usage: python -m utils.preprocess_ds <downstream task> <path to its dataset>
    python -m utils.preprocess_ds spcv2 /path/to/speech_commands_v0.02
  2. Run evaluation. This will convert all .wav audio to representation embeddings first, train a lineaer layer network, then calculate accuracy as a result.

    python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth spcv2

You can also run an evaluation multiple times and take an average result. Following will evaluate on UrbanSound8K with a unit audio duration of 4.0 seconds, for 10 times.

# usage: python evaluate.py <your weight> <downstream task> <unit duration sec.> <# of iteration>
python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth us8k 4.0 10

Similarly, the following evaluates on NSynth (4.0 seconds long) 10 times.

python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth nsynth 4.0 10

Evaluating Representations In Your Tasks

This is an example to calculate a feature vector for an audio sample.

from byol_a.common import *
from byol_a.augmentations import PrecomputedNorm
from byol_a.models import AudioNTT2020


device = torch.device('cuda')
cfg = load_yaml_config('config.yaml')
print(cfg)

# Mean and standard deviation of the log-mel spectrogram of input audio samples, pre-computed.
# See calc_norm_stats in evaluate.py for your reference.
stats = [-5.4919195,  5.0389895]

# Preprocessor and normalizer.
to_melspec = torchaudio.transforms.MelSpectrogram(
    sample_rate=cfg.sample_rate,
    n_fft=cfg.n_fft,
    win_length=cfg.win_length,
    hop_length=cfg.hop_length,
    n_mels=cfg.n_mels,
    f_min=cfg.f_min,
    f_max=cfg.f_max,
)
normalizer = PrecomputedNorm(stats)

# Load pretrained weights.
model = AudioNTT2020(d=cfg.feature_d)
model.load_weight('pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth', device)

# Load your audio file.
wav, sr = torchaudio.load('work/16k/spcv2/one/00176480_nohash_0.wav') # a sample from SPCV2 for now
assert sr == cfg.sample_rate, "Let's convert the audio sampling rate in advance, or do it here online."

# Convert to a log-mel spectrogram, then normalize.
lms = normalizer((to_melspec(wav) + torch.finfo(torch.float).eps).log())

# Now, convert the audio to the representation.
features = model(lms.unsqueeze(0))

Training From Scratch

You can also train models. Followings are an example of training on FSD50K.

  1. Convert all samples to 16kHz. This will convert all FSD50K files to a folder work/16k/fsd50k while preserving folder structure.

    python -m utils.convert_wav /path/to/fsd50k work/16k/fsd50k
  2. Start training, this example trains with all development set audio samples from FSD50K.

    python train.py work/16k/fsd50k/FSD50K.dev_audio

Refer to Table VI on our paper for the performance of a model trained on FSD50K.

Pretrained Weights

We include 3 pretrained weights of our encoder network.

Method Dim. Filename NSynth US8K VoxCeleb1 VoxForge SPCV2/12 SPCV2 Average
BYOL-A 512-d AudioNTT2020-BYOLA-64x96d512.pth 69.1% 78.2% 33.4% 83.5% 86.5% 88.9% 73.3%
BYOL-A 1024-d AudioNTT2020-BYOLA-64x96d1024.pth 72.7% 78.2% 38.0% 88.5% 90.1% 91.4% 76.5%
BYOL-A 2048-d AudioNTT2020-BYOLA-64x96d2048.pth 74.1% 79.1% 40.1% 90.2% 91.0% 92.2% 77.8%

License

This implementation is for your evaluation of BYOL-A paper, see LICENSE for the detail.

Acknowledgements

BYOL-A is built on top of byol-pytorch, a BYOL implementation by Phil Wang (@lucidrains). We thank Phil for open-source sophisticated code.

@misc{wang2020byol-pytorch,
  author =       {Phil Wang},
  title =        {Bootstrap Your Own Latent (BYOL), in Pytorch},
  howpublished = {\url{https://github.com/lucidrains/byol-pytorch}},
  year =         {2020}
}

References

byol-a's People

Contributors

daisukelab avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.