Coder Social home page Coder Social logo

meituan-automl / twins Goto Github PK

View Code? Open in Web Editor NEW
577.0 15.0 69.0 2.36 MB

Two simple and effective designs of vision transformer, which is on par with the Swin transformer

License: Apache License 2.0

Python 99.45% Shell 0.55%
vision-transformer-architectures image-classification twins

twins's Introduction

[NeurIPS 2021] Twins: Revisiting the Design of Spatial Attention in Vision Transformers

NeurIPS

Very recently, a variety of vision transformer architectures for dense prediction tasks have been proposed and they show that the design of spatial attention is critical to their success in these tasks. In this work, we revisit the design of the spatial attention and demonstrate that a carefully-devised yet simple spatial attention mechanism performs favourably against the state-of-the-art schemes. As a result, we propose two vision transformer architectures, namely, Twins- PCPVT and Twins-SVT. Our proposed architectures are highly-efficient and easy to implement, only involving matrix multiplications that are highly optimized in modern deep learning frameworks. More importantly, the proposed architectures achieve excellent performance on a wide range of visual tasks including image- level classification as well as dense detection and segmentation. The simplicity and strong performance suggest that our proposed architectures may serve as stronger backbones for many vision tasks.

Twins-SVT-S Figure 1. Twins-SVT-S Architecture (Right side shows the inside of two consecutive Transformer Encoders).

Usage

First, clone the repository locally:

git clone https://github.com/Meituan-AutoML/Twins.git

Then, install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models==0.3.2:

conda install -c pytorch pytorch torchvision
pip install timm==0.3.2

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision datasets.ImageFolder, and the training and validation data is expected to be in the train/ folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Model Zoo

Image Classification

We provide a series of Twins models pretrained on ILSVRC2012 ImageNet-1K dataset.

Model Alias in the paper Acc@1 FLOPs(G) #Params (M) URL Log
PCPVT-Small Twins-PCPVT-S 81.2 3.7 24.1 pcpvt_small.pth pcpvt_s.txt
PCPVT-Base Twins-PCPVT-B 82.7 6.4 43.8 pcpvt_base.pth pcpvt_b.txt
PCPVT-Large Twins-PCPVT-L 83.1 9.5 60.9 pcpvt_large.pth pcpvt_l.txt
ALTGVT-Small Twins-SVT-S 81.7 2.8 24 alt_gvt_small.pth svt_s.txt
ALTGVT-Base Twins-SVT-B 83.2 8.3 56 alt_gvt_base.pth svt_b.txt
ALTGVT-Large Twins-SVT-L 83.7 14.8 99.2 alt_gvt_large.pth svt_l.txt

Training

To train Twins-SVT-B on ImageNet using 8 gpus for 300 epochs, run

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model alt_gvt_base --batch-size 128 --data-path path_to_imagenet --dist-eval --drop-path 0.3

Evaluation

To evaluate the performance of Twins-SVT-L on ImageNet using one GPU, run

python main.py --eval --resume alt_gvt_large.pth  --model alt_gvt_large --data-path path_to_imagenet

This should give

* Acc@1 83.660 Acc@5 96.512 loss 0.722
Accuracy of the network on the 50000 test images: 83.7%

Semantic Segmentation

Our code is based on mmsegmentation. Please install mmsegmentation first.

We provide a series of Twins models and training logs trained on the Ade20k dataset. It's easy to extend it to other datasets and segmentation methods.

Model Alias in the paper mIoU(ss/ms) FLOPs(G) #Params (M) URL Log
PCPVT-Small Twins-PCPVT-S 46.2/47.5 234 54.6 pcpvt_small.pth pcpvt_s.txt
PCPVT-Base Twins-PCPVT-B 47.1/48.4 250 74.3 pcpvt_base.pth pcpvt_b.txt
PCPVT-Large Twins-PCPVT-L 48.6/49.8 269 91.5 pcpvt_large.pth pcpvt_l.txt
ALTGVT-Small Twins-SVT-S 46.2/47.1 228 54.4 alt_gvt_small.pth svt_s.txt
ALTGVT-Base Twins-SVT-B 47.4/48.9 261 88.5 alt_gvt_base.pth svt_b.txt
ALTGVT-Large Twins-SVT-L 48.8/50.2 297 133 alt_gvt_large.pth svt_l.txt

