Coder Social home page Coder Social logo

adaptive-regularization-neural-network's Introduction

Python 3.6

Adaptive-Regularization-Neural-Network

PyTorch demo code for paper "Learning Neural Networks with Adaptive Regularization"

Paper

Learning Neural Networks with Adaptive Regularization
Han Zhao *, Yao-Hung Hubert Tsai *, Ruslan Salakhutdinov, and Geoffrey J. Gordon
Thirty-third Conference on Neural Information Processing Systems (NeurIPS), 2019. (*equal contribution)

If you use this code for your research and find it helpful, please cite our paper:

@inproceedings{zhao2019adaptive,
  title={Learning Neural Networks with Adaptive Regularization},
  author={Zhao, Han and Tsai, Yao-Hung Hubert and Salakhutdinov, Ruslan and Gordon, Geoffrey J},
  booktitle={Advances in Neural Information Processing Systems},
  year={2019}
}

Summary

In this paper we propose a method, AdaReg, to perform an adaptive and data-dependent regularization method when training neural networks with small-scale datasets. The follow figure shows a schematic illustration of AdaReg during training:

To summarize, for a fully connected layer (usually the last layer) , AdaReg maintains two additional covariance matrices:

  • Row covariance matrix
  • Column covariance matrix

During the training phase, AdaReg updates both and in a coordinate descent way. As usual, the update of could use any off-the-shelf optimizers provided by PyTorch. To update two covariance matrices, we derive closed form algorithm to achieve the optimal solution given any fixed . Essentially, the algorithm contains two steps:

  • Compute a SVD of a matrix of size or
  • Truncate all the singular values into the range . That is, for all the singular values smaller than , set them to be . Similarly, for all the singular values greater than , set them to be .

The choice of hyperparameter should satisfy and . In practice and in our experiments we fix them to be and . Pseudo-code of the algorithm is shown in the following figure.

How to use AdaReg in your model

Really simple! Here we give a minimum code snippet (in PyTorch) to illustrate the main idea. For the full implementation, please see function BayesNet(args) in src/model.py for more details.

First, for the weight matrix , we need to define two covariance matrices:

# Define two covariance matrices (in sqrt):
self.sqrt_covt = nn.Parameter(torch.eye(self.num_tasks), requires_grad=False)
self.sqrt_covf = nn.Parameter(torch.eye(self.num_feats), requires_grad=False)

Since we will use own analytic algorithm to optimize them, we set the requires_grad to be False. Next, implement a 4 line thresholding function:

def _thresholding(self, sv, lower, upper):
    """
    Two-way thresholding of singular values.
    :param sv:  A list of singular values.
    :param lower:   Lower bound for soft-thresholding.
    :param upper:   Upper bound for soft-thresholding.
    :return:    Thresholded singular values.
    """
    uidx = sv > upper
    lidx = sv < lower
    sv[uidx] = upper
    sv[lidx] = lower
    return sv

The overall algorithm for updating both covariance matrices can then be implemented as:

def update_covs(self, lower, upper):
    """
    Update both the covariance matrix over row and over column, using the closed form solutions.
    :param lower:   Lower bound of the truncation.
    :param upper:   Upper bound of the truncation.
    """
    covt = torch.mm(self.sqrt_covt, self.sqrt_covt.t())
    covf = torch.mm(self.sqrt_covf, self.sqrt_covf.t())
    ctask = torch.mm(torch.mm(self.W, covf), self.W.t())
    cfeat = torch.mm(torch.mm(self.W.t(), covt), self.W)
    # Compute SVD.
    ct, st, _ = torch.svd(ctask.data)
    cf, sf, _ = torch.svd(cfeat.data)
    st = self.num_feats / st
    sf = self.num_tasks / sf
    # Truncation of both singular values.
    st = self._thresholding(st, lower, upper)
    st = torch.sqrt(st)
    sf = self._thresholding(sf, lower, upper)
    sf = torch.sqrt(sf)
    # Recompute the value.
    self.sqrt_covt.data = torch.mm(torch.mm(ct, torch.diag(st)), ct.t())
    self.sqrt_covf.data = torch.mm(torch.mm(cf, torch.diag(sf)), cf.t())

Finally, we need to use the optimized covariance matrices to regularize the learning of our weight matrix (our goal!):

def regularizer(self):
    """
    Compute the weight regularizer w.r.t. the weight matrix W.
    """
    r = torch.mm(torch.mm(self.sqrt_covt, self.W), self.sqrt_covf)
    return torch.sum(r * r)

Add this regularizer back to our favorite objective function (cross-entropy, mean-squared-error, etc) and backpropagate to update , done!

Have a try yourself on MNIST and CIFAR:

Running CIFAR-10

python demo.py --dataset CIFAR10 

Running MNIST

python demo.py --dataset MNIST

Running MNIST with 1000 Training Samples and BatchSize 128

python demo.py --dataset MNIST --trainPartial --trainSize 1000 --batch_size 128

Contact

Please email to [email protected] or [email protected] should you have any questions, comments or suggestions.

adaptive-regularization-neural-network's People

Contributors

hanzhaoml avatar yaohungt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

adaptive-regularization-neural-network's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.