Coder Social home page Coder Social logo

dmolitor / bolasso Goto Github PK

View Code? Open in Web Editor NEW
3.0 3.0 0.0 1.8 MB

Model consistent Lasso estimation through the bootstrap.

Home Page: https://dmolitor.github.io/bolasso/

License: Other

R 97.64% TeX 2.36%
bolasso bootstrap lasso rstats variable-selection

bolasso's Introduction

Top Langs

bolasso's People

Contributors

dmolitor avatar

Stargazers

 avatar  avatar

Watchers

 avatar

bolasso's Issues

BoLasso Algorithm

Thanks for the package! I have a question about the algorithm.

In the paper the algorithm states that regularization parameter mu is set before the for loop and so every bootstrap sample is using the same one I think. In the algorithm you implemented, do you tune the penalty term for every bootstrap sample with k-fold cross validation?

I also tried the bootLassoOLS function from the HDCI package but get slightly different results. Your algorithm gives the best variable selection when fitting a OLS regression after though.

I also tried to implement the algorithm myself but also get different results so I am going wrong somewhere. Maybe I am getting the decay wrong. I am creating 128 bootsrtap samples and then find the best penalty term for each sample with 10 fold cross-validation and then refit the model with the best penalty term on the entire bootstrap sample again and record all the coefficients that are non zero.

Here is a code snippet of how it looks:

library(tidymodels)
library(tidyverse)
library(broom)

lasso_spec <- linear_reg(penalty = tune::tune(), mixture = 1) %>%
  set_engine("glmnet")

df_boot <- recipe_lasso %>% recipes::prep() %>% recipes::juice()
brain_folds <- rsample::bootstraps(df_boot, times = 128)
lambda_grid <- dials::grid_latin_hypercube(penalty(), size = 50)
  
res <- dplyr::tibble()
for (i in 1:nrow(brain_folds)) {

  df <- analysis(brain_folds$splits[[i]])
  recipe_lasso <- recipes::recipe(
    Cattell_TotalScore ~ ., data = df
  ) %>%
    recipes::step_dummy(all_nominal()) %>% 
    recipes::step_normalize(all_numeric(), -all_outcomes()) 
  
  wf_lasso <- workflows::workflow() %>%
    workflows::add_recipe(recipe_lasso) %>% 
    workflows::add_model(lasso_spec)
  
  folds <- rsample::vfold_cv(df, v = 10)
  
  lasso_grid <- tune::tune_grid(
    wf_lasso,
    metrics = metrics,
    resamples = folds,
    grid = lambda_grid
  )
  
  lowest_rmse <- lasso_grid %>%
    tune::select_best("rmse", maximize = FALSE)
  
  final_lasso <- finalize_workflow(
    wf_lasso,
    lowest_rmse
  ) %>% 
    parsnip::fit(df)
  
  res <- res %>% 
    dplyr::bind_rows(
      final_lasso %>%
        workflows::extract_fit_parsnip() %>% 
        broom::tidy() %>% 
        dplyr::filter(estimate != 0) %>% 
        dplyr::mutate(boot = paste0("boot: ", i))
    )
  print(i)
  
}

res %>% 
  dplyr::group_by(term) %>% 
  dplyr::summarise(
    n = dplyr::n(),
    avg = mean(estimate)
  ) %>% 
  dplyr::ungroup() %>% 
  dplyr::mutate(prop = n/128) %>% 
  dplyr::filter(prop >= 0.98)

So my question would just be if you could just in 4-5 bullet points explain how you implemented the algorithm thta would be great so I can understand what was wrong with my code better. Thank you.

Parallelization across models in `coef` and `predict`

Currently, bolasso is implemented to estimate bootstrapped Lasso models in parallel. I want to extend this to the coef and predict methods as well. Although the performance gains are probably much, much smaller for both of these cases, it makes it more consistent and, in the case of extremely large data, could improve performance a fair amount.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.