Coder Social home page Coder Social logo

giser18 / an-overview-of-segformer-and-details-description Goto Github PK

View Code? Open in Web Editor NEW

This project forked from acsekevin/an-overview-of-segformer-and-details-description

0.0 0.0 0.0 12.79 MB

This repository provides an overview of Segformer, architecture encoder in particular. Some details of Segformer can be misleaded, thus makes a short description here to help understand the model. Meanwhile, the code (Keras/TensorFlow) is also provided for supporting.

License: MIT License

Python 100.00%

an-overview-of-segformer-and-details-description's Introduction

An Overview of Segformer and Details Description

Preface

In this repository, the structure of the Segformer model is explained. In many recent blog posts and tutorials, the structure of Segformer has been misunderstood by many people, even experienced computer vision engineers, for reasons that may include misleading diagrams of the Segformer structure in the original paper, but the model structure is shown clearly in the source code address given in the paper. Therefore, the details of the Segformer, including OverlapPatchEmbedding, Efficient Multihead Attention, Mixed-FeedForward Network, OverlapPatchMerging and Segformer block, will also be elaborated here. If there is any problem, please feel free to make a complain, also make a contact if convenient.
Also, the code has been uploaded for a reference which is developed by Keras/TensorFlow.

This is an multi-task segmentation example of street scene. Images taken from city center, Sheffield, England.

Basics and File Description

Project:

git clone https://github.com/ACSEkevin/An-Overview-of-Segformer-and-Details-Description.git

ADEChallengeData2016/: ADE20K dataset which has been used for training and testing the model, please refer to: ADE20K Dataset.
models/: Two types of programming the model: structrual and class inheritance.
adedataset.py: a dataset batch generator (keras requirement).
train.py: model train script. NOTICE: this is an example basic train script that Keras .fit() API has been used, for detailed model training please use TensorFlow to build a train_one_epoch().

To be continued: A validation script; a predict script for model output.

A General Overview of the Model Arcitecture

Here a re-drawn architecture replaces the one from the original paper, which might help to gain a better understanding.

drawing

To conclude and compare:

  • In encoder, an input image is scaled to its $\frac{1}{32}$ and then upsampled to $\frac{1}{4}$ of the original size in decoder. However, the model given in the repository upsampled to the full size to attempt for a better result. This can be revised after cloning.
  • In the original figure, OverlapPatchEmbedding layer is only shown at the begining of the architecture, which can be misleading, infact, there is always OverlapPatchEmbedding layers followed by previous transformer block (shown as SegFormer Block in the figure). Nevertheless, the paper presents a plural term as `OverlapPatchEmbeddings' which implies that there more than one layer.
  • There is a OverlapPatchMerging layer at the end of the transformer block, this layer reshapes the vector groups back to feature maps. It can be easy to confuse these two layers as many blogs shows a `no-merging-after-block' opinion.
  • The feature map $C_1$ goes through the MLP Layer without upsampling. Others are upsampled by $\times 2, \times 4, \times 8$ respectively with bilinear interpolation.

A Single Stage of the Encoder

OverlapPatchEmbedding

In basic trabsformer block, an image is split and patched as a 'sequence', there is no info interaction between patches (strides=patch_size). While in Segformer, the patch size > strides which leads to information sharing between patches (each conv row) thus called 'overlapped' patches. In the end, followed by a layer normalization.

x = Conv2D(n_filters, kernel_size=kernel_size, strides=strides, padding='same')(inputs)
batches, height, width, embed_dim = x.shape
x = tf.reshape(x, shape=[-1, height * width, embed_dim])
x = LayerNormalization()(x)

A Segformer Block

Below is a diagram that shows the detailed architecture of an A Segformer Block module. A sequence goes through Efficient Self-Attention and Mix-Feedforward Network layers, each preceded by a Layer Normalization.

drawing

Efficient Self-Attention

