Coder Social home page Coder Social logo

tome's Introduction

Token Merging: Your ViT but Faster

Official PyTorch implemention of ToMe from our paper: Token Merging: Your ViT but Faster.
Daniel Bolya, Cheng-Yang Fu, Xiaoliang Dai, Peizhao Zhang, Christoph Feichtenhofer, Judy Hoffman.

What is ToMe?

ToMe Concept Figure

Token Merging (ToMe) allows you to take an existing Vision Transformer architecture and efficiently merge tokens inside of the network for 2-3x faster evaluation (see benchmark script). ToMe is tuned to seamlessly fit inside existing vision transformers, so you can use it without having to do additional training (see eval script). And if you do use ToMe during training, you can reduce the accuracy drop even further while also speeding up training considerably.

What ToMe does

ToMe Visualization

ToMe merges tokens based on their similarity, implicitly grouping parts of objects together. This is in contrast to token pruning, which only removes background tokens. ToMe can get away with reducing more tokens because we can merge redundant foreground tokens in addition to background ones. Visualization of merged tokens on ImageNet-1k val using a trained ViT-H/14 MAE model with ToMe. See this example for how to produce these visualizations. For more, see the paper appendix.

News

  • [2023.03.30] Daniel has released his implementation of ToMe for diffusion here. Check it out! (Note: this is an external implementation not affiliated with Meta in any way).
  • [2023.02.08] We are delighted to announce that the Meta Research Blog has highlighted our work, Token Merging! Check out the article at Meta Research Blog for more information.
  • [2023.01.31] We are happy to announce that our paper has been accepted for an oral presentation at ICLR 2023.
  • [2023.01.30] We've released checkpoints trained with ToMe for DeiT-Ti, DeiT-S, ViT-B, ViT-L, and ViT-H!
  • [2022.10.18] Initial release.

Installation

See INSTALL.md for installation details.

Usage

This repo does not include training code. Instead, we provide a set of tools to patch existing vision transformer implementations. Then, you can use those implementations out of the box. Currently, we support the following ViT implementations:

See the examples/ directory for a set of usage examples.

ToMe has also been implemented externally for other applications:

Note: these external implementations aren't associated with Meta in any way.

Using timm models

Timm is a commonly used implementation for vision transformers in PyTorch. As of version 0.4.12 it currently uses AugReg weights.

import timm, tome

# Load a pretrained model, can be any vit / deit model.
model = timm.create_model("vit_base_patch16_224", pretrained=True)
# Patch the model with ToMe.
tome.patch.timm(model)
# Set the number of tokens reduced per layer. See paper for details.
model.r = 16

Here are some expected results when using the timm implementation off-the-shelf on ImageNet-1k val using a V100:

Model original acc original im/s r ToMe acc ToMe im/s
ViT-S/16 81.41 953 13 79.30 1564
ViT-B/16 84.57 309 13 82.60 511
ViT-L/16 85.82 95 7 84.26 167
ViT-L/16 @ 384 86.92 28 23 86.14 56

See the paper for full results with all models and all values of r.

We've trained some DeiT (v1) models using the official implementation. To use, instantiate a DeiT timm model, patch it with the timm patch (prop_attn=True), and use ImageNet mean and variance for data loading.

Model original acc original im/s r ToMe acc ToMe im/s Checkpoint
DeiT-S/16 79.8 930 13 79.36 1550 deit_S_r13
DeiT-Ti/16 71.8 2558 13 71.27 3980 deit_T_r13

Using SWAG models through Torch Hub

SWAG is a repository of massive weakly-supervised ViT models. They are available from Torch Hub and we include a function to patch its implementation.

import torch, tome

# Load a pretrained model, can be one of ["vit_b16_in1k", "vit_l16_in1k", or "vit_h14_in1k"].
model = torch.hub.load("facebookresearch/swag", model="vit_b16_in1k")
# Patch the model with ToMe.
tome.patch.swag(model)
# Set the amount of reduction. See paper for details.
model.r = 45

Here are some results using these SWAG models off-the-shelf on ImageNet-1k val using a V100:

Model original acc original im/s r ToMe acc ToMe im/s
ViT-B/16 @ 384 85.30 85.7 45 84.59 167.7
ViT-L/16 @ 512 88.06 12.8 40 87.80 26.3
ViT-H/14 @ 518 88.55 4.7 40 88.25 9.8

Full results for other values of r are available in the paper appendix.

Training with MAE

We fine-tune models models pretrained with MAE using the official MAE codebase. Apply the patch as shown in this example and set r as desired (see paper appendix for full list of accuracies vs r). Then, follow the instructions in the MAE code-base to fine tune your model from pretrained weights.

Here are some results after training on ImageNet-1k val using a V100 for evaluation:

