Coder Social home page Coder Social logo

diff-plugin's Introduction


arXiv PDF        
Yuhao Liu · Zhanghan Ke . Fang Liu . Nanxuan Zhao · Rynson W.H. Lau


Code repository for our paper "Diff-Plugin: Revitalizing Details for Diffusion-based Low-level tasks" in CVPR 2024.


Diff-Plugin introduces a novel framework that empowers a single pre-trained diffusion model to produce high-fidelity results across a variety of low-level tasks. It achieves this through lightweight plugins, without compromising the generative capabilities of the original pre-trained model. Its plugin selector also allows users to perform low-level visual tasks (either single or multiple) by selecting the appropriate task plugin via language instructions.

teaser

News !!

  • Test and Train Codes and Models are released!!

To-Do List

  • Provide training scripts for the plugin selector.

Environment Setup

To set up your environment, follow these steps:

conda create --name DiffPlugin python=3.8
conda activate DiffPlugin
conda install --file requirements.txt

Inference

Run the following command to apply task plugins. Results are saved in the temp_results folder.

bash infer.sh

Currently, we provide eight task plugins: derain, desnow, dehaze, demoire, deblur, highlight removal, lowlight enhancement and blind face restoration.

Before beginning testing, please specify your desired task on the first line.

Gradio Demo

Single GPU with 24G memory is generally enough for an image with resolution of 1024*720.

python demo.py 

Note that in the advanced options, you can adjust

  • the image resolution in the format of width==height for flexible outputs;
  • the diffusion steps for faster speed (by default is 20);
  • the random seed to generate different results.

Train your own task-plugin

Step-1: construct your own dataset

  1. Navigate to your training data directory:
cd data/train
touch task_name.csv

Arrange your data according to the format in example.csv, using relative paths for the samples specified in the train.sh script.

Note that we leave the data root for filling in the train.sh file and only support the relative path for the used samples.

Example for a sample named a1:

  • Input: /user/Dataset/task1/input/a1.jpg
  • GT: /user/Dataset/task1/GT/a1.jpg

Your task_name.csv should contain lines in the following format:

task1/input/a1.jpg,task1/GT/a1.jpg

Step-2: start training

Replace task name with your task, and execute the training script with:

bash train.sh

Note that:

  • For single GPU training, use python xxx; for multiple GPUs, use accelerate xxx.
  • Ensure the task name matches the folder name in data/train
  • Training was conducted on four A100 GPUs with a batch size of 64

Evaluation

Please refer to the README.md under the metric fodler

License

This project is released under the Creative Commons Attribution NonCommercial ShareAlike 4.0 license.

Citation

If you find Diff-Plugin useful in your research, please consider citing us:

@inproceedings{liu2024diff,
  title={Diff-Plugin: Revitalizing Details for Diffusion-based Low-level Tasks},
  author={Liu, Yuhao and Ke, Zhanghan  and Liu, Fang and Zhao, Nanxuan and Rynson W.H. Lau},
  booktitle={CVPR},
  year={2024}
}

Acknowledgements

Special thanks to the following repositories for supporting our research:

Contact

This repository is maintained by Yuhao Liu (@Yuhao).
For questions, please contact [email protected].

diff-plugin's People

Contributors

yuhaoliu7456 avatar fawnliu avatar

Stargazers

Defendy avatar lucas gelfond avatar  avatar  avatar Tang avatar 小三爷我大胆地往前走莫回头 avatar  avatar Haizhen avatar  avatar  avatar  avatar  avatar H.C JIN avatar  avatar eezywu avatar  avatar  avatar Jumponthemoon avatar Liang Zichao avatar sanshui avatar  avatar Zhou Shijun avatar 0xhephaistos avatar Geralt Yang avatar Chen Xiaoyu avatar Moein Heidari avatar Tianle.L avatar  avatar ozkan avatar Bin Chen avatar wang feng avatar  avatar Rongyuan Wu avatar Wanglong Lu avatar  avatar Huankang Guan avatar  avatar Yipo Huang avatar Tianhe Wu avatar Shi Guo avatar Chaofeng Chen avatar Weixia Zhang avatar DwanZhang avatar Howe avatar  avatar  avatar Tianyu Ding avatar  avatar  avatar Jiarong Hong avatar  avatar  avatar Yunlong Lin avatar Faych Chen avatar  avatar Coffee  avatar Tengyu Ma avatar  avatar Jing Li avatar  avatar  avatar Justin Kai avatar Aurora avatar ShengguangZhou avatar  avatar hanban avatar  avatar  avatar Zixuan Chen avatar SparkyChen avatar  avatar hanruisong avatar Junyi Zhang avatar Haotian Xue avatar Y. F. avatar Xueqiang Fan avatar fwzhuang avatar  avatar chrisCC avatar YiZhang avatar hak-kyoung.kim avatar Yang Hai avatar Licong Guan avatar Luke Liu avatar Alireza Hosseini avatar fengzhihui avatar  avatar Kaidong Zhang avatar  avatar  avatar jiang xingbo avatar  avatar sun avatar Feifan Cai avatar WBC-ML avatar 艾梦 avatar gaozhihan avatar Zhanghan Ke avatar  avatar  avatar

Watchers

 avatar Kostas Georgiou avatar  avatar

diff-plugin's Issues

mutil gpu train failed

Hi,I train with 4A100 with batch 64 ,got OOM,maybe there is some bug in train.py with mutil gpu