Training

To train Twins-PCPVT-Large on Ade20k using 8 gpus for 160k iterations with a global batch size of 16, run

 bash dist_train.sh configs/upernet_pcpvt_l_512x512_160k_ade20k_swin_setting.py 8

Evaluation

To evaluate Twins-PCPVT-Large on Ade20k using 8 gpus (single scale), run

bash dist_test.sh configs/upernet_pcpvt_l_512x512_160k_ade20k_swin_setting.py checkpoint_file 8 --eval mIoU

To evaluate Twins-PCPVT-Large on Ade20k using 8 gpus (multi scale), run

bash dist_test.sh configs/upernet_pcpvt_l_512x512_160k_ade20k_swin_setting.py checkpoint_file 8 --eval mIoU --aug-test

Detection

Our code is based on mmdetection. Please install mmdetection first (we use v2.8.0). We use both Mask R-CNN and RetinaNet to evaluate our method. It's easy to apply Twins in other detectors provided by mmdetection based on our examples.

Training

To train Twins-SVT-Small on COCO with 8 gpus for 1x schedule (PVT setting) under the framework of Mask R-CNN, run

 bash dist_train.sh configs/mask_rcnn_alt_gvt_s_fpn_1x_coco_pvt_setting.py 8

To train Twins-SVT-Small on COCO with 8 gpus for 3x schedule (Swin setting) under the framework of Mask R-CNN, run

 bash dist_train.sh configs/mask_rcnn_alt_gvt_s_fpn_3x_coco_swin_setting.py 8

Evaluation

To evaluate the mAP of Twins-SVT-Small on COCO using 8 gpus based on the Retina framework, run

bash dist_test.sh configs/retinanet_alt_gvt_s_fpn_1x_coco_pvt_setting.py checkpoint_file 8   --eval mAP

Citation

If you find this project useful in your research, please consider cite the following,

Twins:

