Coder Social home page Coder Social logo

vballoli / nfnets-pytorch Goto Github PK

View Code? Open in Web Editor NEW
343.0 8.0 29.0 5.76 MB

NFNets and Adaptive Gradient Clipping for SGD implemented in PyTorch. Find explanation at tourdeml.github.io/blog/

Home Page: https://nfnets-pytorch.readthedocs.io/en/latest/

License: MIT License

Python 100.00%
pytorch image-classification deepmind adaptive-gradient-clipping nfnets sota paper sgd

nfnets-pytorch's Introduction

PyTorch implementation of Normalizer-Free Networks and Adaptive Gradient Clipping

Python Package Docs Papers using ma-gym

Paper: https://arxiv.org/abs/2102.06171.pdf

Original code: https://github.com/deepmind/deepmind-research/tree/master/nfnets

Blog post: https://tourdeml.github.io/blog/posts/2021-03-31-adaptive-gradient-clipping/. Feel free to subscribe to the newsletter, and leave a comment if you have anything to add/suggest publicly.

Do star this repository if it helps your work, and don't forget to cite if you use this code in your research!

Installation

Install from PyPi:

pip3 install nfnets-pytorch

or install the latest code using:

pip3 install git+https://github.com/vballoli/nfnets-pytorch

Usage

WSConv2d

Use WSConv1d, WSConv2d, ScaledStdConv2d(timm) and WSConvTranspose2d like any other torch.nn.Conv2d or torch.nn.ConvTranspose2d modules.

import torch
from torch import nn
from nfnets import WSConv2d, WSConvTranspose2d, ScaledStdConv2d

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

conv_t = nn.ConvTranspose2d(3,6,3)
w_conv_t = WSConvTranspose2d(3,6,3)

Generic AGC (recommended)

import torch
from torch import nn, optim
from torchvision.models import resnet18

from nfnets import WSConv2d
from nfnets.agc import AGC # Needs testing

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = AGC(conv.parameters(), optim) # Needs testing

# Ignore fc of a model while applying AGC.
model = resnet18()
optim = torch.optim.SGD(model.parameters(), 1e-3)
optim = AGC(model.parameters(), optim, model=model, ignore_agc=['fc'])

SGD - Adaptive Gradient Clipping

Similarly, use SGD_AGC like torch.optim.SGD

# The generic AGC is preferable since the paper recommends not applying AGC to the last fc layer.
import torch
from torch import nn, optim
from nfnets import WSConv2d, SGD_AGC

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = SGD_AGC(conv.parameters(), 1e-3)

Using it within any non-residual PyTorch model (with non-residual connections)

replace_conv replaces the convolution in your (non-residual) model with the convolution class and replaces the batchnorm with identity. While the identity is not ideal, it shouldn't cause a major difference in the latency.

Note that as per the paper, replace_conv is only valid for non-residual models(vgg, mobilenetv1, etc.). See the above mentioned blog post for more information regarding the details.

import torch
from torch import nn
from torchvision.models import vgg16

from nfnets import replace_conv, WSConv2d, ScaledStdConv2d

model = vgg16()
replace_conv(model, WSConv2d) # This repo's original implementation
replace_conv(model, ScaledStdConv2d) # From timm

"""
class YourCustomClass(nn.Conv2d):
  ...
replace_conv(model, YourCustomClass)
"""

Docs

Find the docs at readthedocs

Cite Original Work

To cite the original paper, use:

@article{brock2021high,
  author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
  title={High-Performance Large-Scale Image Recognition Without Normalization},
  journal={arXiv preprint arXiv:},
  year={2021}
}

Cite this repository

To cite this repository, use:

