Coder Social home page Coder Social logo

1-8op / sats_continual_semantic_seg Goto Github PK

View Code? Open in Web Editor NEW

This project forked from qiu023/sats_continual_semantic_seg

0.0 0.0 0.0 2.74 MB

Official Code of SATS: Self-Attention Transfer for Continual Semantic Segmentation

Shell 14.63% C++ 15.98% Python 53.23% C 0.01% CSS 0.08% Cuda 15.98% Makefile 0.03% HTML 0.04% Dockerfile 0.03%

sats_continual_semantic_seg's Introduction

Self-Attention Transfer for Continual Semantic Segmentation

Paper Journal

SATS: Self-Attention Transfer for Continual Semantic Segmentation is a new SOTA method of Continual Semantic Segmentation, which utilize the unique intra-class and inter-class relationship information in advanced Transformer segmentation model.

Our SATS has been published in Pattern Recognition Journal, if you have used or mentioned SATS method, code, and other relevant contents in SATS, please remember to cite SATS.

@article{QIU2023109383,
title = {SATS: Self-Attention Transfer for Continual Semantic Segmentation},
journal = {Pattern Recognition},
pages = {109383},
year = {2023},
volume = {138},
issn = {0031-3203},
doi = {https://doi.org/10.1016/j.patcog.2023.109383},
url = {https://www.sciencedirect.com/science/article/pii/S0031320323000845},
author = {Yiqiao Qiu and Yixing Shen and Zhuohao Sun and Yanchong Zheng and Xiaobin Chang and Weishi Zheng and Ruixuan Wang},
keywords = {Continual learning, Semantic segmentation, Self-attention transfer, Class-specific region pooling},
}

Results on VOC

Vizualization on VOC 15-1

This repository contains all of our code. It is a modified version of Douillard et al.'s repository.

PWC PWC PWC PWC

PWC PWC

Requirements

You need to at least install the following libraries:

  • Python (3.7)
  • Pytorch (1.8.1+cu102)
  • torchvision (0.9.1+cu102)
  • tensorboardX (1.8)
  • apex (0.1)
  • matplotlib (3.3.1)
  • numpy (1.17.2)
  • mmcv
  • timm
  • MulticoreTSNE
  • einops

When you need to set up environment in a new machine, you can use my script file: utils.sh to build up basic environment

bash utils.sh

complete library used can be found in requirements.txt

Note also that apex seems to only work with some CUDA versions, therefore try to install Pytorch (and torchvision) with the 10.2 CUDA version. You'll probably need anaconda instead of pip in that case, sorry! Do:

conda install -y pytorch torchvision cudatoolkit=10.2 -c pytorch
cd apex
pip3 install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Note that while the code should be runnable without mixed precision (apex), some have reported lower perfs and slower training speed without it. So try with it!

How to perform training

The most important file is run.py, that is in charge to start the training or test procedure. To run it, simply use the following command:

python -m torch.distributed.launch --nproc_per_node=<num_GPUs> run.py --data_root <data_folder> --name <exp_name> .. other args ..

There are many options (you can see them all by using --help option), but we arranged the code to being straightforward to test the reported methods. Leaving all the default parameters, you can replicate the experiments by setting the following options.

  • please specify the data folder using: --data_root <data_root>
  • dataset: --dataset voc (Pascal-VOC 2012) | ade (ADE20K)
  • task: --task <task>, where tasks are
    • 15-5, 15-5s, 19-1 (VOC), 100-10, 25-25, 100-50, 50 (ADE)
  • step (each step is run separately): --step <N>, where N is the step number, starting from 0
  • disjoint is default setup, to enable overlapped: --overlapped (all experiments in SATS are in overlapped)
  • learning rate: --lr 0.01 (for step 0) | 0.001 (for step > 0)
  • batch size: --batch_size <24/num_GPUs>
  • epochs: --epochs 30 (Pascal-VOC 2012) | 60 (ADE20K)
  • method: --method <method name>, where names are
    • FT, LWF, LWF-MC, ILT, EWC, RW, PI, MIB, PLOP, SSUL, SATS

For all details please follow the information provided using the help option.

You can use one of the provided script that will launch every step of a continual training, for all appeared experiments in SATS in directory: scripts/voc or scripts/ade

do

bash scripts/voc/sats_segformer_15-1.sh

In these scripts, I have commented the first line of the training command line, because all continual learning methods share the same initially trained models (in case the hyper-parameters are the same). If you want to modify the hyper-parameters and retrain the first step of the continual learning, just cancal the comment of that command line.

Once you trained the model, you can see the result on tensorboard (we perform the test after the whole training) or you can test it by using the same script and parameters but using the command

--test

that will skip all the training procedure and test the model on test data, remember to add the

--ckpt /path/to/model/checkpoint/files

Note that you will need to modify those scripts to include the path of checkpoint/data/pretrained dirs.

There are some intermediate directorys stored useful consistent data will be generated when run the training for the first time. Like the plop_median, which stores the median value of the pseudo labels in PLOP, balance_step_exemplar stores the reimplemented SSUL exemplar image indexes of a dataset, data stores the overlapped separated voc/ade dataset image indexes of a dataset.

sats_continual_semantic_seg's People

Contributors

qiu023 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.