Coder Social home page Coder Social logo

modeltc / tfmq-dm Goto Github PK

View Code? Open in Web Editor NEW
52.0 10.0 3.0 120.52 MB

[CVPR 2024 Highlight] This is the official PyTorch implementation of "TFMQ-DM: Temporal Feature Maintenance Quantization for Diffusion Models".

Home Page: https://modeltc.github.io/TFMQ-DM/

License: Apache License 2.0

Python 15.58% Shell 0.06% Jupyter Notebook 84.36%
ddim diffusion-models ldm post-training-quantization stable-diffusion cvpr cvpr2024 quantization highlight

tfmq-dm's Introduction

TFMQ-DM: Temporal Feature Maintenance Quantization for Diffusion Models

License arXiv GitHub Stars

[paper][slides][poster][project page]

Yushi Huang*, Ruihao Gong*, Jing Liu, Tianlong Chen, Xianglong LiuπŸ“§

(* denotes equal contribution, πŸ“§ denotes corresponding author.)

This is the official implementation of our paper TFMQ-DM, a novel training-free framework that achieves a new state-of-the-art result in PTQ of diffusion models, especially under 4-bit weight quantization, and significantly accelerates quantization time. For some hard tasks, e.g., CelebA-HQ, our method reduces the FID score by 6.7.

(Left) Images generated by w4a8 quantized and full-precision Stable Diffusion with MS-COCO captions. (Right) Random samples from w4a8 quantized and full-precision LDM-4 on CelebA-HQ. These results show that our TFMQ-DM outperforms the previous PTQ methods. More qualitative and quantitative results can be found in our paper.

News

  • Apr 5, 2024: 🌟 Our paper has been selected as a Highlight Poster at CVPR 2024 (top 2.8%)! πŸŽ‰ Cheers!

  • Mar 31, 2024: πŸ”₯ We release our Python code for quantizing all the models presented in our paper. Have a try!

  • Feb 27, 2024: 🌟 Our paper has been accepted by CVPR 2024! πŸŽ‰ Cheers!

Overview

overview

The Diffusion model, a prevalent framework for image generation, encounters significant challenges in terms of broad applicability due to its extended inference times and substantial memory requirements. Efficient Post-training Quantization (PTQ) is pivotal for addressing these issues in traditional models. Different from traditional models, diffusion models heavily depend on the time-step $t$ to achieve satisfactory multi-round denoising. Usually, $t$ from the finite set { $1, \ldots, T$ } is encoded to a temporal feature by a few modules irrespective of the sampling data. However, existing PTQ methods do not optimize these modules separately. They adopt inappropriate reconstruction targets and complex calibration methods, resulting in a severe disturbance of the temporal feature and denoising trajectory, as well as a low compression efficiency. To solve these, we propose a Temporal Feature Maintenance Quantization (TFMQ) framework building upon a Temporal Information Block which is just related to the time-step $t$ and unrelated to the sampling data. Powered by the pioneering block design, we devise temporal information aware reconstruction (TIAR) and finite set calibration (FSC) to align the full-precision temporal features in a limited time. Equipped with the framework, we can maintain the most temporal information and ensure the end-to-end generation quality. Extensive experiments on various datasets and diffusion models prove our state-of-the-art results.

Quick Started

After cloning the repository, you can follow these steps to complete the model's quantization inference process.

Requirements

conda env create -f ./stable-diffusion/environment.yaml
conda activate ldm
pip install -r requirements.txt

If you encounter errors while installing the packages listed in requirements.txt, you can try installing each Python package individually using the pip command.

Before quantization, you need to download the pre-trained weights (preliminary checkpoints):

# ----------- DDIM -----------
# The model execution script will automatically download the required checkpoints during runtime.

# ----------- LDM ------------
cd ./stable-diffusion/
sh ./scripts/download_first_stages.sh
sh ./scripts/download_models.sh
mkdir -p models/ldm/cin256-v2/
wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt
cd ../

# ----- Stable Diffusion -----
wget https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
mkdir -p ./stable-diffusion/models/ldm/stable-diffusion-v1/
mv sd-v1-4.ckpt ./stable-diffusion/models/ldm/stable-diffusion-v1/sd-v1-4.ckpt

Quantization

In this part, we will first generate some calibration data before quantizing. Alternatively, you can generate calibration data separately by adding torch.save in the script, and quantize when needed by adding torch.load and commenting out the code for generating calibration data.

If you want to quantize your diffusion models on multiple GPUs, add --multi_gpu to the corresponding command, except for DDIM. Additionally, you can remove --use_aq,--aq 8 to cancel activation quantization.

