Coder Social home page Coder Social logo

youngwanlee / mpvit Goto Github PK

View Code? Open in Web Editor NEW
356.0 8.0 39.0 125 KB

[CVPR 2022] MPViT:Multi-Path Vision Transformer for Dense Prediction

Home Page: https://arxiv.org/abs/2112.11010

License: Other

Python 83.74% Shell 1.04% C++ 1.38% Cuda 13.85%
vision-transformer detectron2 mmsegmentation

mpvit's Introduction

MPViT : Multi-Path Vision Transformer for Dense Prediction

This repository inlcudes official implementations and model weights for MPViT.

[Arxiv] [BibTeX]

MPViT : Multi-Path Vision Transformer for Dense Prediction
🏛️️️🏫Youngwan Lee, 🏛️️️Jonghee Kim, 🏫Jeff Willette, 🏫Sung Ju Hwang
ETRI:classical_building:️, KAIST:school:

News

🎉 MPViT has been accepted in CVPR2022.

Abstract

We explore multi-scale patch embedding and multi-path structure, constructing the Multi-Path Vision Transformer (MPViT). MPViT embeds features of the same size (i.e., sequence length) with patches of different scales simultaneously by using overlapping convolutional patch embedding. Tokens of different scales are then independently fed into the Transformer encoders via multiple paths and the resulting features are aggregated, enabling both fine and coarse feature representations at the same feature level. Thanks to the diverse and multi-scale feature representations, our MPViTs scaling from Tiny(5M) to Base(73M) consistently achieve superior performance over state-of-the-art Vision Transformers on ImageNet classification, object detection, instance segmentation, and semantic segmentation. These extensive results demonstrate that MPViT can serve as a versatile backbone network for various vision tasks.

Main results on ImageNet-1K

🚀 These all models are trained on ImageNet-1K with the same training recipe as DeiT and CoaT.

model resolution acc@1 #params FLOPs weight
MPViT-T 224x224 78.2 5.8M 1.6G weight
MPViT-XS 224x224 80.9 10.5M 2.9G weight
MPViT-S 224x224 83.0 22.8M 4.7G weight
MPViT-B 224x224 84.3 74.8M 16.4G weight

Main results on COCO object detection

🚀 All model are trained using ImageNet-1K pretrained weights.

☀️ MS denotes the same multi-scale training augmentation as in Swin-Transformer which follows the MS augmentation as in DETR and Sparse-RCNN. Therefore, we also follows the official implementation of DETR and Sparse-RCNN which are also based on Detectron2.

Please refer to detectron2/ for the details.

Backbone Method lr Schd box mAP mask mAP #params FLOPS weight
MPViT-T RetinaNet 1x 41.8 - 17M 196G model | metrics
MPViT-XS RetinaNet 1x 43.8 - 20M 211G model | metrics
MPViT-S RetinaNet 1x 45.7 - 32M 248G model | metrics
MPViT-B RetinaNet 1x 47.0 - 85M 482G model | metrics
MPViT-T RetinaNet MS+3x 44.4 - 17M 196G model | metrics
MPViT-XS RetinaNet MS+3x 46.1 - 20M 211G model | metrics
MPViT-S RetinaNet MS+3x 47.6 - 32M 248G model | metrics
MPViT-B RetinaNet MS+3x 48.3 - 85M 482G model | metrics
MPViT-T Mask R-CNN 1x 42.2 39.0 28M 216G model | metrics
MPViT-XS Mask R-CNN 1x 44.2 40.4 30M 231G model | metrics
MPViT-S Mask R-CNN 1x 46.4 42.4 43M 268G model | metrics
MPViT-B Mask R-CNN 1x 48.2 43.5 95M 503G model | metrics
MPViT-T Mask R-CNN MS+3x 44.8 41.0 28M 216G model | metrics
MPViT-XS Mask R-CNN MS+3x 46.6 42.3 30M 231G model | metrics
MPViT-S Mask R-CNN MS+3x 48.4 43.9 43M 268G model | metrics
MPViT-B Mask R-CNN MS+3x 49.5 44.5 95M 503G model | metrics

Deformable-DETR

All models are trained using the same training recipe.

Please refer to deformable_detr/ for the details.

backbone box mAP epochs link
ResNet-50 44.5 50 -
CoaT-lite S 47.0 50 link
CoaT-S 48.4 50 link
MPViT-S 49.0 50 link

Main results on ADE20K Semantic segmentation

All model are trained using ImageNet-1K pretrained weight.

Please refer to semantic_segmentation/ for the details.

Backbone Method Crop Size Lr Schd mIoU #params FLOPs weight
MPViT-S UperNet 512x512 160K 48.3 52M 943G weight
MPViT-B UperNet 512x512 160K 50.3 105M 1185G weight

Getting Started

