Coder Social home page Coder Social logo

tfgcvit's Introduction

tfgcvit

Keras (TensorFlow v2) reimplementation of Global Context Vision Transformer models.

  • Based on Official Pytorch implementation.
  • Supports variable-shape inference for downstream tasks.
  • Contains pretrained weights converted from official ones.

Installation

pip install tfgcvit

Examples

Default usage (without preprocessing):

from tfgcvit import GCViTTiny  # + 4 other variants and input preprocessing

model = GCViTTiny()  # by default will download imagenet-pretrained weights
model.compile(...)
model.fit(...)

Custom classification (with preprocessing):

from keras import layers, models
from tfgcvit import GCViTTiny, preprocess_input

inputs = layers.Input(shape=(224, 224, 3), dtype='uint8')
outputs = layers.Lambda(preprocess_input)(inputs)
outputs = GCViTTiny(include_top=False)(outputs)
outputs = layers.Dense(100, activation='softmax')(outputs)

model = models.Model(inputs=inputs, outputs=outputs)
model.compile(...)
model.fit(...)

Differences

Code simplification:

  • All input shapes automatically evaluated (not passed through a constructor like in PyTorch)
  • Downsampling have been moved out from GCViTLayer layer to simplify feature extraction in downstream tasks.

Performance improvements:

  • Layer normalization epsilon fixed at 1.001e-5 and inputs are casted to float32 to use fused op implementation.
  • Some layers have been refactored to use faster TF operations.
  • A lot of reshapes/transposes have been removed. Most of the time internal representation is 4D-tensor.
  • Relative index estimations moved to GCViTLayer layer level.

Variable shapes

When using GCViT models with input shapes different from pretraining one, try to make height and width to be multiple of 32 * window_size. Otherwise, a lot of tensors will be padded, resulting in speed degradation.

Evaluation

For correctness, Tiny and Small models (original and ported) tested with ImageNet-v2 test set.

import tensorflow as tf
import tensorflow_datasets as tfds
from tfgcvit import GCViTTiny, preprocess_input


def _prepare(example, input_size=224, crop_pct=0.875):
    scale_size = tf.math.floor(input_size / crop_pct)

    image = example['image']

    shape = tf.shape(image)[:2]
    shape = tf.cast(shape, 'float32')
    shape *= scale_size / tf.reduce_min(shape)
    shape = tf.round(shape)
    shape = tf.cast(shape, 'int32')

    image = tf.image.resize(image, shape, method=tf.image.ResizeMethod.BICUBIC)
    image = tf.round(image)
    image = tf.clip_by_value(image, 0., 255.)
    image = tf.cast(image, 'uint8')

    pad_h, pad_w = tf.unstack((shape - input_size) // 2)
    image = image[pad_h:pad_h + input_size, pad_w:pad_w + input_size]

    image = preprocess_input(image)

    return image, example['label']


imagenet2 = tfds.load('imagenet_v2', split='test', shuffle_files=True)
imagenet2 = imagenet2.map(_prepare, num_parallel_calls=tf.data.AUTOTUNE)
imagenet2 = imagenet2.batch(8).prefetch(tf.data.AUTOTUNE)

model = GCViTTiny()
model.compile('sgd', 'sparse_categorical_crossentropy', ['accuracy', 'sparse_top_k_categorical_accuracy'])
history = model.evaluate(imagenet2)

print(history)
name original acc@1 ported acc@1 original acc@5 ported acc@5
Tiny 73.01 72.93 90.75 90.70
Small 73.39 73.46 91.09 91.14

The most metric differences comes from input data preprocessing (decoding, interpolation). All layers outputs have been compared with original ones. Maximum absolute difference among all layers is 8e-4. Most of them have maximum absolute difference less then 1e-5.

Citation

@article{hatamizadeh2022global,
  title={Global Context Vision Transformers},
  author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
  journal={arXiv preprint arXiv:2206.09959},
  year={2022}
}

tfgcvit's People

Contributors

shkarupa-alex avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

tfgcvit's Issues

Release Date

@shkarupa-alex Thank you for open-sourcing your project. I think this project is yet to be released as model weights' are currently being loaded from the local directory,

weights_path = f'/Users/alex/Develop/tfgcvit/weights/{model_name}.h5'

  • As the original repo is yet to be fixed and not sure when it'll be done I'm wondering if there is any plan to make a pre-release?
  • I've tried running the codes directly but it ran into error, but with some minor changes it was able to run nicely.
  • I've also checked the prediction using the converted_weights.py, it works nicely.
  • Is there any plan to migrate from keras to tf.keras? Nowadays, it seems keras is getting obsolete.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.