In the paper, the authors proposed an Efficient Self-Attention to reduce the temporal complexity from $O(n^2)$ to $O(\frac{n^2}{sr})$ where $sr$ is sampling reduction ratio. The module trans back to basic Self-Attention $sr=1$.

  • Like a normal Self-Attention module, each vector of an input sequence will propose $query$, $key$ and $value$. While there is only one vector shown in the figure.
  • Differently, $key$ and $value$ matrices go through reduction layer then participate in transformations. The layer can be implemented by Conv2D which plays a role of down sampling (strides $=$ kernel_size), then followed by a Layer Normalization. The Reshape layers helps reconstruct and de-construct feature maps respectively.
  • Shape changes in reduction layer: $[num_{patches}, dim_{embed}]$ -> $[height, width, dim_{embed}]$ -> $[\frac{height}{sr}, \frac{width}{sr}, dim_{embed}]$ -> $[\frac{height \times width}{sr^2}, dim_{embed}]$.

reduction layer:

batches, n_patches, channels = inputs.shape
if sr_ratio > 1:
  inputs = tf.reshape(inputs, shape=[batches, height, width, embed_dim])
  inputs = Conv2D(embed_dim, kernel_size=sr_ratio, strides=sr_ratio, padding='same')(inputs)
  inputs = LayerNormalization()(inputs)
  inputs = tf.reshape(inputs, shape=[batches, (height * width) // (sr_ratio ** 2), embed_dim])

Mix-Feedforward Network

Condtional Positional Encoding method addresses the problem of loss of accuracy resulted from different input resolutions in VisionTransformer. In this paper authors pointed out that positional encoding(PE) is not necessary for segmentation tasks. Thus there is only a Conv $3 \times 3$ layer without PE in Mix-FFN.

  • In the code, the layer DWConv was adpoted rather than Conv $3 \times 3$ descripted in the paper , which can be mis-leading.
  • The Reshape layers have the same purpose as those in reduction layer from Efficient Self-Attention.
  • Shape changes in Mix-FFN layer: $[num_{patches}, dim_{embed}]$ -> $[num_{patches}, dim_{embed} \cdot rate_{exp}]$ -> $[height, width, dim_{embed} \cdot rate_{exp}]$ -> $[num_{patches}, dim_{embed} \cdot rate_{exp}]$ -> $[num_{patches}, dim_{embed}]$.
batches, n_patches, channels = inputs.shape
x = Dense(int(embed_dim * expansion_rate), use_bias=True)(inputs)
x = tf.reshape(x, shape=[batches, height, width, int(embed_dim * expansion_rate)])
x = DepthwiseConv2D(kernel_size=3, strides=1, padding='same')(x)
x = tf.reshape(x, shape=[batches, n_patches, int(embed_dim * expansion_rate)])
x = Activation('gelu')(x)
x = Dense(embed_dim, use_bias=True)(x)
x = Dropout(rate=drop_rate)(x)

OverlapPatchMerging

This is a simple reshape operation to reconstruct sequences (patches) to feature maps. There is also a detail that the layer is also proceded by a Layer Normalization.

x = LayerNormalization()(x)
feature_Cx = tf.reshape(x, shape=[batches, height_Cx, width_Cx, embed_dims[index]])

where embed_dims[index] can be a list that stores the embedding dimension of each Segformer block.

References

  • Xie, E. et al. 'SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers', NeurIPS 2021, arXiv doi: 10.48550/arXiv.2105.15203

  • Chu, X. et al. (2021) 'Conditional Positional Encodings for Vision Transformers', ICLR 2023, pp. 1-19. arXiv doi: 10.48550/arXiv.2102.1088

  • Zhou, B. et al. (2017). 'Scene Parsing through ADE20K Dataset', Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR). doi: 10.1109/CVPR.2017.544

Author and Contributor

an-overview-of-segformer-and-details-description's People

Contributors

acsekevin avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.