Additionally, before quantizing Stable Diffusion, you still need to prepare for some prompts. For example, you can download MS-COCO, and use the path of captions_train*.json as <PATH/TO/LOAD/DATA> in the following command. We would use 128 prompts within captions_train*.json as a part of calibration data.

# ----------- DDIM -----------
python sample_diffusion_ddim.py --config ddim/configs/cifar10.yml --timesteps 100 --eta 0 --skip_type quad --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/LOG> --cali --use_aq --cali_save_path <PATH/TO/SAVE/QUANTIZED/MODEL> --interval_length 5

# ----------- LDM ------------
# LSUN-Bedrooms
python sample_diffusion_ldm.py -r ./stable-diffusion/models/ldm/lsun_beds256/model.ckpt -c 200 -e 1.0 --seed 40 --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/LOG> --cali --use_aq --cali_save_path <PATH/TO/SAVE/QUANTIZED/MODEL> --interval_length 10
# LSUN-Churches
python sample_diffusion_ldm.py -r ./stable-diffusion/models/ldm/lsun_churches256/model.ckpt -c 400 -e 0.0 --seed 40 --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/LOG> --cali --use_aq --cali_save_path <PATH/TO/SAVE/QUANTIZED/MODEL> --interval_length 25
# CelebA-HQ
python sample_diffusion_ldm.py -r ./stable-diffusion/models/ldm/celeba256/model.ckpt -c 200 -e 0.0 --seed 40 --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/LOG> --cali --use_aq --cali_save_path <PATH/TO/SAVE/QUANTIZED/MODEL> --interval_length 10
# FFHQ
python sample_diffusion_ldm.py -r ./stable-diffusion/models/ldm/ffhq256/model.ckpt -c 200 -e 1.0 --seed 40 --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/LOG> --cali --use_aq --cali_save_path <PATH/TO/SAVE/QUANTIZED/MODEL> --interval_length 10
# ImageNet
python latent_imagenet_diffusion.py -e 0.0 --ddim_steps 20 --scale 3.0 --seed 40 --wq <4 OR 8> --ptq --aq 8 --outdir <PATH/TO/SAVE/LOG> --cali --use_aq --cali_save_path <PATH/TO/SAVE/QUANTIZED/MODEL>

# ----- Stable Diffusion -----
python txt2img.py --plms --no_grad_ckpt --ddim_steps 50 --seed 40 --cond --wq <4 OR 8> --ptq --aq 8 --outdir <PATH/TO/SAVE/LOG> --cali --skip_grid --use_aq --ckpt ./stable-diffusion/models/ldm/stable-diffusion-v1/sd-v1-4.ckpt --config stable-diffusion/configs/stable-diffusion/v1-inference.yaml --data_path <PATH/TO/LOAD/DATA> --cali_save_path <PATH/TO/SAVE/QUANTIZED/MODEL>

Inference

After the quantization process, you can generate images you like. You can remove --use_aq,--aq 8 to cancel activation quantization.

# ----------- DDIM -----------
python sample_diffusion_ddim.py --config ddim/configs/cifar10.yml --timesteps 100 --eta 0 --skip_type quad --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/RESULT> --use_aq --cali_ckpt <PATH/TO/LOAD/QUANTIZED/MODEL> --max_images 128

# ----------- LDM ------------
# LSUN-Bedrooms
python sample_diffusion_ldm.py -r ./stable-diffusion/models/ldm/lsun_beds256/model.ckpt -c 200 -e 1.0 --seed 40 --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/RESULT> --use_aq --cali_ckpt <PATH/TO/LOAD/QUANTIZED/MODEL> -n 5 --batch_size 5
# LSUN-Churches
python sample_diffusion_ldm.py -r ./stable-diffusion/models/ldm/lsun_churches256/model.ckpt -c 400 -e 0.0 --seed 40 --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/RESULT> --use_aq --cali_ckpt <PATH/TO/LOAD/QUANTIZED/MODEL> -n 5 --batch_size 5
# CelebA-HQ
python sample_diffusion_ldm.py -r ./stable-diffusion/models/ldm/celeba256/model.ckpt -c 200 -e 0.0 --seed 40 --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/RESULT> --use_aq --cali_ckpt <PATH/TO/LOAD/QUANTIZED/MODEL> -n 5 --batch_size 5
# FFHQ
python sample_diffusion_ldm.py -r ./stable-diffusion/models/ldm/ffhq256/model.ckpt -c 200 -e 1.0 --seed 40 --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/RESULT> --use_aq --cali_ckpt <PATH/TO/LOAD/QUANTIZED/MODEL> -n 5 --batch_size 5
# ImageNet
python latent_imagenet_diffusion.py -e 0.0 --ddim_steps 20 --scale 3.0 --seed 40 --wq <4 OR 8> --ptq --aq 8 --outdir <PATH/TO/SAVE/RESULT>  --use_aq --cali_ckpt <PATH/TO/LOAD/QUANTIZED/MODEL> --n_sample_per_class 2 --classes <CLASSES. e.g. 7,489,765>