Model original acc original im/s r ToMe acc ToMe im/s Checkpoint
ViT-B/16 83.62 309 16 81.91 603 vit_B_16_r16
ViT-L/16 85.66 93 8 85.09 183 vit_L_16_r8
ViT-H/14 86.88 35 7 86.46 63 vit_H_14_r7

To use the checkpoints, apply the MAE patch (tome.patch.mae) to an MAE model from the official MAE codebase as shown in this example. Pass global_pool=True to the vit mae constructors and use ImageNet mean for data loading. For the models we trained (above checkpoints), we used prop_attn=True when patching with ToMe, but leave that as False for off-the-shelf models. Note that the original models in this table were also trained by us.

As a sanity check, here is our baseline result without training using the off-the-shelf ViT-L model available here as described in Table 1 of the paper:

Model original acc original im/s r ToMe acc ToMe im/s
ViT-L/16 85.96 93 8 84.22 183

License and Contributing

Please refer to the CC-BY-NC 4.0. For contributing, see contributing and the code of conduct.

Citation

If you use ToMe or this repository in your work, please cite:

@inproceedings{bolya2022tome,
  title={Token Merging: Your {ViT} but Faster},
  author={Bolya, Daniel and Fu, Cheng-Yang and Dai, Xiaoliang and Zhang, Peizhao and Feichtenhofer, Christoph and Hoffman, Judy},
  booktitle={International Conference on Learning Representations},
  year={2023}
}

tome's People

Contributors

chengyangfu avatar dbolya avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tome's Issues

Some doubts about Accuracy

Thank you for your excellent work!

But when I apply ToMe into DeiT-S without training, I found the Acc is 78.826%, which is lower than the 79.4% as Table 4 reported. Do you know the gap?

AttributeError: 'ToMeBlock' object has no attribute 'drop_path'

File "/tome/tome/patch/timm.py", line 35, in forward x = x + self.drop_path_rate(x_attn) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1208, in __getattr__ type(self).__name__, name)) AttributeError: 'ToMeBlock' object has no attribute 'drop_path_rate'

Video Classification off-the-shelf model

Hi,

Is it possible to share your baseline ViT-L model used for the Kinetics-400 experiments?
The official MAE-ST repo does not provide any fine-tuned checkpoints, and it is not feasible for me to go through the fine-tuning process at the current time.

Best,
Joakim

About using ToMe in ImageEncoderViT within segment anything

Hello Author,

Thank you for your contributions. I am currently looking to optimize the ImageEncoderViT method from “Segment Anything” using your token merging method, but I have encountered two issues:

  1. I noticed that the Block in the ImageEncoderViT uses windowed attention, and the shape of the tokens is (B, W, H, C), such as (1, 64, 64, 1280) for vit_h. This dimensionality cannot be processed by bipartite soft matching. I am considering whether merging W and H directly for computation would work. Do you have a better suggestion?
  2. The implementation of ToMe can reduce the number of tokens by about 98%, which changes the final feature shape. In the Image Encoder ViT, there are two Conv2d operations at the end, and after the token shape is changed, it cannot undergo convolution operations. I am wondering if adding a shape-expanding operation at this point would be feasible?

Thank you for your help.

I keep getting this issue, how do I fix it?

raise NotImplementedError(f"Invalid/unsupported device. Expected cpu, cuda, or mps, got {device.type}.")
NotImplementedError: Invalid/unsupported device. Expected cpu, cuda, or mps, got privateuseone.

End-to-end inference doesn't be accelerated.

Hi, thanks for your excellent work!
I'm quite interested in your approach to speedup ViT's throughput. However, when I implement ViT-B end-to-end inference (including data Input, preprocessing, and model inference), the processing time is the same whether using ToMe or not. I even tried using different batch_size to fill the GPU memory, but the results are still the same.
Here's the result:
- device: each row using a RTX3090 GPU
- dataset: ImageNet-1k validation set
end-to-end_result

For every test case, I only change the model or batch_size. Other components for data Input, preprocessing.... are the same. (the same device and code)

My question is why the "Total Inference Time" of models with ToMe are similar to baseline (No ToMe)? Didn't throughput mean the efficiency for model inference? Even if I didn't optimize the code for data input and data preprocessing, the "Total Inference Time" still should smaller than the baseline because the ToMe can speed up the time spent in model inference.
Did I misunderstand something?

Does not support swin-transformer?

Hello,
Thanks for this amazing work!
When I use tome, I found it doesn't seem to support swin-transformer.I want to know why it is not supported and whether it be supported in the future?

I am looking forward to your reply!

Training Speedup for ToMe with Video

Is the finetune training time reported in Table 6 for the same number of epochs, or total wall clock time to convergence? I don't observe a noticeable reduction in training speed per iteration, however I can replicate the 2x inference speedup when r=65.

