Coder Social home page Coder Social logo

weight-clipping's Introduction

Weight-Clipping

The official repo for reproducing the experiments and weight clipping implementation. You can find the paper from this link. Here we give a minimal implementation for weight clipping with SGD (change torch.optim.SGD to torch.optim.Adam` if you want to use Adam).

import torch, math

class InitBounds:
    def __init__(self):
        self.previous_weight = None
    def get(self, p):
        if p.dim() == 1:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.previous_weight)
            return 1.0 / math.sqrt(fan_in)
        elif p.dim() == 2 or p.dim() == 4:
            self.previous_weight = p
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(p)
            return  1.0 / math.sqrt(fan_in)
        else:
            raise ValueError("Unsupported tensor dimension: {}".format(p.dim()))

class WeightClippingSGD(torch.optim.Optimizer):
    def __init__(self, params, kappa=1.0, optimizer=torch.optim.SGD, **kwargs):
        defaults = dict(kappa=kappa)
        super(WeightClippingSGD, self).__init__(params, defaults)
        self.optimizer = optimizer(self.param_groups, **kwargs)
        self.param_groups = self.optimizer.param_groups
        self.defaults.update(self.optimizer.defaults)
        self.init_bounds = InitBounds()
    def step(self, closure=None):
        self.zero_grad()
        loss = closure()
        loss.backward()
        self.optimizer.step()
        self.weight_clipping()
    def weight_clipping(self):
        for group in self.param_groups:
            for p in group["params"]:
                bound = self.init_bounds.get(p)
                p.data.clamp_(-group["kappa"] * bound, group["kappa"] * bound)

Reproducing results:

1. You need to have environemnt with python 3.11:

conda create --name torch python==3.11
conda activate torch

2. Install Dependencies:

python -m pip install --upgrade pip
pip install .

3. TBD

License

Distributed under the MIT License. See LICENSE for more information.

How to cite

Bibtex:

@inproceedings{elsayed2024weightclipping,
  title={Weight clipping for deep continual and reinforcement learning},
  author={Elsayed, Mohamed and Lan, Qingfeng and Lyle, Clare and Mahmood, A Rupam},
  booktitle={Reinforcement Learning Conference},
  year={2024}
}

APA:

Elsayed, M., Lan, Q., Lyle, C., Mahmood, A. R. (2024). Weight clipping for deep continual and reinforcement learning. In the First Reinforcement Learning Conference

weight-clipping's People

Contributors

mohmdelsayed avatar

Stargazers

Imad  avatar  avatar AjianLiu avatar

Watchers

Ashique Rupam Mahmood avatar  avatar Kostas Georgiou avatar

weight-clipping's Issues

[ README: Reproduce results ]

Hi @mohmdelsayed, @qlan3 @armahmood ,

Really interesting work, and thanks a lot sharing the code.

Could you please, share with us how to reproduce the experiments, especially the ones about reinforcement learning ?

Also, could you provide the link to download the mini-imagenet version you used ?
Thanks !

Best,
Imad

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.