✋ We use pytorch==1.7.0 torchvision==0.8.1 cuda==10.1 libraries on NVIDIA V100 GPUs. If you use different versions of cuda, you may obtain different accuracies, but the differences are negligible.

Acknowledgement

This repository is built using the Timm library, DeiT, CoaT, Detectron2, mmsegmentation repositories.

This work was supported by Institute of Information & Communications Technology Planning & Evaluation (IITP) grant funded by the Korean government (MSIT) (No. 2020-0-00004, Development of Previsional Intelligence based on Long-term Visual Memory Network and No. 2014-3-00123, Development of High Performance Visual BigData Discovery Platform for Large-Scale Realtime Data Analysis).

License

Please refer to MPViT LSA.

Citing MPViT

@inproceedings{lee2022mpvit,
      title={MPViT: Multi-Path Vision Transformer for Dense Prediction}, 
      author={Youngwan Lee and Jonghee Kim and Jeffrey Willette and Sung Ju Hwang},
      booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
      year={2022}
}

mpvit's People

Contributors

youngwanlee avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mpvit's Issues

what is the meaning of the following code in''MPViT-main\semantic_segmentation\configs\_base_\models''

这段代码具体对应于文中的哪个模块呀?为什么会牵扯到相对位置编码呢?
class ConvRelPosEnc(nn.Module):
"""Convolutional relative position encoding."""
def init(self, Ch, h, window):
"""Initialization.

    Ch: Channels per head.
    h: Number of heads.
    window: Window size(s) in convolutional relative positional encoding.
            It can have two forms:
            1. An integer of window size, which assigns all attention heads
               with the same window size in ConvRelPosEnc.
            2. A dict mapping window size to #attention head splits
               (e.g. {window size 1: #attention head split 1, window size
                                  2: #attention head split 2})
               It will apply different window size to
               the attention head splits.
    """
    super().__init__()

    if isinstance(window, int):
        # Set the same window size for all attention heads.
        window = {window: h}
        self.window = window
    elif isinstance(window, dict):
        self.window = window
    else:
        raise ValueError()

    self.conv_list = nn.ModuleList()
    self.head_splits = []
    for cur_window, cur_head_split in window.items():
        dilation = 1  # Use dilation=1 at default.
        padding_size = (cur_window + (cur_window - 1) *
                        (dilation - 1)) // 2
        cur_conv = nn.Conv2d(
            cur_head_split * Ch,
            cur_head_split * Ch,
            kernel_size=(cur_window, cur_window),
            padding=(padding_size, padding_size),
            dilation=(dilation, dilation),
            groups=cur_head_split * Ch,
            )
        self.conv_list.append(cur_conv)
        self.head_splits.append(cur_head_split)
    self.channel_splits = [x * Ch for x in self.head_splits]

def forward(self, q, v, size):
    """foward function"""
    B, h, N, Ch = q.shape
    H, W = size

    # We don't use CLS_TOKEN
    q_img = q
    v_img = v

    # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W].
    v_img = rearrange(v_img, "B h (H W) Ch -> B (h Ch) H W", H=H, W=W)
    # Split according to channels.
    v_img_list = torch.split(v_img, self.channel_splits, dim=1)
    conv_v_img_list = [
        conv(x) for conv, x in zip(self.conv_list, v_img_list)
    ]
    conv_v_img = torch.cat(conv_v_img_list, dim=1)
    # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch].
    conv_v_img = rearrange(conv_v_img, "B (h Ch) H W -> B h (H W) Ch", h=h)

    EV_hat_img = q_img * conv_v_img
    EV_hat = EV_hat_img
    return EV_hat

missing deps

Thanks for the amazing work!

When trying to use the semantic segmentation part, I miss some libraries, so I had to install them like

pip install einops timm

[]s

About visual patches of different sizes

Thank you for writing good paper!

It is mentioned in the paper as follows.
"MPViT embeds features of the same size (i.e., sequence length) with patches of different scales simultaneously by using overlapping convolutional patch embedding."

I wonder where patches of different scales are implemented in the code.

Looking forward to your reply. Thank you.

The following error occurred when I was training the upernet_mpvit_base_160k_ade20k.py file with a single GPU.

