Coder Social home page Coder Social logo

compshrink's Introduction

compshrink

This is the repository for Bayesian Dirichlet-Multinomial regression using global-local shrinkage priors: including horseshoe, horseshoe+ and Bayesian Lasso as described in the paper entitled "Shrinkage and Selection for Compositional Data" by Jyotishka Datta and Dipankar Bandopadhyay.

Abstract

We propose a variable selection and estimation framework for Bayesian compositional regression model using state-of-the-art continuous shrinkage priors to identify the significant associations between available covariates and taxonomic abundance from microbiome data. We use a generalized Dirichlet and Dirichlet distribution for modeling the compositional component and compare the popular horseshoe (Carvalho et al., 2010) and horseshoe+ (Bhadra et al., 2017) priors along with the Bayesian Lasso as a benchmark. We use Hamiltonial Monte Carlo for posterior sampling and posterior credible intervals and pseudo posterior inclusion probabilities for variable selection. Our simulation studies show excellent recovery and estimation accuracy for sparse parameter regime, and we apply our method to human microbiome data from NYC-Hanes study.

Details

The Stan codes are provided in the stan-codes folder.

//
// This Stan program defines a Dirichlet Multinomial model, with a
// matrix of values 'Y' modeled as GDM distributed
// with horseshoe prior on beta coefficients 
// ** Wadsworth's model - with HS instead of Spike-Slab**
//

functions {
// for likelihood estimation
  real dirichlet_multinomial_lpmf(int[] y, vector alpha) {
    real alpha_plus = sum(alpha);
    return lgamma(alpha_plus) + lgamma(sum(y)+1) + sum(lgamma(alpha + to_vector(y)))
                - lgamma(alpha_plus+sum(y)) - sum(lgamma(alpha))-sum(lgamma(to_vector(y)+1));
  }
}

data {
  int<lower=1> N; // total number of observations
  int<lower=2> ncolY; // number of categories
  int<lower=2> ncolX; // number of predictor levels
  matrix[N,ncolX] X; // predictor design matrix
  int <lower=0> Y[N,ncolY]; // data // response variable
  real<lower=0> scale_icept;
  // real<lower=0> sd_prior;
  real<lower=0> psi;
}
parameters {
  matrix[ncolX, ncolY] beta_raw; // coefficients (raw)
  vector[N] beta0; // intercept
  matrix<lower=0>[ncolX,ncolY] lambda_tilde; // truncated local shrinkage
  vector<lower=0>[ncolY] tau; // global shrinkage
  real<lower=0> sigma;
  // real<lower = 0> psi;
}
transformed parameters{
  matrix[ncolX,ncolY] beta; // coefficients 
  matrix<lower=0>[ncolX,ncolY] lambda; // local shrinkage
  lambda = diag_post_multiply(lambda_tilde, tau);
  beta = beta_raw .* lambda*sigma;
}

model {
 // psi ~ uniform(0,1);
 sigma ~ inv_gamma(0.5, 0.5);
// prior:
for(k in 1:N){
  beta0[k] ~ cauchy(0, scale_icept);
}
for (k in 1:ncolX) {
      for (l in 1:ncolY) {
        tau[l] ~ cauchy(0, 1); // flexible 
        lambda_tilde[k,l] ~ cauchy(0, 1);
        beta_raw[k,l] ~ normal(0,1);
    }
  }
// likelihood
for (i in 1:N) {
    vector[ncolY] logits;
    for (j in 1:ncolY){
      logits[j] = beta0[i]+X[i,] * beta[,j];
      }
     Y[i,] ~ dirichlet_multinomial(softmax(logits)*(1-psi)/psi);
    }
}
stan.hs.fit <- stan_model(file = 'multinomial-horseshoe-marg.stan', 
                          model_name = "Dirichlet Horseshoe")

For any of the three candidate priors, we can sample from the posterior using the sampling function in R-Stan.

n.iters = 1000
n.chains = 1
rng.seed = 12345

set.seed(rng.seed)
dirfit <- dirmult(Ymat)

NYC.data = list(N = nrow(Ymat), ncolY = ncol(Ymat), ncolX = ncol(Xmat),
                 X = Xmat, Y = Ymat, psi = dirfit$theta, scale_icept = 2, d=1) 

ptm = proc.time()
smpls.hs.res = sampling(stan.hs.fit, 
                        data = NYC.data, 
                        iter = n.iters,
                        init = 0,
                        seed = rng.seed,
                        cores = 2,
                        warmup = floor(n.iters/2),
                        chains = n.chains,
                        control = list(adapt_delta = 0.85),
                        refresh = 100)
proc.time()-ptm
# summarize results

beta.smpls.hs <- rstan::extract(smpls.hs.res, pars=c("beta"), permuted=TRUE)[[1]]
beta.mean.hs <- apply(beta.smpls.hs, c(2,3), mean)
beta.median.hs <- apply(beta.smpls.hs, c(2,3), median)
beta.mode.hs <- apply(beta.smpls.hs, c(2,3), Mode)
beta.sd.hs <- apply(beta.smpls.hs, c(2,3),sd)
beta.hs.ci <- apply(beta.smpls.hs, c(2,3), quantile, probs=c(0.025,0.5,0.975)) #the median line with 95% credible intervals

compshrink's People

Contributors

dattahub avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.