Coder Social home page Coder Social logo

dangitstam / topic-rnn Goto Github PK

View Code? Open in Web Editor NEW
52.0 5.0 10.0 1.09 MB

Implementation (in progress) of Dieng et al.'s TopicRNN: a neural topic model & RNN hybrid.

License: Apache License 2.0

Python 100.00%
nlp deep-learning topic-modeling rnn variational-inference pytorch allennlp

topic-rnn's Introduction

TopicRNN: A Recurrent Neural Network with Long-Range Semantic Dependency

Currently, it is not clear to me if this model can work. After preliminary experiments and observations it doesn't seem like this architecture can learn latent topics very well. So for now, it will be archived.

Implementation of Dieng et al.'s TopicRNN: a neural topic model & RNN hybrid that learns global semantic dependencies via latent topics and local, syntatic dependencies through an RNN.

Model Details

  • The model learns a beta matrix of size (V x K) where V is the size of the vocabulary and K is the number of latent topics. Each row in beta represents a distinct distribution over the vocabulary.
  • A variational distribution is learned using word frequencies as input to produce the parameters for the Gaussian distribution in which each topic proportion vector theta of length k is sampled from.
  • beta * theta then results in the the logits over the vocabulary at the given time step that allow learned topics to be properly weighted before influencing inference of the next word. Topic additions for each word are zeroed out if the index of the logit belongs to a stop word, this allows only semantically significant words to have influence from the topics.
  • The topic additionsbeta * theta are added to the vocabulary projection of the RNN hidden W * ht resulting in a final distribution over the vocabulary that is normalized via SoftMax.

Getting Started

The system is built with PyTorch and AllenNLP, which are the main dependencies.

Prerequisites

  • Python 3.6 (3.6.5+ recommended)
  • AllenNLP 0.6.0

Installing

It is recommended to first create a virtual environment before installing dependencies.

Using Conda

conda create --name topic_rnn python=3.6

Using VirtualEnv

python3 -m venv /path/to/new/virtual/environment

Download PyTorch and AllenNLP via

`pip install -r requirements.txt`

Generating a Dataset (IMDB)

imdb_review_reader.py contains a dataset reader primed to take a .jsonl file where each entry is of the form

{
    'id': <integer id>,
    'text': <raw text of movie review>,
    'sentiment': <integer value representing sentiment>
}

You can download the IMDB 100K dataset here.

Upon extracting the dataset from the tar, the resulting directory will look like

aclImdb/
    train/
        unsup/
            <review id>_<sentiment>.txt
            ...
        pos/
            <review id>_<sentiment>.txt
            ...
        neg/
            <review id>_<sentiment>.txt
            ...
    test/
        pos/
            <review id>_<sentiment>.txt
            ...
        neg/
            <review id>_<sentiment>.txt
            ...
    ...

You can generate the necessary .jsonl files via scripts/generate_imdb_corpus.py needed to reproduce the results of the paper. The script expects the aclImdb file structure above, you can run it by doing

python generate_imdb_corpus.py --data-path <path to aclImdb>  --save-dir <directory to save the .jsonl files>

The directory specified by --save-dir will then contain five files: train_unsup.jsonl, valid_unsup.jsonl, train_labeled.jsonl, valid_labeled.jsonl, and test.jsonl. You will need to write the relative path to training/testing .jsonl files within your experiment JSON config.

Training the model

tests/fixtures/smoke_imdb_language_model.json contains a base specification for TopicRNN (i.e hyperparamters, relative paths to training/testing .jsonl, etc.). The fixtures also includes a subset of the IMDB dataset in the expected format.

Training this simple model can be done right out of the box after installing requirements. To ensure things are running smoothly, run

allennlp train tests/fixtures/smoke_imdb_language_model.json --s /tmp/topic_rnn_imdb_smoke --include-package library

To ensure that the model runs properly with a GPU, change cuda_device under trainer in the config JSON to point to an available device.

So long as the model can save a checkpoint when using either a CPU or GPU, you're good to go.

In any file in experiments, you must specify at minimum

  • The dataset reader with type (i.e. imdb_review_reader) and words_per_instance (backpropagation-through-time limit)
  • The relative paths to the training and validation .jsonl files (generate_imdb_corpus.py will be extended to produce training and validation splits at a later time)
  • Vocabulary with max_vocab_size
  • The model with type (base implementation of topic_rnn is currently the only model), text_field_embedder (specify whether to use pretrained embeddings, embedding size, etc.), text_encoder (encoding the utterance via RNN, GRU, LSTM, etc.), and topic_dim (number of latent topics)

An example, experiments/imdb_language_model.json is provided.

To train the model with an experimental config, run

allennlp train <path to the current experiment's JSON configuration> \
-s <directory for serialization>  \
--include-package library

Built With

  • AllenNLP - The NLP framework used, built by AI2
  • PyTorch - The deep learning library used

Authors

  • Tam Dang

License

This project is licensed under the Apache License - see the LICENSE.md file for details.

topic-rnn's People

Contributors

dangitstam avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

topic-rnn's Issues

difference between logits and topic_addition

Thanks for your excellent work~

Actually, I rebuild this structure by tensorflow while I met some problem. I discover that logits and topic_addition may own different scale, such as logits distribution are -20-20 and topic_addition may be -1-1. I am not sure add them directly will effect? When I check the training process, I find that the whole netwok rely on logits much more. Is there any wrong?

calculation of kl_divergence

Hi, thanks for your excellent work.

I have some question about formula of kl_divergence. As mentioned in the code, the formula is :
kl_divergence = torch.ones_like(mu) + 2 * log_sigma - (mu ** 2) - (torch.exp(log_sigma) ** 2)
while I think standard formula is :
kl_divergence = torch.ones_like(mu) + log_sigma - (mu ** 2) - torch.exp(log_sigma)

Therefore, I'm curious about this part. Is there anyone who can provide some help?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.