@inproceedings{chu2021Twins,
	title={Twins: Revisiting the Design of Spatial Attention in Vision Transformers},
	author={Xiangxiang Chu and Zhi Tian and Yuqing Wang and Bo Zhang and Haibing Ren and Xiaolin Wei and Huaxia Xia and Chunhua Shen},
	booktitle={NeurIPS 2021},
  url={https://openreview.net/forum?id=5kTlVBkzSRx},
	year={2021}
}

CPVT:

@inproceedings{chu2023CPVT,
	title={Conditional Positional Encodings for Vision Transformers},
	author={Xiangxiang Chu and Zhi Tian and Bo Zhang and Xinlong Wang and Chunhua Shen},
	booktitle={ICLR 2023},
	url={https://openreview.net/forum?id=3KWnuT-R1bh},
	year={2023}
}

Acknowledgement

We heavily borrow the code from DeiT and PVT. We test throughputs as in Swin Transformer.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

twins's People

Contributors

cxxgtxy avatar serser avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

twins's Issues

About the code stage

2021-05-19 15-45-33 的屏幕截图
2021-05-19 15-56-40 的屏幕截图

Dear author, I have seen that there are steps stage1,2,3 and 4 in your Twin paper. I want to output the results of these 4 stages respectively, but I can't find them in the code. I am looking forward to your reply.

mIoU eval method (got higher mIoU than provided)

For the
ALTGVT-Large model on ADE20k dataset, the reported single scale mIoU in the page says 48.8
However, I used the mmsegmentation default evaluation code and got 49.07 (higher??)

(I am using default test_pipeline img_scale= (2048, 512) , mode='whole')
I used a single GPU for single scale inference.

Could you please have a look at this?

Thank you

Some confusion about warmup strategy

In the log you provided, I find that the warmup epochs number is not 5 and linear warmup start at the second epoch. This is inconsistent with paper:

We use a linear warm-up in the first five epochs ...

Twins/logs/svt_s.txt

Lines 1 to 7 in 37f9dbf

{"train_lr": 1.000000000000014e-06, "train_loss": 6.9166167094230655, "test_loss": 6.881752743440516, "test_acc1": 0.18800001103878022, "test_acc5": 0.9300000336265564, "epoch": 0, "n_parameters": 24060776}
{"train_lr": 1.000000000000014e-06, "train_loss": 6.900423232269287, "test_loss": 6.852993618039524, "test_acc1": 0.41600001563549044, "test_acc5": 1.5720000462150574, "epoch": 1, "n_parameters": 24060776}
{"train_lr": 0.00040080000000000486, "train_loss": 6.646979278850555, "test_loss": 5.493567599969752, "test_acc1": 6.424000176010132, "test_acc5": 18.284000532073975, "epoch": 2, "n_parameters": 24060776}
{"train_lr": 0.0008005999999999952, "train_loss": 6.297702983379364, "test_loss": 4.646971811266506, "test_acc1": 14.610000506286621, "test_acc5": 32.9140008934021, "epoch": 3, "n_parameters": 24060776}
{"train_lr": 0.001200399999999992, "train_loss": 6.0142835487365724, "test_loss": 3.968842138262356, "test_acc1": 22.43800064086914, "test_acc5": 45.01200138336181, "epoch": 4, "n_parameters": 24060776}
{"train_lr": 0.001600200000000024, "train_loss": 5.753731050109863, "test_loss": 3.4170068081687477, "test_acc1": 30.60400089279175, "test_acc5": 55.76000146942139, "epoch": 5, "n_parameters": 24060776}
{"train_lr": 0.001998636387080776, "train_loss": 5.494997973155975, "test_loss": 2.9434911687584484, "test_acc1": 38.45400118209839, "test_acc5": 64.25200190032959, "epoch": 6, "n_parameters": 24060776}

Why?
Thx.

请教调用ALTGVT问题

您好,请教一下 调用分类任务中的ALTGVT时,提示 类GroupBlock重载父类TimmBlock时多写了一个参数,报错如下:
super(GroupBlock, self).init(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop,
TypeError: init() takes from 3 to 10 positional arguments but 11 were given

查询TimmBlock后发现 其并没有qk_scale这个参数,请问这个bug如何解决呢

flops calculation is not accurate.

Hi, the get_flops.py doesn't consider the flops of self-attention, which is not accurate.
For Twins-PCPVT-S:
use get_flops.py given:

==============================
Input shape: (3, 512, 2048)
Flops: 162.66 GFLOPs
Params: 28.37 M
==============================

when use the fvcore.nn.flop_count(attention will be included), I get:

==============================
Input shape: (3, 512, 2048)
Flops: 225.98693683200003
Params: 28372862
==============================

Questions about table 5

Hi,

In your paper table 5, the (G,G,G,G) uses the numbers (79.8%) from PVT paper, which uses absolution positional encoding. However, I suppose the other model variants listed in this table use CPE, so they are not directly comparable. Should the accuracy of (G,G,G,G) with CPE be 81.2% as shown in table 1?

In general, I am interested in knowing if there is a benifit of using global attention in the early layers.

Thanks.

学习率

Twins/main.py

Line 263 in 4700293

linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0

在代码中,lr会跟随batch size线性调整。 但wramup-lr和min-lr为什么不用调整呢?

希望能够得到帮助,感谢~~~

Welcome update to OpenMMLab 2.0

Welcome update to OpenMMLab 2.0

I am Vansin, the technical operator of OpenMMLab. In September of last year, we announced the release of OpenMMLab 2.0 at the World Artificial Intelligence Conference in Shanghai. We invite you to upgrade your algorithm library to OpenMMLab 2.0 using MMEngine, which can be used for both research and commercial purposes. If you have any questions, please feel free to join us on the OpenMMLab Discord at https://discord.gg/amFNsyUBvm or add me on WeChat (van-sin) and I will invite you to the OpenMMLab WeChat group.

Here are the OpenMMLab 2.0 repos branches:

OpenMMLab 1.0 branch OpenMMLab 2.0 branch
MMEngine 0.x
MMCV 1.x 2.x
MMDetection 0.x 、1.x、2.x 3.x
MMAction2 0.x 1.x
MMClassification 0.x 1.x
MMSegmentation 0.x 1.x
MMDetection3D 0.x 1.x
MMEditing 0.x 1.x
MMPose 0.x 1.x
MMDeploy 0.x 1.x
MMTracking 0.x 1.x
MMOCR 0.x 1.x
MMRazor 0.x 1.x
MMSelfSup 0.x 1.x
MMRotate 1.x 1.x
MMYOLO 0.x

Attention: please create a new virtual environment for OpenMMLab 2.0.

Question about GSA

Hello, thank you very much for your excellent work. I have some questions about GSA. According to my personal understanding, GSA in the paper takes one representation from each window, so the sr_ratio should be the same as the window size ([7, 7, 7, 7]) when calculating Key and Value, but it is [8, 4, 2, 1] in the code. Is there anything wrong with my understanding?

@BACKBONES.register_module()
class alt_gvt_large(ALTGVT):
    def __init__(self, **kwargs):
        super(alt_gvt_large, self).__init__(
            patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1],
            extra_norm=True, drop_path_rate=0.3,
        )

GSA代码

您好,我想问一下GSA代码在您的项目哪个部分可以找到?

The difference between PEG and PEG for detection

Hi, thanks for your great work.
I meet some confusion when reading the paper. Specifically, in the part of Supplement C. Example Code, there is a light difference between the two presented Algorithms.
In Algorithm 1 PyTorch snippet of PEG, the PEG includes a Conv layer, while in Algorithm 2PyTorch snippet of PEG for detection, there are additional BN+Relu layers.
I wonder how about the effectiveness comparison of this two setting, would the second setting with BN+Relu be better?
Thank you.

issue with the ALTGVT models code

Model = alt_gvt_base(pretrained=False)

when I use models which has called groupblock.
The following error occured:
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
TypeError: '>' not supported between instances of 'type' and 'float'

can not reproduce the performance of svt-small model

Thanks for your nice work!
And I would like to reproduce the performance of svt-small(alt_gvt_small) model. Below is my code:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model alt_gvt_small --batch-size 256 --data-path ../data/ImageNet --dist-eval --drop-path 0.2

The other parameters are default. But the result only up to 81.1%, not 81.7%.
Could you give me some suggestions on how to reproduce your nice performance from scratch?

TypeError: 'DataContainer' object is not subscriptable

Traceback (most recent call last):
  File "train.py", line 202, in <module>
    main()
  File "train.py", line 198, in main
    meta=meta)
  File "/opt/tiger/conda/lib/python3.7/site-packages/mmdet/apis/train.py", line 170, in train_detector
    runner.run(data_loaders, cfg.workflow)
  File "/opt/tiger/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 125, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/opt/tiger/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 54, in train
    self.call_hook('after_train_epoch')
  File "/opt/tiger/conda/lib/python3.7/site-packages/mmcv/runner/base_runner.py", line 308, in call_hook
    getattr(hook, fn_name)(self)
  File "/opt/tiger/conda/lib/python3.7/site-packages/mmdet/core/evaluation/eval_hooks.py", line 276, in after_train_epoch
    gpu_collect=self.gpu_collect)
  File "/opt/tiger/conda/lib/python3.7/site-packages/mmdet/apis/test.py", line 97, in multi_gpu_test
    result = model(return_loss=False, rescale=True, **data)
  File "/opt/tiger/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/tiger/conda/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 705, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/opt/tiger/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/tiger/conda/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 110, in new_func
    output = old_func(*new_args, **new_kwargs)
  File "/opt/tiger/conda/lib/python3.7/site-packages/mmdet/models/detectors/base.py", line 183, in forward
    return self.:(img, img_metas, **kwargs)
  File "/opt/tiger/conda/lib/python3.7/site-packages/mmdet/models/detectors/base.py", line 150, in forward_test
    img_meta[img_id]['batch_input_shape'] = tuple(img.size()[-2:])
