Coder Social home page Coder Social logo

char-rbm's Introduction

char-boltzmann

Character-level RBMs for short text. For more information, check out my blog post.

Requirements

scikit-learn and its dependencies (numpy, scipy) is the big one. Also enum34. pip install -r requirements.txt might be all you need to do.

How-to

The two important scripts are:

  • train.py: trains an RBM model on a text file with one short text per line. It has a whole bunch of command line options you can supply, but the defaults are all pretty reasonable. The one you're most likely to need to change is --extra-chars - the default behaviour is to use only [a-z ] (and [A-Z] implicitly downcased), which is definitely not appropriate for some datasets having lots of numerals/punctuation.
  • sample.py: generates new short texts given a pickled model file generated by train.py

(The last script, compare_models.py is only really relevant if you're training a bunch of different models on the same dataset and enjoy spreadsheets.)

More details on the arguments to these scripts can be seen by running them with '-h'.

README-datasets.md has pointers to some suitable datasets.

Example

To train a small model on first names:

wget http://www.cs.cmu.edu/afs/cs/project/ai-repository/ai/areas/nlp/corpora/names/other/names.txt
python train.py --maxlen 10 --extra-chars '' --hid 100 names.txt
python sample.py names__nh100.pickle

This should give you some output like...

wietzer     
sarnimono   
buttheo     
ressinosoo  
bernington

Interpreting train.py output

During training, you'll see debug output like...

[CharBernoulliRBMSoftmax] Iteration 3/5 t = 14.46s
Pseudo-log-likelihood sum: -115047.96   Average per instance: -2.13
E(vali):        -14.00  E(train):       -14.07  difference: 0.07
Fantasy samples: moll$$$$$$|anderd$$$$|gronbel$$$

Without going into too much detail, the pseudo-log-likelihood (-2.13 above), is a pretty decent estimation of how well the model is currently fitting the training data. The lower the better.

The next line compares the energy assigned to the training data vs. the validation set. The difference (0.07 in this case) gives an idea of how much the model is overfitting. The higher the difference, the worse. A difference of 0 implies no overfitting.

The final line has string representions of a few of the "fantasy particles" used for the persistent contrastive divergence training.

More details

The core RBM code is cannibalized from scikit-learn's BernoulliRBM implementation. I tacked on some additional features including:

  • L2 weight cost
  • softmax sampling
  • sampling with temperature (for simulated annealing)
  • flag to gradually reduce learning rate
  • initializing visible biases to the training set means

This code has the same performance limitations as the base sklearn implementation. In particular, it can't run on a GPU.

The 'workspace' branch has a lot of extra scripts and data files which might be useful to someone, but which are kind of messy (even relative to the already-kinda-messy master). They mostly relate to model visualization and experiments with different sampling techniques.

char-rbm's People

Contributors

colinmorris avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.