Coder Social home page Coder Social logo

torch-tresnet's Introduction

TResNet: High Performance GPU-Dedicated Architecture

Packaged TResNet based on Official PyTorch Implementation [paper] [github]

Installation

Install with pip:

pip install torch_tresnet

or directly:

pip install git+https://github.com/tczhangzhi/torch-tresnet

Use

Follow the grammatical conventions of torchvision

from torch_tresnet import tresnet_m, tresnet_l, tresnet_xl, tresnet_m_448, tresnet_l_448, tresnet_xl_448

# pretrianed on 224*224
model = tresnet_m(pretrained=True)
model = tresnet_m(pretrained=True, num_classes=10)
model = tresnet_m(pretrained=True, num_classes=10, in_chans=3)

# pretrianed on 448*448
model = tresnet_m_448(pretrained=True)
model = tresnet_m_448(pretrained=True, num_classes=10)
model = tresnet_m_448(pretrained=True, num_classes=10, in_chans=3)

For a demonstration of training and testing, please see the jupyter notebook, which is an unofficial implementation, please do not use this implementation as the final evaluation standard

Main Results

TResNet Models

TResNet models accuracy and GPU throughput on ImageNet, compared to ResNet50. All measurements were done on Nvidia V100 GPU, with mixed precision. All models are trained on input resolution of 224.

Models Top Training Speed (img/sec) Top Inference Speed (img/sec) Max Train Batch Size Top-1 Acc.
ResNet50 805 2830 288 79.0
EfficientNetB1 440 2740 196 79.2
TResNet-M 730 2930 512 80.7
TResNet-L 345 1390 316 81.4
TResNet-XL 250 1060 240 82.0

Comparison To Other Networks

Comparison of ResNet50 to top modern networks, with similar top-1 ImageNet accuracy. All measurements were done on Nvidia V100 GPU with mixed precision. For gaining optimal speeds, training and inference were measured on 90% of maximal possible batch size. Except TResNet-M, all the models' ImageNet scores were taken from the public repository, which specialized in providing top implementations for modern networks. Except EfficientNet-B1, which has input resolution of 240, all other models have input resolution of 224.

Model Top Training Speed (img/sec) Top Inference Speed (img/sec) Top-1 Acc. Flops[G]
ResNet50 805 2830 79.0 4.1
ResNet50-D 600 2670 79.3 4.4
ResNeXt50 490 1940 78.5 4.3
EfficientNetB1 440 2740 79.2 0.6
SEResNeXt50 400 1770 79.0 4.3
MixNet-L 400 1400 79.0 0.5
TResNet-M 730 2930 80.7 5.5

Transfer Learning SotA Results

Comparison of TResNet to state-of-the-art models on transfer learning datasets (only ImageNet-based transfer learning results). Models inference speed is measured on a mixed precision V100 GPU. Since no official implementation of Gpipe was provided, its inference speed is unknown.

Dataset Model Top-1 Acc. Speed img/sec Input
CIFAR-10 Gpipe 99.0 - 480
CIFAR-10 TResNet-XL 99.0 1060 224
CIFAR-100 EfficientNet-B7 91.7 70 600
CIFAR-100 TResNet-XL 91.5 1060 224
Stanford Cars EfficientNet-B7 94.7 70 600
Stanford Cars TResNet-L 96.0 500 368
Oxford-Flowers EfficientNet-B7 98.8 70 600
Oxford-Flowers TResNet-L 99.1 500 368

torch-tresnet's People

Contributors

tczhangzhi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

pgsrv mhiyer

torch-tresnet's Issues

Wrong url for tresnet_l_448 and tresnet_xl_448

I think you made a mistake in "tresnet.py", line 265 and line 271.
you used 'tresnet_m_448' for tresnet_l_448 and tresnet_xl_448 versions, which leads to an error when loading pretrained models.

def tresnet_l_448(pretrained=False, progress=True, **kwargs):
    """ Constructs a large TResnet model.
    """
    return _tresnet('tresnet_m_448', [4, 5, 18, 3], pretrained, progress, width_factor=1.2, **kwargs) # here


def tresnet_xl_448(pretrained=False, progress=True, **kwargs):
    """ Constructs an extra-large TResnet model.
    """
    return _tresnet('tresnet_m_448', [4, 5, 24, 3], pretrained, progress, width_factor=1.3, **kwargs) # and here

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.