Coder Social home page Coder Social logo

pytorch-nmfs's Introduction

Non-negative Matrix Fatorization in PyTorch

PyTorch is not only a good deep learning framework, but also a fast tool when it comes to matrix operations and convolutions on large data. A great example is PyTorchWavelets.

In this package I implement basic NMF and NMFD module minimizing beta-divergence using multiplicative update rules, and is based on torch.nn.Module so the models can be moved freely among CPU/GPU devices. Part of the multiplier is obtained via torch.autograd so the amount of codes is reduced and easy to maintain (only the denominator is calculated).

The interface is similar to sklearn.decomposition.NMF with some extra options.

Usage

Here is a short example of decompose a spectrogram.

import torch
import librosa
from torchnmf import NMF
from torchnmf.metrics import KL_divergence

y, sr = librosa.load(librosa.util.example_audio_file())
y = torch.from_numpy(y)
windowsize = 2048
S = torch.stft(y, windowsize, window=torch.hann_window(windowsize)).pow(2).sum(2).sqrt().cuda()

R = 8   # number of components

net = NMF(S.shape, n_components=R).cuda()
# run extremely fast on gpu
_, V = net.fit_transform(S)      # fit to target matrix S
print(KL_divergence(V, S))        # KL divergence to S

A more detailed version can be found here, which redo this example with NMFD.

Compare to sklearn

The barchart shows the time cost per iteration with different beta-divergence. It is clear that pytorch-based NMF is faster than scipy-based NMF (sklearn) when beta != 2 (Euclidean distance), and run even faster when computation is done on GPU. The test is conducted on a Acer E5 laptop with i5-7200U CPU and GTX 950M GPU.

Installation

Using pip:

pip install git+http://github.com/yoyololicon/pytorch-NMFs

Or clone this repo and do:

python setup.py install

Requirements

  • PyTorch == 0.4.1

Tips

  • If you notice significant slow down when operating on CPU, please flush denormal numbers by torch.set_flush_denormal(True).
  • Replace zeros in target matrix with small value to avoid overflow.

TODO

  • Support sparse matrix.
  • Regularization.
  • NNDSVD initialization.

pytorch-nmfs's People

Contributors

yoyololicon avatar

Watchers

James Cloos avatar dug avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.