Coder Social home page Coder Social logo

fptrans's Introduction

FPTrans: Feature-Proxy Transformer for Few-Shot Segmentation

Jian-Wei Zhang, Yifan Sun, Yi Yang, Wei Chen

[arXiv][Bibtex]

This repository is the PyTorch Implementation. One can find the PaddlePaddle implementation from here.

Framework

Installation

Create a virtual environment and install the required packages.

conda create -n fptrans python=3.9.7
conda activate fptrans
conda install numpy=1.21.2
conda install pytorch==1.10.0 torchvision==0.11.1 cudatoolkit=11.3 -c pytorch
conda install tqdm scipy pyyaml
pip install git+https://github.com/IDSIA/[email protected]
pip install dropblock pycocotools opencv-python

Put following bash function in ~/.bashrc for simplifying the CUDA_VISIBLE_DEVICES.

function cuda()
{
    if [ "$#" -eq 0 ]; then
        return
    fi
    GPU_ID=$1
    shift 1
    CUDA_VISIBLE_DEVICES="$GPU_ID" $@
}

Now we can use cuda 0 python for single GPU and cuda 0,1 python for multiple GPUs.

Getting Started

See Preparing Datasets and Pretrained Backbones for FPTrans

Usage for inference with our pretrained models

Download the checkpoints of our pretrained FPTrans from GoogleDrive or BaiduDrive (Code: FPTr), and put the pretrained models (the numbered folders) into ./output/.

Datasets Backbone #Shots Experiment ID (Split 0 - Split 3)
PASCAL-5i ViT-B/16 1-shot 1,2,3,4
DeiT-B/16 1-shot 5,6,7,8
DeiT-S/16 1-shot 9,10,11,12
DeiT-T/16 1-shot 13,14,15,16
ViT-B/16 5-shot 17,18,19,20
DeiT-B/16 5-shot 21,22,23,24
COCO-20i ViT-B/16 1-shot 25,26,27,28
DeiT-B/16 1-shot 29,30,31,32
ViT-B/16 5-shot 33,34,35,36
DeiT-B/16 5-shot 37,38,39,40

Run the test command:

# PASCAL ViT 1shot
cuda 0 python run.py test with configs/pascal_vit.yml exp_id=1 split=0

# PASCAL ViT 5shot
cuda 0 python run.py test with configs/pascal_vit.yml exp_id=17 split=0 shot=5

# COCO to PASCAL 1shot (cross domain, no need for training, just test)
# Load model trained from COCO, test on PASCAL 
# Notice: the code will use different splits from PASCAL-5i to avoid test 
#         classes (PASCAL) existed in training datasets (COCO).
cuda 0 python run.py test with configs/coco2pascal_vit.yml exp_id=29 split=0

Usage for training from scratch

Run the train command (adjust batch size bs for adapting the GPU memory):

# PASCAL 1shot
cuda 0 python run.py train with split=0 configs/pascal_vit.yml

# PASCAL 5shot
cuda 0,1 python run.py train with split=0 configs/pascal_vit.yml shot=5

# COCO 1shot
cuda 0,1 python run.py train with split=0 configs/coco_vit.yml

# COCO 5shot
cuda 0,1,2,3 python run.py train with split=0 configs/coco_vit.yml shot=5 bs=8

Optional arguments:

  • -i <Number>: Specify the experiment id. Default is incremental numbers in the ./output directory (or MongoDB if used).
  • -p: print configurations
  • -u: Run command without saving experiment details. (used for debug)

Please refer to Sacred Documentation for complete command line interface.

Performance

  • Results on PASCAL-5i
Backbone Method 1-shot 5-shot
ResNet-50 HSNet 64.0 69.5
BAM 67.8 70.9
ViT-B/16-384 FPTrans 64.7 73.7
DeiT-T/16 FPTrans 59.7 68.2
DeiT-S/16 FPTrans 65.3 74.2
DeiT-B/16-384 FPTrans 68.8 78.0
  • Results on COCO-20i
Backbone Method 1-shot 5-shot
ResNet-50 HSNet 39.2 46.9
BAM 46.2 51.2
ViT-B/16-384 FPTrans 42.0 53.8
DeiT-B/16-384 FPTrans 47.0 58.9

Notice that the results are obtained on NVIDIA A100/V100 platform. We find that the results may have a few fluctuation on NVIDIA GeForce 3090 with exactly the same model and environment.

Citing FPTrans

@inproceedings{zhang2022FPTrans,
  title={Feature-Proxy Transformer for Few-Shot Segmentation},
  author={Jian-Wei Zhang, Yifan Sun, Yi Yang, Wei Chen},
  journal={NeurIPS},
  year={2022}
}

fptrans's People

Contributors

jarvis73 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

fptrans's Issues

run error