# ----- Stable Diffusion -----
python txt2img.py --prompt <PROMPT. e.g. "A white dog."> --plms --no_grad_ckpt --ddim_steps 50 --seed 40 --cond --n_iter 1 --n_samples 1 --wq <4 OR 8> --ptq --aq 8 --skip_grid --outdir <PATH/TO/SAVE/RESULT> --use_aq --ckpt ./stable-diffusion/models/ldm/stable-diffusion-v1/sd-v1-4.ckpt --config stable-diffusion/configs/stable-diffusion/v1-inference.yaml --cali_ckpt <PATH/TO/LOAD/QUANTIZED/MODEL>

Acknowledgments

Our code was developed based on ddim, latent-diffusion and stable-diffusion. We referred to BRECQ and Q-Diffusion for the blockwise calibration implementation.

We thank OpenVINO for providing the framework to deploy our quantized model and measure acceleration. Moreover, we also thank torch-fidelity, guided-diffusion, and clip-score for IS, sFID, FID and CLIP score computation.

Citation

If you find our TFMQ-DM useful or relevant to your research, please kindly cite our paper:

@InProceedings{Huang_2024_CVPR,
    author    = {Huang, Yushi and Gong, Ruihao and Liu, Jing and Chen, Tianlong and Liu, Xianglong},
    title     = {TFMQ-DM: Temporal Feature Maintenance Quantization for Diffusion Models},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2024},
    pages     = {7362-7371}
}

tfmq-dm's People

Contributors

harahan avatar xhplus avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tfmq-dm's Issues

Can this receipt run on the A100 only with 40 G memory?

Thank you so much for this wonderful and valuable work. I have run the code on one A100 with 40 G memory, and it reports 'out of memory.' I have two A100 cards. Can you give me some suggestions on how to deal with this memory issue? Looking forward to your feedback.

Where is the code to implement the Finite set calibration?

Hi,
Thanks for this work! In the paper, the Finite set calibration is used which uses a timestep-aware scale factor to quant the activation in the embedding_layers and time_embed. I tried to run the calibration of txt2img.py, and have two questions.

  1. It seems that the disable_out_quantizaion function have set the use_aq of the time_embed to False, therefore the time_embed layer is actually not used for act-quant. So I am confused since the paper states the time_embed is also act-quantized.

  2. Where can I find the code to implement the timestep-aware scale factor for act quantization, I have searched the whole project but still can not find timesetp-spectific scale s for act-quant. May I get your help please?

Thx.

A small bug during the inference

Hi,
Thanks for this work! I found a small bug during the inference phase:In the inference of DDIM, this example uses -- cali_save_path to load <PATH/TO/LOAD/QUANTIZED/MODEL>.

So in my opinion,i prefer to change
python sample_diffusion_ddim.py --config ddim/configs/cifar10.yml --timesteps 100 --eta 0 --skip_type quad --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/RESULT> --use_aq --cali_save_path <PATH/TO/LOAD/QUANTIZED/MODEL> --max_images 128

to
python sample_diffusion_ddim.py --config ddim/configs/cifar10.yml --timesteps 100 --eta 0 --skip_type quad --wq <4 OR 8> --ptq --aq 8 -l <PATH/TO/SAVE/RESULT> --use_aq --cali_ckpt <PATH/TO/LOAD/QUANTIZED/MODEL> --max_images 128

Also, when I saw the function sample_fid about untill_fake_t, I thought the code was really smelly

Thx.

Caliberate with Stable Diffusion

Hi,

Sorry to bother you. I notice that you have use the text prompts json from COCO dataset for cali_data generation. Since there are various versions of COCO, I want to know which specific dataset should I download? btw, what gpus should be used for calibration with the SD model?

Thx.

Quantization on imagenet takes several days on A100

Hi there,

I have been running your code on the ImageNet model using a single A100 GPU, and it has been taking several days to complete. I am curious if this is the expected duration for this task.

In your paper, you mentioned that training on LSUN256 took around 2 GPU hours. Could you please provide more details on how you measured the GPU hours? It would help me understand the performance expectations better.

Thank you for your assistance!

Best regards,
Dongyeun Lee

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.