Coder Social home page Coder Social logo

lasso-in-pytorch's Introduction

Lasso-in-PyTorch

Why lasso can't produce sparse solution in pytorch

Why lasso can't produce sparse solution in pytorch

lasso = linear model + $\lambda *\ell_1$ regularization ,从compressed sensing角度,在design matrix X满足RIP时:$\ell_0=\ell_1$, 妥妥稀疏解。扩展到更复杂的模型$f(x)$ ,$f(x)+\ell_1$是否能够得到稀疏解?$f(x)$如果是多层神经网络,理论上不好分析,但感觉肯定可以,e.g., 不断调大$\lambda$。

在pytorch试图实现lasso regression $y= X\beta + \epsilon$. 然而$\hat \beta$ 不稀疏,咋调$\lambda$都没用!!!

分析:pytorch 是用SGD 优化,SGD不可以直接求解lasso, 在0点处绝对值函数不可导。

**结论:**它是直接求的,$\beta^{k+1} = \beta^{k} +\eta\cdot X^T(y-X\beta^k) + \lambda sgn(\beta^k)$ .

这种求法被叫做 SGD-L1 naive,导致了$\beta^{k+1}$ 几乎不会严格等于0,可能在0附近震荡。

一个自然的想法:SGD-L1 clipping, 每一次梯度更新后对参数做 soft-threshold.

python sklearn SGD regressor就是这么干的:the update is truncted to 0 to allow for learning sparse models image

给一篇参考文献:Stochastic Gradient Descent Training for L1-regularized Log-linear Models with Cumulative Penalty。

Example

  1. 产生真实的beta, sample size n = 100, dim p =10, 前两个系数为10,10,第三个第四个是-10,其它为0.

    np.random.seed(5)
    n = 100
    p = 10
    beta = np.zeros([p]).astype(np.float32)
    beta[0] = 10
    beta[1] =10
    beta[2] = -10
    beta[3] =-10
    X = np.random.rand(n,p).astype(np.float32)
    Y = np.dot(X,beta)
    a = torch.from_numpy(X)
    b = torch.from_numpy(Y.reshape(-1,1))
  2. SGD-L1 naive

    可以看到,前四个变量估计的不错,分别在+10和-10附近,和sklearn的lasso结果也很接近。后6个变量的系数非常小,对prediction应该影响不大,但是没有exactly 等于0,所以不是稀疏解 image

  3. SGD-L1 clipping

    三个结果,第一个SGD-L1 clipping, 可以看到,后6个变量系数exactly 被压缩到0了,第三个结果是 sklearn SGDRregressor方法,它用的是参考文献中的clipping方式,效果会好一点,但改进不大了,而且也没有什么理论支持。 image

Summary

  1. L1 regularization在PyTorch中,SGD无法直接产生稀疏解,需要每一次参数更新后做soft-threshold.
  2. code放到了我的github,有需要的朋友可自取(并点个star)

lasso-in-pytorch's People

Contributors

sanyouwu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.