`ToMeTransformer` object has no attribute 'dist_token'

I was trying out the timm_validation.ipynb from the examples directory on my Colab. I load the ViT Base model and uploaded an image. But when i tried patching my model with ToMe it gives out an error

This is the stacktrace

AttributeError                            Traceback (most recent call last)
[<ipython-input-11-6cb8d2faecdd>](https://localhost:8080/#) in <module>
      1 # Patch Model to ToMe
----> 2 tome.patch.timm(model)

1 frames
[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in __getattr__(self, name)
   1206                 return modules[name]
   1207         raise AttributeError("'{}' object has no attribute '{}'".format(
-> 1208             type(self).__name__, name))
   1209 
   1210     def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:

AttributeError: 'ToMeVisionTransformer' object has no attribute 'dist_token'

How to implement ToMe to image encoders in CLIP model?

I tried to implement ToMe into the image encoder in CLIP model. However, the ViT in CLIP uses nn.MultiheadAttention, which I couldn't modify the forward process. I wonder if you have any ideas on how to implement ToMe to original CLIP models? Thanks!

Used in the vanilla Transformer

Hi! Thanks for the amazing work! In the paper, ToMe is only used in ViT. I am wandering that if ToMe can be applied to the vanilla Transformer. In that case, I guess it is similar to set the patch-size to be 1. Have you tried something similar and please correct me if I said something wrong. Thanks!

Is there any way to generate heatmap with ToMe

Hello,
Thanks for this amazing work!

I'm wondering if you've tried generating heatmap of the attention weight before, like Grad-CAM. For example, after I get the attention weight of each attention block and the corresponding source matrix, is there any way to generate heatmap with ToMe like the following image:
image

Looking forward to your reply :)

Best

About keeping token order

Hi, thank you for your work. Now I am trying it to bert as token merge. I want to keep the token order during merge, but I meet a problem since my programming skill is poor...Can you give me some hints to do so? Thank you.

How exactly tokens are reduced while there is no change in your model dimensions before and after tome.patch

Hi , I understood that after tokens come from the attention module, you feed them into a ToMe Block. Afterwards, number of tokens become N-r . But if you haven't changed dimension of input of MLP(that comes after ToMe block) to N-r, How exactly you can claim that you are reducing tokens.
And how this can help increasing throughput if there are no changes in dimension of your model after and before modifying it?

hyper-parm "r" in ViT-g/14

Hi! In the paper, ViT-L (24 layers) with r=8, do you have any suggestion for r in ViT-g (40 layers)? Thank you!

Adding support for HuggingFace vision Transformers

Hi,

Thanks for this great work!

In 🤗 Transformers, we support the Vision Transformer (ViT) - among many other models like MAE, BEiT, ConvNeXt, Swin Transformer, Swin Transformer v2, etc. Recent additions also include Transformer-based video models, like VideoMAE and X-CLIP.

As can be seen on the hub, the 2 most popular ViT models have +500k and +300k downloads respectively the last month. Would be great if people can leverage this speed up in performance! An increase in throughtput would be very beneficial for people putting these algorithms in production.

As models in the Transformers library are implemented very independently (we duplicate code rather than inheriting, for the sake of readability + independence among the models), we could add ToMe as a separate model in the library.

Let me know your thoughts :)

Best,

Niels
ML Engineer @ HuggingFace

How can I use ToMe under pytorch 1.10.1

Hello,
Thanks for this amazing work!

I know that I need Pytorch >= 1.12.0 when using ToMe because of the scatter_reduce function. However, for some reason, I can't use Pytorch > 1.10.1 in my development environment. I hope to apply ToMe to my vit-like model, so is it possible for me to use ToMe under pytorch 1.10.1?

Moreover, I'm wondering is there any elegant way that I can restore the original order of the tokens which disrupted by ToMe module? For example, how can I use the merged tokens and source matrix to fill the missing tokens?

Looking forward to your reply :)

Best

Does ToMe work for focal modulation networks?

any help on modifying ToMe for focal modulation networks?
I guess in FMN we could apply tome on Q/M. Also it has downsampling layers in each stage, so r value changes each stage and model definition?

ToMeBlock cannot be used with torch.utils.checkpoint

When using relatively smaller VIT model like VIT_TI, we do not need to using torch.utils.checkpoint. But for VIT_L or VIT_H,it is necessary to use torch.utils.checkpoint for saving a lot of GPU memory.

the forward_features in the origin timm version VisionTransformer is as following:

def forward_features(self, x):
    x = self.patch_embed(x)
    cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    if self.dist_token is None:
        x = torch.cat((cls_token, x), dim=1)
    else:
        x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
    x = self.pos_drop(x + self.pos_embed)
    x = self.blocks(x)
    x = self.norm(x)
    if self.dist_token is None:
        return self.pre_logits(x[:, 0])
    else:
        return x[:, 0], x[:, 1]