I follow your steps to precompute cross-entropy weights, when I run: cuda 0 python tools.py precompute_loss_weights with dataset=PASCAL dry_run=True
This error occurred: sacred.utils.ConfigAddedError: Added new config entry that is not used anywhere
Conflicting configuration values:
dry_run=True
And then I run: cuda 0 python tools.py precompute_loss_weights with dataset=PASCAL
This error occurred: ValueError: Can not open file /home/gc/deskbook/FPTrans-main/data/VOCdevkit/VOC2012/SegmentationClassAug/2008_000008.png.
Actually,the image 2008_000008.png does not exist in the VOC dataset, I don't know what the problem is and hope you can give me an answer.

Why is the feature extractor not fixed?

Hi,

Thanks for your great job! I have a question why is the feature extractor not fixed, which is different from the most common FSS methods?

Regards,
MS

SBD Dataset

Congratulations for the good work! Where can I download the sbd dataset?

关于代码promt generation的问题

您好!我看了您的文章和代码,现在有一点疑惑的是:
3.3.2节Prompt Generation中,文中说在生成了C-维的平均向量后会将其扩展到G×C维的和token一样的东西,然后和token进行拼接我好奇这里G是代表什么意思,以及我在您的代码中好像没有看到有这一步操作...

code

Dear author, I have recently read your work FPTrans. I am very interested in your work and appreciate your contribution. I would like to ask when the relevant code will be released?

zlib.error

I follow your steps training from scratch.
However, the following errors always occur.
error
Can you give me some suggestions?
Thanks a lot!

zlib error

When I followed your steps to configure the environment and run the code, the following error occurred:

Traceback (most recent call last):
  File "/home/wj/.local/lib/python3.8/site-packages/sacred/experiment.py", line 312, in run_commandline
    return self.run(
  File "/home/wj/.local/lib/python3.8/site-packages/sacred/experiment.py", line 276, in run
    run()
  File "/home/wj/.local/lib/python3.8/site-packages/sacred/run.py", line 238, in __call__
    self.result = self.main_function(*args)
  File "/home/wj/.local/lib/python3.8/site-packages/sacred/config/captured_function.py", line 42, in captured_function
    result = wrapped(*args, **kwargs)
  File "run.py", line 99, in train
    trainer.start_training_loop(start_epoch, evaluator, num_classes)
  File "/home/wj/code/FPTrans/core/base_trainer.py", line 281, in start_training_loop
    for i, batch in enumerate(gen, start=1):
  File "/home/wj/.local/lib/python3.8/site-packages/tqdm/std.py", line 1195, in __iter__
    for obj in iterable:
  File "/home/wj/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
    data = self._next_data()
  File "/home/wj/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1376, in _next_data
    return self._process_data(data)
  File "/home/wj/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1402, in _process_data
    data.reraise()
  File "/home/wj/.local/lib/python3.8/site-packages/torch/_utils.py", line 461, in reraise
    raise exception
zlib.error: Caught error in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/wj/.local/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/wj/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/wj/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/wj/code/FPTrans/data_kits/voc_coco.py", line 298, in __getitem__
    kwargs['weights'] = self.get_weights(weight_path, cls)
  File "/home/wj/code/FPTrans/data_kits/voc_coco.py", line 263, in get_weights
    cache_weights[name] = load_weights(data_dir / name)
  File "/home/wj/code/FPTrans/utils_/misc.py", line 36, in load_weights
    edts = npzfile['x']
  File "/home/wj/.local/lib/python3.8/site-packages/numpy/lib/npyio.py", line 249, in __getitem__
    magic = bytes.read(len(format.MAGIC_PREFIX))
  File "/usr/lib/python3.8/zipfile.py", line 940, in read
    data = self._read1(n)
  File "/usr/lib/python3.8/zipfile.py", line 1016, in _read1
    data = self._decompressor.decompress(data, n)
zlib.error: Error -3 while decompressing data: invalid distance too far back


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "run.py", line 193, in <module>
    ex.run_commandline()
  File "/home/wj/.local/lib/python3.8/site-packages/sacred/experiment.py", line 347, in run_commandline
    print_filtered_stacktrace()
  File "/home/wj/.local/lib/python3.8/site-packages/sacred/utils.py", line 493, in print_filtered_stacktrace
    print(format_filtered_stacktrace(filter_traceback), file=sys.stderr)
  File "/home/wj/.local/lib/python3.8/site-packages/sacred/utils.py", line 528, in format_filtered_stacktrace
    return "".join(filtered_traceback_format(tb_exception))
  File "/home/wj/.local/lib/python3.8/site-packages/sacred/utils.py", line 568, in filtered_traceback_format
    current_tb = tb_exception.exc_traceback
AttributeError: 'TracebackException' object has no attribute 'exc_traceback'

How to solve this problem? The same problem still occurs after I use command 'pip install --upgrade setuptools' to update.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.