Coder Social home page Coder Social logo

chen0040 / mxnet-audio Goto Github PK

View Code? Open in Web Editor NEW
52.0 5.0 15.0 16.06 MB

Implementation of music genre classification, audio-to-vec, song recommender, and music search in mxnet

License: MIT License

Python 100.00%
audio-classification music-recommendation music-search mxnet song-recommender

mxnet-audio's Introduction

mxnet-audio

Implementation of music genre classification, audio-to-vec, song recommender, and music search in mxnet

Principles

  • The classifier ResNetV2AudioClassifier converts audio into mel-spectrogram and uses a simplified resnet DCnn architecture to classifier audios based on its associated labels.
  • The classifier Cifar10AudioClassifier converts audio into mel-spectrogram and uses the cifar-10 DCnn architecture to classifier audios based on its associated labels.

The classifiers differ from those used in image classification in that:

  • they use softrelu instead relu.
  • they have elongated max pooling shape (as the mel-spectrogram is elongated "image")
  • Dropout being added

Usage

Dependencies

Make sure you have the right dependencies in your python environment by running:

pip install -r requirements.txt

Train a deep learning model

The audio training uses Gtzan data set to train the music classifier to recognize the genre of songs.

The training works by converting audio or song file into a mel-spectrogram which can be thought of a 3-dimension tensor in a similar manner to an image. With the trained model, it is possible to build other interesting application such as music recommendation, music search, audio2vec, etc.

To train on the Gtzan data set, run the following command:

cd demo
python cifar10_train.py

The sample codes below show how to train Cifar10AudioClassifier to classify songs based on its genre labels:

from mxnet_audio.library.cifar10 import Cifar10AudioClassifier
from mxnet_audio.library.utility.gtzan_loader import download_gtzan_genres_if_not_found
import mxnet


def load_audio_path_label_pairs(max_allowed_pairs=None):
    download_gtzan_genres_if_not_found('./very_large_data/gtzan')
    audio_paths = []
    with open('./data/lists/test_songs_gtzan_list.txt', 'rt') as file:
        for line in file:
            audio_path = './very_large_data/' + line.strip()
            audio_paths.append(audio_path)
    pairs = []
    with open('./data/lists/test_gt_gtzan_list.txt', 'rt') as file:
        for line in file:
            label = int(line)
            if max_allowed_pairs is None or len(pairs) < max_allowed_pairs:
                pairs.append((audio_paths[len(pairs)], label))
            else:
                break
    return pairs


def main():
    audio_path_label_pairs = load_audio_path_label_pairs()
    print('loaded: ', len(audio_path_label_pairs))

    classifier = Cifar10AudioClassifier(model_ctx=mxnet.gpu(0), data_ctx=mxnet.gpu(0))
    batch_size = 8
    epochs = 100
    history = classifier.fit(audio_path_label_pairs, model_dir_path='./models',
                             batch_size=batch_size, epochs=epochs,
                             checkpoint_interval=2)


if __name__ == '__main__':
    main()

After training, the trained models are saved to demo/models.

To test the trained Cifar10AudioClassifier model, run the following command:

cd demo
python cifar10_predict.py

Model Comparison

Below compares training quality of ResNetV2AudioClassifier and Cifar10AudioClassifier:

training-comppare

Predict Music Genres

The sample codes shows how to use the trained Cifar10AudioClassifier model to predict the music genres:

from random import shuffle

from mxnet_audio.library.cifar10 import Cifar10AudioClassifier
from mxnet_audio.library.utility.gtzan_loader import download_gtzan_genres_if_not_found, gtzan_labels


def load_audio_path_label_pairs(max_allowed_pairs=None):
    download_gtzan_genres_if_not_found('./very_large_data/gtzan')
    audio_paths = []
    with open('./data/lists/test_songs_gtzan_list.txt', 'rt') as file:
        for line in file:
            audio_path = './very_large_data/' + line.strip()
            audio_paths.append(audio_path)
    pairs = []
    with open('./data/lists/test_gt_gtzan_list.txt', 'rt') as file:
        for line in file:
            label = int(line)
            if max_allowed_pairs is None or len(pairs) < max_allowed_pairs:
                pairs.append((audio_paths[len(pairs)], label))
            else:
                break
    return pairs


def main():
    audio_path_label_pairs = load_audio_path_label_pairs()
    shuffle(audio_path_label_pairs)
    print('loaded: ', len(audio_path_label_pairs))

    classifier = Cifar10AudioClassifier()
    classifier.load_model(model_dir_path='./models')

    for i in range(0, 20):
        audio_path, actual_label_id = audio_path_label_pairs[i]
        predicted_label_id = classifier.predict_class(audio_path)
        print(audio_path)
        predicted_label = gtzan_labels[predicted_label_id]
        actual_label = gtzan_labels[actual_label_id]
        
        print('predicted: ', predicted_label, 'actual: ', actual_label)


if __name__ == '__main__':
    main()

Audio to Vector

The sample codes shows how to use the trained Cifar10AudioClassifier model to encode an audio file into a fixed-length numerical vector:

from random import shuffle

