Coder Social home page Coder Social logo

bayesian-flow-mnist's Introduction

bayesian-flow-mnist

A simple Bayesian Flow model for MNIST in Pytorch.

  • Binarised MNIST generation using Bayesian Flow Discrete Data Loss
  • Continuous MNIST generation using Bayesian Flow Continuous Data Loss

This implementation could definitely be factorised, or include more features, but the intention is to make something minimal.

How to Run

Environment Setup

Aside from pytorch, matplotlib, and tqdm, the training script requires bayesian-flow-pytorch.

pip install git+https://github.com/thorinf/bayesian-flow-pytorch

Training

The model can be trained with the following command and MNIST will download automatically:

python train.py -ckpt CHECKPOINTING_PATH -d MNIST_DOWNLOAD_PATH

Experiments

Binarised MNIST

This method considers the pixel intensities as Bernoulli probabilities, think of this like the likelihood of it being the likelihood of something being true. Since the pixel values can be interpreted as probabilities, the task can be trained on the Bayesian Flow discrete data loss.

The sampling method will then return a shape which matches MNIST spatial dimensions and a final probaility dimension. The final dimension will contain two probabilities; probability of pixel on/high, and the probability of pix off/low. So convert this back to pixel intensity we can just take the channel indicating pixel on/high.

Although this isn't a strategy that's come to be expected from generative image models, in this case it works extremely well. In fact, the results at 20 epochs of training are nearly as good as the continuous methodology after 50.

Animated GIF
Binarised MNIST sampling after 50 epochs of training, using 1000 sampling steps, '-1' unconditional generation.

Continuous MNIST

This is a more typical generative method for MNIST. The data is scaled to be [-1,1], and nothing more. This part of the implementation is experimental, and, as you can see below, the results could be improved. Trying different values of sigma, or dropout, may yield a better generation. Note, a sigma value of 0.01 resulted in the loss being NaN about halfway through training.

Animated GIF
Continuous MNIST sampling after 50 epochs of training, using 1000 sampling steps, '-1' unconditional generation.

bayesian-flow-mnist's People

Contributors

thorinf avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.