TypeError: 'DataContainer' object is not subscriptable

COCO检测的时候,train没问题,但test一直有上面的bug

Why 'ws 1 for stand attention' in your GroupAttention code?

I find that in your implementation of GroupAttention in gvt.py, you comment that 'ws 1 for stand attention'.

class GroupAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ws=1, sr_ratio=1.0):
        """
        ws 1 for stand attention
        """
        super(GroupAttention, self).__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

However, I think ws means the window size, if ws=1, than the self-attention is only performed in a 1x1 window, which is not the standard self-attention.

A little confusion about GSA

Hi, I am reading your concise but effective work. Since I am not in the field of computer vision, I have a narrow question about the GSA module:
After the global sub-sampled attention, the size of feature map is reduced to m*n, how do you resume it to the original size?

Another question is, have you tried the global max pooling as the sub-sample function?

Thank you!

lr rate

The default lr rate is 5e-4 rather than 1e-3. So which lr rate should I use to reproduce your results?

关于mmdet版本问题

我安装了2.8.0的mmdet
但在执行
config_file = 配置文件路径
checkpoint_file = 模型路径
model = init_detector(config_file, checkpoint_file, device='cuda:0')
加载模型时,却得到了这样的报错:

Traceback (most recent call last):
File "newtest.py", line 17, in
model = init_detector(config_file, checkpoint_file, device='cuda:0')
File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/apis/inference.py", line 38, in init_detector
model = build_detector(config.model, test_cfg=config.test_cfg)
File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/builder.py", line 67, in build_detector
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/builder.py", line 32, in build
return build_from_cfg(cfg, registry, default_args)
File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmcv/utils/registry.py", line 171, in build_from_cfg
return obj_cls(**args)
File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/detectors/mask_rcnn.py", line 24, in init
pretrained=pretrained)
File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/detectors/two_stage.py", line 26, in init
self.backbone = build_backbone(backbone)
File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/builder.py", line 37, in build_backbone
return build(cfg, BACKBONES)
File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmdet/models/builder.py", line 32, in build
return build_from_cfg(cfg, registry, default_args)
File "/home/sxn/data/env/openmmlab/lib/python3.7/site-packages/mmcv/utils/registry.py", line 164, in build_from_cfg
f'{obj_type} is not in the {registry.name} registry')
KeyError: 'alt_gvt_small is not in the backbone registry'