from mxnet_audio.library.cifar10 import Cifar10AudioClassifier
from mxnet_audio.library.utility.gtzan_loader import download_gtzan_genres_if_not_found


def load_audio_path_label_pairs(max_allowed_pairs=None):
    download_gtzan_genres_if_not_found('./very_large_data/gtzan')
    audio_paths = []
    with open('./data/lists/test_songs_gtzan_list.txt', 'rt') as file:
        for line in file:
            audio_path = './very_large_data/' + line.strip()
            audio_paths.append(audio_path)
    pairs = []
    with open('./data/lists/test_gt_gtzan_list.txt', 'rt') as file:
        for line in file:
            label = int(line)
            if max_allowed_pairs is None or len(pairs) < max_allowed_pairs:
                pairs.append((audio_paths[len(pairs)], label))
            else:
                break
    return pairs


def main():
    audio_path_label_pairs = load_audio_path_label_pairs()
    shuffle(audio_path_label_pairs)
    print('loaded: ', len(audio_path_label_pairs))

    classifier = Cifar10AudioClassifier()
    classifier.load_model(model_dir_path='./models')

    for i in range(0, 20):
        audio_path, actual_label_id = audio_path_label_pairs[i]
        audio2vec = classifier.encode_audio(audio_path)
        print(audio_path)

        print('audio-to-vec: ', audio2vec)


if __name__ == '__main__':
    main()

Music Search Engine

The sample codes shows how to use Cifar10AudioSearch with the trained model to search for similar musics given a music file:

from mxnet_audio.library.cifar10 import Cifar10AudioSearch
from mxnet_audio.library.utility.gtzan_loader import download_gtzan_genres_if_not_found


def load_audio_path_label_pairs(max_allowed_pairs=None):
    download_gtzan_genres_if_not_found('./very_large_data/gtzan')
    audio_paths = []
    with open('./data/lists/test_songs_gtzan_list.txt', 'rt') as file:
        for line in file:
            audio_path = './very_large_data/' + line.strip()
            audio_paths.append(audio_path)
    pairs = []
    with open('./data/lists/test_gt_gtzan_list.txt', 'rt') as file:
        for line in file:
            label = int(line)
            if max_allowed_pairs is None or len(pairs) < max_allowed_pairs:
                pairs.append((audio_paths[len(pairs)], label))
            else:
                break
    return pairs


def main():
    search_engine = Cifar10AudioSearch()
    search_engine.load_model(model_dir_path='./models')
    for path, _ in load_audio_path_label_pairs():
        search_engine.index_audio(path)

    query_audio = './data/audio_samples/example.mp3'
    search_result = search_engine.query(query_audio, top_k=10)

    for idx, similar_audio in enumerate(search_result):
        print('result #%s: %s' % (idx+1, similar_audio))


if __name__ == '__main__':
    main()

Recommend Songs

The sample codes shows how to use Cifar10AudioRecommender with the trained model to recommend songs based on user's listening history:

from random import shuffle

from mxnet_audio.library.cifar10 import Cifar10AudioRecommender
from mxnet_audio.library.utility.gtzan_loader import download_gtzan_genres_if_not_found


def load_audio_path_label_pairs(max_allowed_pairs=None):
    download_gtzan_genres_if_not_found('./very_large_data/gtzan')
    audio_paths = []
    with open('./data/lists/test_songs_gtzan_list.txt', 'rt') as file:
        for line in file:
            audio_path = './very_large_data/' + line.strip()
            audio_paths.append(audio_path)
    pairs = []
    with open('./data/lists/test_gt_gtzan_list.txt', 'rt') as file:
        for line in file:
            label = int(line)
            if max_allowed_pairs is None or len(pairs) < max_allowed_pairs:
                pairs.append((audio_paths[len(pairs)], label))
            else:
                break
    return pairs


def main():
    music_recommender = Cifar10AudioRecommender()
    music_recommender.load_model(model_dir_path='./models')
    music_archive = load_audio_path_label_pairs()
    for path, _ in music_archive:
        music_recommender.index_audio(path)

    # create fake user history on musics listening to
    shuffle(music_archive)
    for i in range(30):
        song_i_am_listening = music_archive[i]
        music_recommender.track(song_i_am_listening)

    for idx, similar_audio in enumerate(music_recommender.recommend(limits=10)):
        print('result #%s: %s' % (idx+1, similar_audio))


if __name__ == '__main__':
    main()

Note

On pre-processing

To pre-generate the mel-spectrograms from the audio files for classification, one can also first run the following scripts before starting training, which will make the training faster:

cd demo/utility
python gtzan_loader.py

audioread.NoBackend

The audio processing depends on librosa version 0.6 which depends on audioread.

If you are on Windows and sees the error "audioread.NoBackend", go to ffmpeg and download the shared linking build, unzip to a local directory and then add the bin folder of the ffmpeg to the Windows $PATH environment variable. Restart your cmd or powershell, Python should now be able to locate the backend for audioread in librosa

Training with GPU

Note that the default training scripts in the demo folder use GPU for training, therefore, you must configure your graphic card for this (or remove the "model_ctx=mxnet.gpu(0)" in the training scripts).

mxnet-audio's People

Contributors

chen0040 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.