Coder Social home page Coder Social logo

baysmm's Introduction

Bayesian Subspace Multinomial Model (BaySMM)

  • Model for learning document embeddings (i-vectors) along with their uncertainties.
  • Gaussian linear classifier exploiting the uncertainties in document embeddings.
  • See paper http://arxiv.org/abs/1908.07599

S. Kesiraju, O. Plchot, L. Burget and S. V. Gangashetty, "Learning Document Embeddings Along With Their Uncertainties," in IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 28, pp. 2319-2332, 2020, doi: 10.1109/TASLP.2020.3012062.

Requirements

  • Python >= 3.7

  • PyTorch >= 1.1 <=1.4

  • scipy >= 1.3

  • numpy >= 1.16.4

  • scikit-learn >= 0.21.2

  • h5py >= 2.9.0

  • See INSTALL.md for detailed instructions.

Data preparation - sample from 20Newsgroups

python src/create_sample_data.py.py sample_data/

Training the model

  • For help:

    python src/run_baysmm.py --help

  • To train on GPU set CUDA_VISIBLE_DEVICES=$GPU_ID where the $GPU_ID is the free GPU index

  • Following code trains the model for 1000 VB iterations and saves the model in an automatically created sub-directory: exp/s_1.00_rp_1_lw_1e+01_l1_1e-03_50_adam/

    python src/run_baysmm.py train \
        sample_data/train.mtx \
        sample_data/vocab \
        exp/ \
        -K 50 \
        -trn 1000 \
        -lw 1e+01 \
        -var_p 1e+01 \
        -lt 1e-03
  • ELBO and KLD for every iteration, log file, etc are saved in the sub-directory.

Extracting the posterior distributions of embeddings

  • Extract embeddings [mean, log.std.dev] for 1000 iterations for each of the stats file present in sample_data/mtx.flist file list.

  • Using -nth 100 argument, embeddings for every 100th iteration are also saved.

    python src/run_baysmm.py extract \
        sample_data/mtx.flist \
        exp/s_1.00_rp_1_lw_1e+01_l1_1e-03_50_adam/model_T1000.h5 \
        -xtr 1000 \
        -nth 100
  • Extracted embedding posterior distributions are saved in exp/*/ivecs/ sub-directory with appropriate names.

Training and testing the classifier

  • Three classifiers can be trained on these embeddings.
  • Use --final option to train and test classifier on embeddings from the final iteration.
  1. Gaussian linear classifier - uses only the mean parameter

    python src/train_and_clf_cv.py exp/s_1.00_rp_1_lw_1e+01_l1_1e-03_50_adam/ivecs/train_model_T1000_e1000.h5 sample_data/train.labels glc

  2. Multi-class logistic regression - uses only the mean parameter

    python src/train_and_clf_cv.py exp/s_1.00_rp_1_lw_1e+01_l1_1e-03_50_adam/ivecs/train_model_T1000_e1000.h5 sample_data/train.labels lr

  3. Gaussian linear classifier with uncertainty - uses full posterior distribution

    python src/train_and_clf_cv.py exp/s_1.00_rp_1_lw_1e+01_l1_1e-03_50_adam/ivecs/train_model_T1000_e1000.h5 sample_data/train.labels glcu

  • All the results and predicted classes are saved in exp/*/results/

Citation

@ARTICLE{Kesiraju:2020:BaySMM,
  author={Kesiraju, Santosh and Plchot, Oldřich and Burget, Lukáš and Gangashetty, Suryakanth V.},
  journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 
  title={Learning Document Embeddings Along With Their Uncertainties}, 
  year={2020},
  volume={28},
  number={},
  pages={2319-2332},
  doi={10.1109/TASLP.2020.3012062}}

baysmm's People

Contributors

skesiraju avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.