Coder Social home page Coder Social logo

tf_rsvi's Introduction

tf_rsvi

Pure Tensorflow implementation of Rejection Sampling Variational Inference.

Code implements a sparse Gamma Deep Exponential Family model with Rejection Sampling VI.

Citations:

Reparameterization Gradients through Acceptance-Rejection Sampling Algorithms.
Christian A. Naesseth, Francisco J. R. Ruiz, Scott W. Linderman, and David M. Blei
Proceedings of the 20th International Conference on Artificial Intelligence and Statistics 2017,
Fort Lauderdale, Florida, USA.
Deep Exponential Families.
Rajesh Ranganath, Linpeng Tang, Laurent Charlin, David M. Blei.
Proceedings of the Eighteenth International Conference on Artificial Intelligence and Statistics,
PMLR 38:762-771, 2015.

Code is deeply indebted to the work at blei-lab/ars-reparameterization.

That code provides a few more examples (work in progress) but is designed to run on CPU. This code runs on Tensorflow v1.4 and runs ~10x faster on a GTX 980.

Poisson samples from model after 500 steps...

samples

ELBO over course of training (500 steps, comparable to RSVI paper but ~40s).

ELBO over time

Log of run:

Epoch    0: total loss is -8.1686e+07 (est elbo: -8.1766e+07) || time elapsed: 2.75 s
Epoch   10: total loss is -1.2606e+07 (est elbo: -1.2428e+07) || time elapsed: 5.36 s
Epoch   20: total loss is -1.2216e+07 (est elbo: -1.22e+07) || time elapsed: 6.06 s
Epoch   30: total loss is -1.1942e+07 (est elbo: -1.2095e+07) || time elapsed: 6.73 s
Epoch   40: total loss is -1.1964e+07 (est elbo: -1.201e+07) || time elapsed: 7.39 s
Epoch   50: total loss is -1.1856e+07 (est elbo: -1.187e+07) || time elapsed: 8.06 s
Epoch   60: total loss is -1.1787e+07 (est elbo: -1.1723e+07) || time elapsed: 8.73 s
Epoch   70: total loss is -1.1678e+07 (est elbo: -1.1724e+07) || time elapsed: 9.38 s
Epoch   80: total loss is -1.1408e+07 (est elbo: -1.1446e+07) || time elapsed: 10.05 s
Epoch   90: total loss is -1.1305e+07 (est elbo: -1.1356e+07) || time elapsed: 10.71 s
Epoch  100: total loss is -1.1055e+07 (est elbo: -1.1067e+07) || time elapsed: 11.39 s
Epoch  110: total loss is -1.0793e+07 (est elbo: -1.0763e+07) || time elapsed: 12.06 s
Epoch  120: total loss is -1.0493e+07 (est elbo: -1.052e+07) || time elapsed: 12.71 s
Epoch  130: total loss is -1.0261e+07 (est elbo: -1.0274e+07) || time elapsed: 13.39 s
Epoch  140: total loss is -1.0007e+07 (est elbo: -1.0036e+07) || time elapsed: 14.07 s
Epoch  150: total loss is -9.8497e+06 (est elbo: -9.8733e+06) || time elapsed: 14.73 s
Epoch  160: total loss is -9.6353e+06 (est elbo: -9.6608e+06) || time elapsed: 15.42 s
Epoch  170: total loss is -9.4338e+06 (est elbo: -9.5857e+06) || time elapsed: 16.10 s
Epoch  180: total loss is -9.3352e+06 (est elbo: -9.3854e+06) || time elapsed: 16.76 s
Epoch  190: total loss is -9.2268e+06 (est elbo: -9.2198e+06) || time elapsed: 17.45 s
Epoch  200: total loss is -9.0963e+06 (est elbo: -9.1414e+06) || time elapsed: 18.13 s
Epoch  210: total loss is -9.0205e+06 (est elbo: -8.9437e+06) || time elapsed: 18.78 s
Epoch  220: total loss is -8.8646e+06 (est elbo: -8.8987e+06) || time elapsed: 19.44 s
Epoch  230: total loss is -8.7613e+06 (est elbo: -8.7465e+06) || time elapsed: 20.12 s
Epoch  240: total loss is -8.6495e+06 (est elbo: -8.6615e+06) || time elapsed: 20.78 s
Epoch  250: total loss is -8.6095e+06 (est elbo: -8.6393e+06) || time elapsed: 21.45 s
Epoch  260: total loss is -8.5225e+06 (est elbo: -8.5681e+06) || time elapsed: 22.12 s
Epoch  270: total loss is -8.537e+06 (est elbo: -8.4787e+06) || time elapsed: 22.78 s
Epoch  280: total loss is -8.4092e+06 (est elbo: -8.42e+06) || time elapsed: 23.45 s
Epoch  290: total loss is -8.3269e+06 (est elbo: -8.3354e+06) || time elapsed: 24.11 s
Epoch  300: total loss is -8.2921e+06 (est elbo: -8.2745e+06) || time elapsed: 24.78 s
Epoch  310: total loss is -8.2316e+06 (est elbo: -8.183e+06) || time elapsed: 25.45 s
Epoch  320: total loss is -8.2003e+06 (est elbo: -8.1439e+06) || time elapsed: 26.10 s
Epoch  330: total loss is -8.1047e+06 (est elbo: -8.1114e+06) || time elapsed: 26.78 s
Epoch  340: total loss is -8.1485e+06 (est elbo: -8.0644e+06) || time elapsed: 27.45 s
Epoch  350: total loss is -7.9622e+06 (est elbo: -8.0318e+06) || time elapsed: 28.12 s
Epoch  360: total loss is -7.9918e+06 (est elbo: -7.9756e+06) || time elapsed: 28.80 s
Epoch  370: total loss is -7.9623e+06 (est elbo: -7.9398e+06) || time elapsed: 29.46 s
Epoch  380: total loss is -7.9109e+06 (est elbo: -7.9497e+06) || time elapsed: 30.12 s
Epoch  390: total loss is -7.8882e+06 (est elbo: -7.8894e+06) || time elapsed: 30.82 s
Epoch  400: total loss is -7.8905e+06 (est elbo: -7.8856e+06) || time elapsed: 31.50 s
Epoch  410: total loss is -7.7921e+06 (est elbo: -7.8335e+06) || time elapsed: 32.16 s
Epoch  420: total loss is -7.8271e+06 (est elbo: -7.8505e+06) || time elapsed: 32.82 s
Epoch  430: total loss is -7.8309e+06 (est elbo: -7.7842e+06) || time elapsed: 33.49 s
Epoch  440: total loss is -7.7404e+06 (est elbo: -7.7589e+06) || time elapsed: 34.14 s
Epoch  450: total loss is -7.7268e+06 (est elbo: -7.7499e+06) || time elapsed: 34.81 s
Epoch  460: total loss is -7.662e+06 (est elbo: -7.6812e+06) || time elapsed: 35.49 s
Epoch  470: total loss is -7.6888e+06 (est elbo: -7.6831e+06) || time elapsed: 36.14 s
Epoch  480: total loss is -7.6736e+06 (est elbo: -7.6294e+06) || time elapsed: 36.84 s
Epoch  490: total loss is -7.6456e+06 (est elbo: -7.6613e+06) || time elapsed: 37.53 s

tf_rsvi's People

Contributors

tomblaze avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.