Coder Social home page Coder Social logo

jordandeklerk / vit Goto Github PK

View Code? Open in Web Editor NEW
0.0 2.0 0.0 6.29 MB

Implementing a vision transformer model in PyTorch on CIFAR-10

License: MIT License

Python 100.00%
artificial-intelligence attention-mechanism computer-vision image-classification transformers

vit's Introduction

Vision Transformer on CIFAR-10


Contents

  1. Highlights
  2. Requirements
  3. Usage
  4. Results

Highlights

This project is a implementation from scratch of a slightly modified version of the vanilla vision transformer introduced in the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. We implement this model on the small scale benchmark dataset CIFAR-10.

Vision Transformers often suffer when trained from scratch on small datasets such as CIFAR-10. This is primarily due to the lack of locality, inductive biases and hierarchical structure of the representations which is commonly observed in the Convolutional Neural Networks. As a result, ViTs require large-scale pre-training to learn such properties from the data for better transfer learning to downstream tasks.

This project shows that with modifications, supervised training of vision transformer models on small scale datasets like CIFAR-10 can lead to very high accuracy with low computational constraints.

The vanilla vision transformer model uses the standard multi-head self-attention mechanism introduced in the seminal paper by Vaswani et al..

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        all_head_dim = head_dim * self.num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn

Requirements

pip install -r requirements.txt

Usage

To replicate the reported results, clone this repo

cd your_directory git clone [email protected]:jordandeklerk/ViT-pytorch.git

and run the main training script

python train.py 

Make sure to adjust the checkpoint directory in train.py to store checkpoint files.


Results

We test our approach on the CIFAR-10 dataset with the intention to extend our model to 4 other small low resolution datasets: Tiny-Imagenet, CIFAR100, CINIC10 and SVHN. All training took place on a single A100 GPU.

  • CIFAR10
    • vit_cifar10_patch2_input32 - 96.80 @ 32

Flop analysis:

total flops: 915674304
total activations: 10735212
number of parameter: 2725632
| module            | #parameters or shape   | #flops   |
|:------------------|:-----------------------|:---------|
| model             | 2.726M                 | 0.916G   |
|  cls_token        |  (1, 1, 192)           |          |
|  pos_embed        |  (1, 257, 192)         |          |
|  patch_embed.proj |  2.496K                |  0.59M   |
|  blocks           |  2.673M                |  0.915G  |
|  norm             |  0.384K                |  0.247M  |

Citations

@misc{vaswani2017attention,
    title   = {Attention Is All You Need},
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{dosovitskiy2020image,
    title   = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
    author  = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
    year    = {2020},
    eprint  = {2010.11929},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

vit's People

Contributors

jordandeklerk avatar jdeklerk10 avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.