The following error occurred when I was training the upernet_mpvit_base_160k_ade20k.py file with a single GPU.
fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).
Traceback (most recent call last):
File "./tools/train.py", line 176, in
main()
File "./tools/train.py", line 165, in main
train_segmentor(
File "/export/home/rny/SegNeXt-main/mmseg/apis/train.py", line 110, in train_segmentor
cfg.device,
File "/export/home/rny/.conda/envs/openmmlab/lib/python3.8/site-packages/mmcv/utils/config.py", line 510, in getattr
return getattr(self._cfg_dict, name)
File "/export/home/rny/.conda/envs/openmmlab/lib/python3.8/site-packages/mmcv/utils/config.py", line 48, in getattr
raise ex
AttributeError: 'ConfigDict' object has no attribute 'device'
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 2698) of binary: /export/home/rny/.conda/envs/openmmlab/bin/python
Traceback (most recent call last):
File "/export/home/rny/.conda/envs/openmmlab/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/export/home/rny/.conda/envs/openmmlab/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/export/home/rny/.conda/envs/openmmlab/lib/python3.8/site-packages/torch/distributed/launch.py", line 193, in
main()
File "/export/home/rny/.conda/envs/openmmlab/lib/python3.8/site-packages/torch/distributed/launch.py", line 189, in main
launch(args)
File "/export/home/rny/.conda/envs/openmmlab/lib/python3.8/site-packages/torch/distributed/launch.py", line 174, in launch
run(args)
File "/export/home/rny/.conda/envs/openmmlab/lib/python3.8/site-packages/torch/distributed/run.py", line 715, in run
elastic_launch(
File "/export/home/rny/.conda/envs/openmmlab/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 131, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/export/home/rny/.conda/envs/openmmlab/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
SegNeXt-main in line 9 is the code for another article.
How do I solve this problem?

download in the model zoo

Great work!
I really expect to use your weight as the pretraining model, but it seems that the link can't be opened. Could you please check it?

about the channels

Thanks for the great work !

I wonder why the in_channels of decode_head is [ 224, 368, 480, 480 ] rather than [ 128, 224, 368, 480 ] for the MPViT-Base ?

Looking forward to your reply. Thanks again.

Pretrained model download

I can not download pretrained model from dropbox, can anyone help me? Is there any other place to download pretrained model such as Baidu Netdisk

multi-scale patch embedding

Thanks for your great work!
In your paper, you propose a multi-scale patch embedding that tokenizes the patches of different sizes at the same time by overlapping convolution operations
But I notice that the patch sizes are fixed as 3.
Additionary, the spatial sizes of the inputs for the mhca_stages in different paths are identical.
So what does it means for multi-scale patch embedding?
Looking forward to your reply!

Demo for visualization?

Thanks for the effort and releasing the code, can you please provide a demo for the detection part?

About FLOPS and Parameters of Mask R-CNN

How do you calculate the flops and parameters for Mask R-CNN model? I use analyze_model provided by detectron2, which is not consistent with your results. Would you mind share the tools?

Welcome update to OpenMMLab 2.0

Welcome update to OpenMMLab 2.0

I am Vansin, the technical operator of OpenMMLab. In September of last year, we announced the release of OpenMMLab 2.0 at the World Artificial Intelligence Conference in Shanghai. We invite you to upgrade your algorithm library to OpenMMLab 2.0 using MMEngine, which can be used for both research and commercial purposes. If you have any questions, please feel free to join us on the OpenMMLab Discord at https://discord.gg/A9dCpjHPfE or add me on WeChat (ID: van-sin) and I will invite you to the OpenMMLab WeChat group.

Here are the OpenMMLab 2.0 repos branches:

OpenMMLab 1.0 branch OpenMMLab 2.0 branch
MMEngine 0.x
MMCV 1.x 2.x
MMDetection 0.x 、1.x、2.x 3.x
MMAction2 0.x 1.x
MMClassification 0.x 1.x
MMSegmentation 0.x 1.x
MMDetection3D 0.x 1.x
MMEditing 0.x 1.x
MMPose 0.x 1.x
MMDeploy 0.x 1.x
MMTracking 0.x 1.x
MMOCR 0.x 1.x
MMRazor 0.x 1.x
MMSelfSup 0.x 1.x
MMRotate 0.x 1.x
MMYOLO 0.x

Attention: please create a new virtual environment for OpenMMLab 2.0.

About the parameters of drop_path_rate

@youngwanLEE Hi, I have a question about the drop_path_rate used for small and base models:
When I provide the parameter of drop_path, e.g., 0.1, and run the code from your repo, I received a warning from the FVCore library:

"The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
mhca_stages.1.mhca_blks.0.MHCA_layers.0.drop_path ....."

Does this indicate that drop_path is not being used?

Some details on the implementation.

Hi @youngwanLEE.

Thanks for the great work. After reading the paper, I have some questions about the implementation details. One of the questions is that whether you have used the Relative Position Encoding in the transformer architecture design as Coat (Code here). I found that discarding the relative position encoding in Factorized Attention leads to a significant acc drop.

Looking forward to your reply. Thx.

Question about the norm layer

Hi, @youngwanLEE. So great work and the performance of the downstream tasks is promising.

I notice that you use BN in patch embedding layers and conv_stem, then use layer norm in MHCABlock; When transferring the pre-trained model to the downstream tasks, you replace the BN layer with SyncBN;

So I wonder that how many improvements can this bring by replacing the BN layer with SyncBN? for COCO detection and ADE20k semantic segmentation respectively? Thanks in advance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.