然后gvt.py里确实是有注册成backbone的,不知道如何解决

Runtime error for mmseg

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument find_unused_parameters=True to torch.nn.parallel.DistributedDataParallel; (2) making sure all forward function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's forward function. Please include the loss function and the structure of the return value of forward of your module when reporting this issue (e.g. list, dict, iterable).

Hi! Thanks for opensourcing code.
When I use the Twins backbone in mmseg, I found this error. It seems that there are several parameters do not generate the loss.

Part of the code is missing

Hello, I am very interested in your work, but I did not find the GSA part when reading the code. Is this part not open source yet?

训练总是卡住

训练几个epoch,程序总是卡住。某个gpu的利用率变成0,另外7的gpu利用率变成100。 程序不往下运行了,而且也不报错。

关于学习率

您好,我在复现论文的过程中,发现在pvpvt_s.txt 学习率最大为0.00125

{"train_lr": 1.000000000000015e-06, "train_loss": 6.913535571038723, "test_loss": 6.8714314655021385, "test_acc1": 0.2160000117301941, "test_acc5": 1.2940000837326049, "epoch": 0, "n_parameters": 24106216}
{"train_lr": 1.000000000000015e-06, "train_loss": 6.896081995010376, "test_loss": 6.839652865021317, "test_acc1": 0.40800002346038816, "test_acc5": 1.7700001041412354, "epoch": 1, "n_parameters": 24106216}
{"train_lr": 0.0002507999999999969, "train_loss": 6.628805226147175, "test_loss": 5.555466687237775, "test_acc1": 6.462000321006775, "test_acc5": 18.384000979614257, "epoch": 2, "n_parameters": 24106216}
{"train_lr": 0.0005006000000000066, "train_loss": 6.272622795701027, "test_loss": 4.64223074471509, "test_acc1": 15.328000774383545, "test_acc5": 34.724001724243166, "epoch": 3, "n_parameters": 24106216}
{"train_lr": 0.0007504000000000098, "train_loss": 5.958464104115963, "test_loss": 3.9152780064830073, "test_acc1": 24.868001231384277, "test_acc5": 48.68200244750977, "epoch": 4, "n_parameters": 24106216}
{"train_lr": 0.0010002000000000064, "train_loss": 5.670980889737606, "test_loss": 3.3432016747969167, "test_acc1": 33.47600182495117, "test_acc5": 59.10200291748047, "epoch": 5, "n_parameters": 24106216}
{"train_lr": 0.0012491503115478462, "train_loss": 5.421633140593767, "test_loss": 2.9919017029029353, "test_acc1": 39.19200199279785, "test_acc5": 65.3920035522461, "epoch": 6, "n_parameters": 24106216}
{"train_lr": 0.0012487765716255204, "train_loss": 5.202176849722862, "test_loss": 2.6157535275927297, "test_acc1": 45.32800238342285, "test_acc5": 71.13600388793945, "epoch": 7, "n_parameters": 24106216}