Then we add torch.utils.checkpoint to it as following, but it caused errors when loss.backward().

def forward_features(self, x):
    x = self.patch_embed(x)
    cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    if self.dist_token is None:
        x = torch.cat((cls_token, x), dim=1)
    else:
        x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
    x = self.pos_drop(x + self.pos_embed)
 # x = self.blocks(x)
    for func in self.blocks:
            if self.using_checkpoint and self.training:
                from torch.utils.checkpoint import checkpoint
                x = checkpoint(func, x)
            else:
                x = func(x)
    x = self.norm(x)
    if self.dist_token is None:
        return self.pre_logits(x[:, 0])
    else:
        return x[:, 0], x[:, 1]

The error is as following. I have verified the error is caused by torch.utils.checkpoint.

 File "/mnt/my_dist/code/classification/backbones/vit_tome.py", line 417, in forward
    attn = attn + size.log()[:, None, None, :, 0]
RuntimeError: The size of tensor a (29) must match the size of tensor b (24) at non-singleton dimension 3

Can anyone help with some solutions?Thanks in advance.

Training with merging

Hi, I want to apply bipartite_soft_matching to my project. During training, you treat the token merging as a pooling operation. Can I directly use bipartite_soft_matching without any change during the training stage?

Does ToMe work for focal modulation networks?

any help on modifying ToMe for focal modulation networks?
I guess in FMN we could apply to me on Q/M. Also it has downsampling layers in each stage, so r value changes each stage and model definition?

AttributeError: 'ToMeVisionTransformer' object has no attribute 'norm'

Hi,

This is the simple but great idea on improving performance of ViT. But I got an error while I'm exploring the code.

I tried to reproduce the MAE evaluation results on Imagenet-1k with vit_base_patch16 using official code of MAE , while I apply tome patch to models_vit.py as mentioned in given example, but it gave error as model is none. So I changed from this,

model = tome.patch.timm(model, prop_attn=False)

to this,

tome.patch.timm(model, prop_attn=False)

I have done this because I already worked with this code on given benchmark. But it gave the new error

File "/nfs/users/ext_vignagajan.vigneswaran/ToMe/experiments/models/mae/main_finetune.py", line 360, in
main(args)
File "/nfs/users/ext_vignagajan.vigneswaran/ToMe/experiments/models/mae/main_finetune.py", line 307, in main
test_stats = evaluate(data_loader_val, model, device)
File "/nfs/users/ext_vignagajan.vigneswaran/miniconda3/envs/tome/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/nfs/users/ext_vignagajan.vigneswaran/ToMe/experiments/models/mae/engine_finetune.py", line 118, in evaluate
output = model(images)
File "/nfs/users/ext_vignagajan.vigneswaran/miniconda3/envs/tome/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/nfs/users/ext_vignagajan.vigneswaran/ToMe/tome/patch/timm.py", line 114, in forward
return super().forward(*args, **kwdargs)
File "/nfs/users/ext_vignagajan.vigneswaran/miniconda3/envs/tome/lib/python3.10/site-packages/timm/models/vision_transformer.py", line 347, in forward
x = self.forward_features(x)
File "/nfs/users/ext_vignagajan.vigneswaran/miniconda3/envs/tome/lib/python3.10/site-packages/timm/models/vision_transformer.py", line 340, in forward_features
x = self.norm(x)
File "/nfs/users/ext_vignagajan.vigneswaran/miniconda3/envs/tome/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1207, in getattr
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'ToMeVisionTransformer' object has no attribute 'norm'

I explored the code but I didn't able to fix it. Can you help me through this to fix that error?

Stable Diffusion

This work looks incredible. Is there an ETA on you releasing the SD patch, would we be reinventing the wheel to attempt to add the patch ourselves / would you accept a PR or is there a coming plan for that?

eucl distance function in Tab. 1b

Hello,
Thank you for your interesting work!

I want to know the implementation of eucl distance function in Tab. 1b. I try to use scores = torch.cdist(a,b) to calculate the eucl distance, but the top-1 accuracy on ViT-L/16 with r=8 is only 61.50 by my implementation.

Thanks!

About Token Merging Order

Thank you very much for working open source.
Similar to issue#9, I would now like to be able to keep the original order as each block merges the token.
However, as I look at your code implementation, I realize that you are splicing set A and B directly (sequentially,return torch.cat([unm, dst], dim=1) ), rather than alternating the choices.
I would like to know, why is this? And how can I modify the implementation to keep the original order of the Token (the merged one can be put into set B).

How to restore it to its original shape

Please tell me, after performing the merge, the number of tokens is reduced, but I need to restore the image patch to its original size to perform the next step.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.