@misc{nfnets2021pytorch,
  author = {Vaibhav Balloli},
  title = {A PyTorch implementation of NFNets and Adaptive Gradient Clipping},
  year = {2021},
  howpublished = {\url{https://github.com/vballoli/nfnets-pytorch}}
}

nfnets-pytorch's People

Contributors

lalondma avatar mj9 avatar pisarik avatar shi27feng avatar vballoli avatar zuenko avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

nfnets-pytorch's Issues

Model is None

Hi @vballoli,
It seems to have a bug in the code of AGC:

if model is not None:
    assert ignore_agc not in [None, []], "You must specify ignore_agc for AGC to ignore fc-like(or other) layers"
    names = [name for name, module in model.named_modules()]

    for module_name in ignore_agc:
        if module_name not in names:
            raise ModuleNotFoundError("Module name {} not found in the model".format(module_name))
        params = [{"params": list(module.parameters())} for name,
                          module in model.named_modules() if name not in ignore_agc]
else:
    params = [{"params": list(module.parameters())} for name,
                       module in model.named_modules()]

When model is None then the else part of the code cannot get name and module from model.named_modules().

Thanks

Implement `param_groups` for AGC

Describe the bug
Using LambdaLR, it will call len(optimizer.param_groups), but this is not implemented for AGC.

To Reproduce

model = torch.nn.Conv1d(10,20,4)
optimizer = optim.AdamW(model.parameters())
optimizer_agc = AGC(model.parameters(),optimizer)

lambda1 = lambda iteration: 0.05*iteration
scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1)
scheduler_warmup_agc = torch.optim.lr_scheduler.LambdaLR(optimizer_agc,lr_lambda=lambda1)

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-58-f68dd026a6de> in <module>
----> 1 scheduler_warmup2 = torch.optim.lr_scheduler.LambdaLR(optimizer2,lr_lambda=lambda1)

/gpfs/alpine/proj-shared/fus131/conda-envs/torch1.5.0v2/lib/python3.6/site-packages/torch/optim/lr_scheduler.py in __init__(self, optimizer, lr_lambda, last_epoch)
    180
    181         if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
--> 182             self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
    183         else:
    184             if len(lr_lambda) != len(optimizer.param_groups):

AttributeError: 'AGC' object has no attribute 'param_groups'

Pytorch v1.5.0

Method To Compute Gamma For ScaledStdConv

Is your feature request related to a problem? Please describe.

gamma is one of the hyper-parameters that can be tuned for the ScaledStdConv layers.

This should be computed from the activation function between layers.

Describe the solution you'd like

Add a function called compute_gamma that takes in a callable and returns a sampled gamma value.

import torch
from typing import Callable

@torch.no_grad()
def compute_gamma(activation_fn: Callable[[torch.Tensor], torch.Tensor], batch_size: int = 1024, samples: int = 256, device=None) -> float:
    # from appendix D: https://arxiv.org/pdf/2101.08692.pdf
    x = torch.randn(batch_size, samples, dtype=torch.float32, device=device)
    y = activation_fn(x)
    gamma = torch.mean(torch.var(y, dim=1)) ** -0.5
    return gamma.item()

Example usages

from torch import nn
from torch.nn import functional as F

print(compute_gamma(nn.ELU()))
print(compute_gamma(torch.sigmoid))
print(compute_gamma(F.relu))
>>> 1.2714643478393555
>>> 4.806765079498291
>>> 1.7130787372589111

Describe alternatives you've considered

Re-implement this function each time in your code.

Additional context

Excerpt from the appendix of https://arxiv.org/abs/2101.08692

D.1 NUMERICAL APPROXIMATIONS OF NONLINEARITY-SPECIFIC GAINS

It is often faster to determine the nonlinearity-specific constants γ empirically, especially when the
chosen activation functions are complex or difficult to integrate. One simple way to do this is for the
SiLU function is to sample many (say, 1024) random C-dimensional vectors (of say size 256) and
compute the average variance, which will allow for computing an estimate of the constant.
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(2) # Arbitrary key
# Produce a large batch of random noise vectors
x = jax.random.normal(key, (1024, 256))
y = jax.nn.silu(x)
# Take the average variance of many random batches
gamma = jnp.mean(jnp.var(y, axis=1)) ** -0.5

About AGC arguments

Hi. Thank you for your all effort and quick implementation.

When I was referring to your code for AGC, I found that one of argument named 'params' in AGC constructor was never used.
In more detail, 'params' is perfectly replaced by followed code snippet in AGC.

    if model is not None:
        assert ignore_agc not in [
            None, []], "You must specify ignore_agc for AGC to ignore fc-like(or other) layers"
        names = [name for name, module in model.named_modules()]

        for module_name in ignore_agc:
            if module_name not in names:
                raise ModuleNotFoundError(
                    "Module name {} not found in the model".format(module_name))
        params = [{"params": list(module.parameters())} for name,
                      module in model.named_modules() if name not in ignore_agc]
    
    else:
        params = [{"params": list(module.parameters())} for name,
                      module in model.named_modules()]

Why do you leave this argument ? Is there any reason for this?
I think more codes should be written in the phrase 'else'. for params. ( When if we optimizer an image like random noise with some objective function )

And I have one more question. I think it looks a silly question, but I want to know... It is similar to above.

Why do you leave local variable 'defaults' in AGC constructor ?

AGC without modifying the optimizer

Hello,

Is there a way to apply AGC externally without modifying the optimizer code?

I am using optimizers from torch_optimizer package and that would be good.

Example in readme does not work

Describe the bug
Running either replace_conv this code form the readme on the front page:
model = vgg16()
replace_conv(model, WSConv2d) # This repo's original implementation
replace_conv(model, ScaledStdConv2d) # From timm

Results in this error:

  File "/opt/conda/lib/python3.8/site-packages/nfnets/utils.py", line 25, in replace_conv
    replace_conv(mod, conv_class)
  File "/opt/conda/lib/python3.8/site-packages/nfnets/utils.py", line 18, in replace_conv
    setattr(module, name, conv_class(target_mod.in_channels, target_mod.out_channels, target_mod.kernel_size,
  File "/opt/conda/lib/python3.8/site-packages/nfnets/base.py", line 262, in __init__
    super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 386, in __init__
    super(Conv2d, self).__init__(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 108, in __init__
    if bias:
RuntimeError: Boolean value of Tensor with more than one value is ambiguous

optim_agc.step() this instruction is raising AttributeError: 'NoneType' object has no attribute 'ndim'

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import *
from torch.utils.data import Dataset, DataLoader
from nfnets import replace_conv,SGD_AGC

net = models.resnet50(pretrained=True)
replace_conv(net)
net = net.cuda() if device else net
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 13)
net.fc = net.fc.cuda() if use_cuda else net.fc
ct = 0
for name, child in net.named_children():
    ct += 1
    if ct < 8:
        for name2, params in child.named_parameters():
            params.requires_grad = False

criterion = nn.CrossEntropyLoss()
optim_agc = SGD_AGC(net.parameters(), 1e-3,momentum=0.9)

n_epochs = 5
print_every = 10
valid_loss_min = np.Inf
val_loss = []
val_acc = []
train_loss = []
train_acc = []
total_step = len(train_dataloader)

for epoch in range(1, n_epochs+1):
    running_loss = 0.0
    correct = 0
    total=0
    print(f'Epoch {epoch}\n')
    for batch_idx, (data_, target_) in enumerate(train_dataloader):
        data_, target_ = data_.to(device), target_.to(device)
        optim_agc.zero_grad()
        outputs = net(data_)
        loss = criterion(outputs, target_)
        loss.backward()
        optim_agc.step()
        running_loss += loss.item()
        _,pred = torch.max(outputs, dim=1)
        correct += torch.sum(pred==target_).item()
        total += target_.size(0)
        if (batch_idx) % 20 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch, n_epochs, batch_idx, total_step, loss.item()))
    train_acc.append(100 * correct / total)
    train_loss.append(running_loss/total_step)
		

Traceback (most recent call last):

File "image_product_classifier.py", line 74, in
optim_agc.step()
File "/home/sachin_mohan/venv/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
return func(*args, **kwargs)
File "/home/sachin_mohan/venv/lib/python3.6/site-packages/nfnets/sgd_agc.py", line 105, in step
grad_norm = unitwise_norm(p.grad)
File "/home/sachin_mohan/venv/lib/python3.6/site-packages/nfnets/utils.py", line 27, in unitwise_norm
if x.ndim <= 1:
AttributeError: 'NoneType' object has no attribute 'ndim'``

TypeError: replace_conv() takes 1 positional argument but 2 were given

Describe the bug
Hello, thank you for the repo. I'm having an issue with the replace_conv().

To Reproduce

model = torchvision.models.resnet101(pretrained=False)
model.fc = nn.Linear(in_features=2048, out_features=8, bias=True)
replace_conv(model)

model = model.to(device)

Screenshots
Here is a screenshot of the stacktrace
image

train model acc and test acc gap is very large

epoch: 0, train_loss: 6.847220252882614, acc: 0.28600405679513186
epoch: 0, test_loss: 1.719438068297228, test_acc: 0.059837728194726165
epoch: 10, train_loss: 1.8927386342755212, acc: 0.2748478701825558
epoch: 10, test_loss: 0.5072576126106355, test_acc: 0.06490872210953347
epoch: 20, train_loss: 1.1933289462255563, acc: 0.4077079107505071
epoch: 20, test_loss: 0.30588787504535936, test_acc: 0.0922920892494929
epoch: 30, train_loss: 0.9264058188388222, acc: 0.7261663286004056
epoch: 30, test_loss: 0.2350183322362089, test_acc: 0.1724137931034483
epoch: 40, train_loss: 0.6766643664132246, acc: 0.7870182555780934
epoch: 40, test_loss: 0.16959500415363776, test_acc: 0.1876267748478702
epoch: 50, train_loss: 0.4851083510709919, acc: 0.8681541582150102
epoch: 50, test_loss: 0.11013727963181884, test_acc: 0.20588235294117646
epoch: 60, train_loss: 0.4282051429778398, acc: 0.8945233265720081
epoch: 60, test_loss: 0.08342219629232595, test_acc: 0.22210953346855983
epoch: 70, train_loss: 0.4014547500337063, acc: 0.920892494929006
epoch: 70, test_loss: 0.07419358513954634, test_acc: 0.22718052738336714

` model = nf_resnet50()

optimizer = SGD_AGC(model.parameters(), 1e-3)
# model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=["accuracy"])
train_transform = transforms.Compose([
    transforms.ToPILImage(),  # 将图像转换格式,数据转换为tensfroms格式。只有转换为tensfroms格式才能进行后面的处理。
    transforms.Resize(256),
    # transforms.RandomResizedCrop(224,scale=(0.5,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()])
val_transform = transforms.Compose([
    transforms.ToPILImage(),  # 将图像转换格式,数据转换为tensfroms格式。只有转换为tensfroms格式才能进行后面的处理。
    transforms.Resize(256),
    transforms.ToTensor()])

dataset = generate_dataset(batch_size, data_path, img_size=256,transform=train_transform)
dataset_val = generate_dataset(batch_size, data_path, img_size=256, training=False, transform=val_transform)

model.to(device)
for epoch in range(nb_epoch):
    model.train()
    for step, inputs in enumerate(dataset):
        imgs_tensor, label_tensor = inputs
        imgs_tensor = imgs_tensor.to(device)
        label_tensor = label_tensor.to(device)
        output = model(imgs_tensor)
        optimizer.zero_grad()
        loss = F.cross_entropy(output, label_tensor)  # 函数自带softmax,模型最后一层不需要加softmax层
        loss.backward()
        optimizer.step()`

This is loss and acc log and train code segment, as you can see, we can easily found train and test acc is inconformity. Is this case is normal? if not, what I can do to deal with this condition?
By the way, I found compare with other model, nf_resnet model convergence is slower than others, such as resnext and densenet. Do you have found this condition?

BN version of the models?

Hi,

May I ask what are the BN versions of the provided NF networks? For example, is torchvision.models.resnet18 the BN counterpart of nf_resnet18?

Thanks in advance!

Is Weight Standardization correct?

Hi

First of all, thank you for sharing this valuable source code.

I'm looking at the code you implemented.
I ask because the implementation of weight standradization is different from the original.

Original paper and github
https://paperswithcode.com/method/weight-standardization
https://paperswithcode.com/method/weight-standardization

image

Your implementation
image

Did you misunderstand the formula at the time of implementation?
image

Can you confirm that I have misunderstood or are mathematically the same formula?

Thanks

A decoder for Semantic Segmentation

Hi,

I am interested in using your architecture for a semantic segmentation problem. I am therefore using the segmentation_models.pytorch library, which luckily implements timm and therefore your architecture as the encoder.

However, all of the decoders supported by segmentation_models.pytorch use normalization. Should I just replace all instances of Conv2D followed by BatchNorm2D with a ScaledStdConv2D, or do you have a better suggestion? (Should I also then put the ReLU before the ScaledStdConv2D, as you seem to do?)

Thank you in advance.

Conv layer from timm work better

resnet18 with replace this layer learn better, but still need two times more epochs for same result than original on CIFAR10

class ScaledStdConv2d(nn.Conv2d):
    """Conv2d layer with Scaled Weight Standardization.
    Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
        https://arxiv.org/abs/2101.08692
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
                 bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
        if padding is None:
            padding = get_padding(kernel_size, stride, dilation)
        super().__init__(
            in_channels, out_channels, kernel_size, stride=stride,
            padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
        self.scale = gamma * self.weight[0].numel() ** -0.5  # gamma * 1 / sqrt(fan-in)
        self.eps = eps ** 2 if use_layernorm else eps
        self.use_layernorm = use_layernorm  # experimental, slightly faster/less GPU memory use

    def get_weight(self):
        if self.use_layernorm:
            weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
        else:
            std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
            weight = self.scale * (self.weight - mean) / (std + self.eps)
        if self.gain is not None:
            weight = weight * self.gain
        return weight

    def forward(self, x):
        return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)

Is it pytorch compatibility??

I use the code, however, I got the following error.

At first, I got the following;

base.py", line 262, in __init__
    dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
TypeError: __init__() got an unexpected keyword argument 'padding_mode'

When I remove padding_mode=padding_mode, I got another error, as presented below

utils.py", line 23, in replace_conv
    setattr(module, name, torch.nn.Identity())
AttributeError: module 'torch.nn' has no attribute 'Identity'

How to solve this?

torch.max doesn't check for tensors being on different devices.

Describe the bug
A clear and concise description of what the bug is.

To Reproduce
Steps to reproduce the behavior:

  1. Go to example and instantiate a resnet18
  2. Send model to torch.device('cuda)
  3. Define a tensor on the gpu
  4. Call model.forward()
  5. RuntimeError: iter.device(arg).is_cuda() INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/native/cuda/Loops.cuh":94, please report a bug to PyTorch.

Expected behavior
Regular output

Screenshots
If applicable, add screenshots to help explain your problem.

See here

My Solution:

Its hacky obviously but it works. Simply replace

scale = torch.rsqrt(torch.max(var * fan_in, torch.tensor(eps))) * self.gain.view_as(var)

with scale = torch.rsqrt(torch.max(var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.