请问是根据batchsize进行缩放得到的学习率还是通过调整得到的学习率
非常感谢

Pretrained model

Could you please upload your pertained model and an example for running your code on Github?

Two gvt.py in your repository, what the difference between them?

I find there are two gvt.py in your repository. One is in the main content, another one is in the segmentation content. By carefully comparing the two py files, I found that the calculation of group attention in gvt.py in the segmentation directory is different. One adds attn. Mask in gvt.py and the other does not. So I want to ask, attn. Mask
What is the role of the calculation of group attention? Why do you do this? Which gvt.py should I use when I'm doing a segmentation task?

Can we train or test on single GPU in detection sections?

If we want to test detection task, or just use the shell code like 'bash dist_test.sh configs/retinanet_alt_gvt_s_fpn_1x_coco_pvt_setting.py checkpoint_file 1 --eval mAP' ?

Or change the lr? and the number of the worker ?
I'm a beginner of the mmdet framework, please help...
this is the error lines:

/home/user/miniconda3/envs/twins/lib/python3.8/site-packages/torch/distributed/launch.py:163: DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead
logger.warn(
The module torch.distributed.launch is deprecated and going to be removed in future.Migrate to torch.distributed.run
WARNING:torch.distributed.run:--use_env is deprecated and will be removed in future releases.
Please read local_rank from os.environ('LOCAL_RANK') instead.
INFO:torch.distributed.launcher.api:Starting elastic_operator with launch configs:
entrypoint : ./test.py
min_nodes : 1
max_nodes : 1
nproc_per_node : 1
run_id : none
rdzv_backend : static
rdzv_endpoint : 127.0.0.1:29500
rdzv_configs : {'rank': 0, 'timeout': 900}
max_restarts : 3
monitor_interval : 5
log_dir : None
metrics_cfg : {}

INFO:torch.distributed.elastic.agent.server.local_elastic_agent:log directory set to: /tmp/torchelastic_o5bp99y9/none_u2fqutod
INFO:torch.distributed.elastic.agent.server.api:[default] starting workers for entrypoint: python
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous'ing worker group
/home/user/miniconda3/envs/twins/lib/python3.8/site-packages/torch/distributed/elastic/utils/store.py:52: FutureWarning: This is an experimental API and will be changed in future.
warnings.warn(
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous complete for workers. Result:
restart_count=0
master_addr=127.0.0.1
master_port=29500
group_rank=0
group_world_size=1
local_ranks=[0]
role_ranks=[0]
global_ranks=[0]
role_world_sizes=[1]
global_world_sizes=[1]

INFO:torch.distributed.elastic.agent.server.api:[default] Starting worker group
INFO:torch.distributed.elastic.multiprocessing:Setting worker0 reply file to: /tmp/torchelastic_o5bp99y9/none_u2fqutod/attempt_0/0/error.json
loading annotations into memory...
Done (t=0.52s)
creating index...
index created!
Traceback (most recent call last):
File "./test.py", line 213, in
main()
File "./test.py", line 166, in main
model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
File "/home/user/miniconda3/envs/twins/lib/python3.8/site-packages/mmdet/models/builder.py", line 67, in build_detector
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
File "/home/user/miniconda3/envs/twins/lib/python3.8/site-packages/mmdet/models/builder.py", line 32, in build
return build_from_cfg(cfg, registry, default_args)
File "/home/user/miniconda3/envs/twins/lib/python3.8/site-packages/mmcv/utils/registry.py", line 171, in build_from_cfg
return obj_cls(**args)
File "/home/user/miniconda3/envs/twins/lib/python3.8/site-packages/mmdet/models/detectors/retinanet.py", line 16, in init
super(RetinaNet, self).init(backbone, neck, bbox_head, train_cfg,
File "/home/user/miniconda3/envs/twins/lib/python3.8/site-packages/mmdet/models/detectors/single_stage.py", line 25, in init
self.backbone = build_backbone(backbone)
File "/home/user/miniconda3/envs/twins/lib/python3.8/site-packages/mmdet/models/builder.py", line 37, in build_backbone
return build(cfg, BACKBONES)
File "/home/user/miniconda3/envs/twins/lib/python3.8/site-packages/mmdet/models/builder.py", line 32, in build
return build_from_cfg(cfg, registry, default_args)
File "/home/user/miniconda3/envs/twins/lib/python3.8/site-packages/mmcv/utils/registry.py", line 171, in build_from_cfg
return obj_cls(**args)
File "/home/user/project/Twins/detection/gvt.py", line 482, in init
super(alt_gvt_small, self).init(
File "/home/user/project/Twins/detection/gvt.py", line 419, in init
super(ALTGVT, self).init(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads,
File "/home/user/project/Twins/detection/gvt.py", line 408, in init
super(PCPVT, self).init(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads,
File "/home/user/project/Twins/detection/gvt.py", line 343, in init
super(CPVTV2, self).init(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, mlp_ratios,
File "/home/user/project/Twins/detection/gvt.py", line 234, in init
_block = nn.ModuleList([block_cls(
File "/home/user/project/Twins/detection/gvt.py", line 234, in
_block = nn.ModuleList([block_cls(
File "/home/user/project/Twins/detection/gvt.py", line 164, in init
super(GroupBlock, self).init(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop,
TypeError: init() takes from 3 to 10 positional arguments but 11 were given
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 11449) of binary: /home/user/miniconda3/envs/twins/bin/python
ERROR:torch.distributed.elastic.agent.server.local_elastic_agent:[default] Worker group failed
INFO:torch.distributed.elastic.agent.server.api:[default] Worker group FAILED. 3/3 attempts left; will restart worker group
INFO:torch.distributed.elastic.agent.server.api:[default] Stopping worker group
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous'ing worker group
INFO:torch.distributed.elastic.agent.server.api:[default] Rendezvous complete for workers. Result:
restart_count=1
master_addr=127.0.0.1
master_port=29500
group_rank=0
group_world_size=1
local_ranks=[0]
role_ranks=[0]
global_ranks=[0]
role_world_sizes=[1]
global_world_sizes=[1]

INFO:torch.distributed.elastic.agent.server.api:[default] Starting worker group
INFO:torch.distributed.elastic.multiprocessing:Setting worker0 reply file to: /tmp/torchelastic_o5bp99y9/none_u2fqutod/attempt_1/0/error.json

关于计算复杂度的问题

您好,有个问题我没太想清楚,就是GSA的计算复杂度为什么是 O(mnHW d) = O( H2W2d k1k2 )。非常感谢!

PCPVT代码问题

您好,论文PCPVT图中,PEG的输出特征是要和PEG输入特征相加,但是在代码中PCPVT中我看怎么没有?
image

Run Train in Windows10,ERROR

run train script in windows10,error :

yapf.yapflib.verifier.InternalError: (unicode error) 'unicodeescape' codec can't decode bytes in position 9-10: truncated \uXXXX escape (, line 1)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.