how to train with SDXL

Great work!
Have you tried different SD models? For example, SD2.0, SDXL, etc. What modifications need to be made to update the SD model?

Inverse generation

Hi,

Much appreciate your effort in creating this interesting work!

I'm curious about the inverse generation of rain and snow in your paper. Would you please give some hints on how to inverse generate those effect?

some questions

Very interesting job!!! I have some questions about your work.
The CLIP vision encoder and VAE encoder in the task-plugin are both pre-trained, I wonder if I can retrain these two encoders, because I plan to use them for denoising tasks in other fields, looking forward to the author's reply!

Paper in details

Very interesting work!!!
Q1. I want to ask the section 3.3. Task-Plugin, you want to get task-specific visual guidance prior and spatial prior to learn the part we want to remove and preserve the image, but the training loss initial latent is encoded by ground truch image which is "clean" image? How to let the TPB&SCB learn to clean the unclean image?
Q2. I want to check that my understanding about the overall pipeline. So I give a image and text that I want to do desnow, the image will feed to task-plugin to get two prior, and then using plugin-selector to compute similarity between the text. And choose the prior which similarity exceeds the threshold, feed the task-specific visual guidance prior to cross-attn, spatial prior to final stage of decoder?
Thx!!!

Colorization model

Thank you for your great work! Could you give me the pretrained plugin for colorization task?

how to resume?

I try to use the --resume_from_checkpoint "savepath/checkpointxxxx", but get errors.

I found that the checkpoints only save the scb_net and tpb_net, but the code related the "--resume_from_checkpoint" looks like it needs to load the original SD1.4 model pytorch_model.bin.

I wonder if the resume for train.py is valid?

SCB dim problem

hello! the work is very interesting so i want to reproduce it.
when I train the model, I meet a dim proplem in SCB module.
here is the error code
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

then the error info is:
/data/***/anaconda3/envs/dewarp/lib/python3.9/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 1290, in forward
hidden_states = hidden_states + additional_residuals
RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 3

so can you help me solve the problem?

snow testset

Hi can you provide the testset about the desnow task. I can not download from the website. thx

about document image

哈喽,你的工作真的很有意思,目前diffusion的工作真的很有意思,也很有效果,正如论文中所说,你的方法可以让diffusion model控制生成,从而保留原来的图像信息,那么你认为对于文档图像会如何呢?因为文档图像对比自然图像来说,有更多的高频特征,文字都是需要保留的,因此可以对生成空间做更多的约束。我在文本图像恢复的数据上用你的方法训练了一版,真的感觉效果很糟糕。
Uploading 屏幕截图 2024-05-27 142644.png…

requirement文件安装失败

PackagesNotFoundError: The following packages are not available from current channels.一个都安装不成功,你们怎么解决的

bash train.sh get an error

─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/notebook/code/personal/80239864/Diff-Plugin/train.py:430 in │
│ │
│ 427 if name == "main": │
│ 428 │ args = parse_args() │
│ 429 │ │
│ ❱ 430 │ main(args) │
│ 431 │
│ │
│ /home/notebook/code/personal/80239864/Diff-Plugin/train.py:160 in main │
│ │
│ 157 │ backup_unet.config.down_block_types = args.down_block_types │
│ 158 │ backup_unet.config.block_out_channels = args.block_out_channels │
│ 159 │ # ------- │
│ ❱ 160 │ scb_net = SCBNet.from_unet(backup_unet, load_weights_from_unet=args.load_weights_fro │
│ 161 │ │
│ 162 │ vae.requires_grad_(False) │
│ 163 │ unet.requires_grad_(False) │
│ │
│ /home/notebook/code/personal/80239864/Diff-Plugin/modules/SCB.py:192 in from_unet │
│ │
│ 189 │ │ │ │ UNet model which weights are copied to the ControlNet. Note that all con │
│ 190 │ │ │ │ copied where applicable. │
│ 191 │ │ """ │
│ ❱ 192 │ │ controlnet = cls( │
│ 193 │ │ │ in_channels=unet.config.in_channels, │
│ 194 │ │ │ flip_sin_to_cos=unet.config.flip_sin_to_cos, │
│ 195 │ │ │ freq_shift=unet.config.freq_shift, │
│ │
│ /opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/configuration_utils.py:636 in │
│ inner_init │
│ │
│ 633 │ │ │
│ 634 │ │ new_kwargs = {**config_init_kwargs, **new_kwargs} │
│ 635 │ │ getattr(self, "register_to_config")(**new_kwargs) │
│ ❱ 636 │ │ init(self, *args, **init_kwargs) │
│ 637 │ │
│ 638 │ return inner_init │
│ 639 │
│ │
│ /home/notebook/code/personal/80239864/Diff-Plugin/modules/SCB.py:145 in init
│ │
│ 142 │ │ │ output_channel = block_out_channels[i] │
│ 143 │ │ │ is_final_block = i == len(block_out_channels) - 1 │
│ 144 │ │ │ │
│ ❱ 145 │ │ │ down_block = get_down_block( │
│ 146 │ │ │ │ down_block_type, │
│ 147 │ │ │ │ num_layers=layers_per_block, │
│ 148 │ │ │ │ in_channels=input_channel, │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: get_down_block() got an unexpected keyword argument 'attn_num_head_channels'

Demoire dataset

Hello. Can I get the LCD Moire dataset from you? There is no way to download it now.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.