Coder Social home page Coder Social logo

diffusion-classifier / diffusion-classifier Goto Github PK

View Code? Open in Web Editor NEW
348.0 348.0 19.0 754 KB

Diffusion Classifier leverages pretrained diffusion models to perform zero-shot classification without additional training

Home Page: http://diffusion-classifier.github.io/

Python 100.00%
bayes-theorem classification computer-vision deep-learning diffusion diffusion-models elbo generative generative-models machine-learning monte-carlo robustness supervised-learning zero-shot-classification zero-shot-learning

diffusion-classifier's Introduction

Your Diffusion Model is Secretly a Zero-Shot Classifier

arXiv Website

This is the official implementation of the ICCV 2023 paper Your Diffusion Model is Secretly a Zero-Shot Classifier by Alexander Li, Mihir Prabhudesai, Shivam Duggal, Ellis Brown, and Deepak Pathak.

Abstract

The recent wave of large-scale text-to-image diffusion models has dramatically increased our text-based image generation abilities. These models can generate realistic images for a staggering variety of prompts and exhibit impressive compositional generalization abilities. Almost all use cases thus far have solely focused on sampling; however, diffusion models can also provide conditional density estimates, which are useful for tasks beyond image generation. In this paper, we show that the density estimates from large-scale text-to-image diffusion models like Stable Diffusion can be leveraged to perform zero-shot classification without any additional training. Our generative approach to classification, which we call Diffusion Classifier, attains strong results on a variety of benchmarks and outperforms alternative methods of extracting knowledge from diffusion models. Although a gap remains between generative and discriminative approaches on zero-shot recognition tasks, our diffusion-based approach has significantly stronger multimodal compositional reasoning ability than competing discriminative approaches. Finally, we use Diffusion Classifier to extract standard classifiers from class-conditional diffusion models trained on ImageNet. Our models achieve strong classification performance using only weak augmentations and exhibit qualitatively better "effective robustness" to distribution shift. Overall, our results are a step toward using generative over discriminative models for downstream tasks.

Installation

Create a conda environment with the following command:

conda env create -f environment.yml

If this takes too long, conda config --set solver libmamba sets conda to use the libmamba solver and could speed up installation.

Zero-shot Classification with Stable Diffusion

python eval_prob_adaptive.py --dataset cifar10 --split test --n_trials 1 \
  --to_keep 5 1 --n_samples 50 500 --loss l1 \
  --prompt_path prompts/cifar10_prompts.csv

This command reads potential prompts from a csv file and evaluates the epsilon prediction loss for each prompt using Stable Diffusion. This should work on a variety of GPUs, from as small as a 2080Ti or 3080 to as large as a 3090 or A6000. Losses are saved separately for each test image in the log directory. For the command above, the log directory is data/cifar10/v2-0_1trials_5_1keep_50_500samples_l1. Accuracy can be computed by running:

python scripts/print_acc.py data/cifar10/v2-0_1trials_5_1keep_50_500samples_l1

Commands to run Diffusion Classifier on each dataset are here. If evaluation on your use case is taking too long, there are a few options:

  1. Parallelize evaluation across multiple workers. Try using the --n_workers and --worker_idx flags.
  2. Play around with the evaluation strategy (e.g. --n_samples and --to_keep).
  3. Evaluate on a smaller subset of the dataset. Saving a npy array of test set indices and using the --subset_path flag can be useful for this.

Evaluating on your own dataset

  1. Create a csv file with the prompts that you want to evaluate, making sure to match up the correct prompts with the correct class labels. See scripts/write_cifar10_prompts.py for an example. Note that you can use multiple prompts per class.
  2. Run the command above, changing the --dataset and --prompt_path flags to match your use case.
  3. Play around with the evaluation strategy on a small subset of the dataset to reduce evaluation time.

Standard ImageNet Classification with Class-conditional Diffusion Models

Additional installations

Within the diffusion-classifier folder, download the DiT repository

git clone [email protected]:facebookresearch/DiT.git

Running Diffusion Classifier

First, save a consistent set of noise (epsilon) that will be used for all image-class pairs:

python scripts/save_noise.py --img_size 256

Then, compute and save the epsilon-prediction error for each class:

python eval_prob_dit.py  --dataset imagenet --split test \
  --noise_path noise_256.pt --randomize_noise \
  --batch_size 32 --cls CLS --t_interval 4 --extra dit256 --save_vb

For example, for ImageNet, this would need to be run with CLS from 0 to 999. This is currently a very expensive process, so we recommend using the --subset_path command to evaluate on a smaller subset of the dataset. We also plan on releasing an adaptive version that greatly reduces the computation time per test image.

Finally, compute the accuracy using the saved errors:

python scripts/print_dit_acc.py data/imagenet_dit256 --dataset imagenet

We show the commands to run DiT on all ImageNet variants here.

Compositional Reasoning on Winoground with Stable Diffusion

To run Diffusion Classifier on Winoground: First, save a consistent set of noise (epsilon) that will be used for all image-caption pairs:

python scripts/save_noise.py --img_size 512

Then, evaluate on Winoground:

python run_winoground.py --model sd --version 2-0 --t_interval 1 --batch_size 32 --noise_path noise_512.pt --randomize_noise --interpolation bicubic

To run CLIP or OpenCLIP baselines:

python run_winoground.py --model clip --version ViT-L/14
python run_winoground.py --model openclip --version ViT-H-14

Citation

If you find this work useful in your research, please cite:

@misc{li2023diffusion,
      title={Your Diffusion Model is Secretly a Zero-Shot Classifier}, 
      author={Alexander C. Li and Mihir Prabhudesai and Shivam Duggal and Ellis Brown and Deepak Pathak},
      year={2023},
      eprint={2303.16203},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

diffusion-classifier's People

Contributors

alexlioralexli avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

diffusion-classifier's Issues

The diffusion.datasets package in print_dit_acc.py

conda env create -f environment.yml execution complete, but:

(diffusion-classifier) ➜ diffusion-classifier git:(master) ✗ python scripts/print_dit_acc.py data/imagenet_dit256 --dataset imagenet
Traceback (most recent call last):
File "/home/featurize/work/diffusion-classifier/scripts/print_dit_acc.py", line 5, in
from diffusion.datasets import get_target_dataset, IMAGENET_A_CLASSES
ModuleNotFoundError: No module named 'diffusion'

What's the problem?

Public Benchmarking

Thanks for the work guys!
It would be incredibly helpful if you could run classification benchmarks on paperswithcode.com

Possible to run on CPU?

Hi there,

Forgive me if this is the wrong place to ask. I have installed the relevant dependencies in a fresh env. and tried running the classification command. Since I'm using a CPU machine and not GPU-enabled, I get a few errors relating to xformers and cuda, etc. etc.

Is there any way to run the model with CPU only and avoid these errors? I imagine not, but if anyone has any ideas, please let me know.

About test samples for computing accuracy

Hi,

Congrats for the impressive work. The inference speed is quite slow though, and it was mentioned "for computational reasons, we evaluated on 4 images per class (4000 test images total)." I am wondering if this protocol has been adpoted for all the datasets (Food, pets, cifar10, aircraft, flowers, stl10). For Imagenet, it was mentioned that the test set is 2000 images. The clarification would be much appreciated.

Thanks,
Aniket.

Other diffusion models

Hi, thanks for your brilliant work! I've tried diffusion classifier with stable dffusion 2.1 locally, and found the performance on cifar10 achieves 85.88 which is even higher than illustrated in the paper. This is a proof of the power of discriminative ability of generative diffusion models.

My question is, have you tried other diffusion models, like DDIM, unCLIP and Score SDE? Is the impressive performance highly dependent on stable diffusion?

conda env create takes forever, anyone has the same issue?

Here's some of the debug output from conda, how can I resolve this? Any help is appreciated!

DEBUG conda.resolve:solve(1346): Solve: minimize removed packages
DEBUG conda.resolve:solve(1353): Solve: maximize versions of requested packages
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 432932
DEBUG conda.common._logic:minimize(745): Final peak objective: 0
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 433887
DEBUG conda.common._logic:minimize(745): Final peak objective: 0
DEBUG conda.resolve:solve(1357): Initial package channel/version metric: 0/0
DEBUG conda.resolve:solve(1360): Solve: minimize track_feature count
DEBUG conda.resolve:generate_feature_count(933): generate_feature_count returning with clause count: 435425
DEBUG conda.common._logic:minimize(658): Clauses added, recomputing solution
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 435425
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 435447
DEBUG conda.common._logic:minimize(745): Final sum objective: 0
DEBUG conda.resolve:solve(1363): Track feature count: 0
DEBUG conda.common._logic:minimize(664): Empty objective, trivial solution
DEBUG conda.resolve:solve(1374): Package misfeature count: 0
DEBUG conda.resolve:solve(1377): Solve: maximize build numbers of requested packages
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 435608
DEBUG conda.common._logic:minimize(745): Final peak objective: 0
DEBUG conda.resolve:solve(1379): Initial package build metric: 0
DEBUG conda.resolve:solve(1382): Solve: prefer arch over noarch for requested packages
DEBUG conda.common._logic:minimize(664): Empty objective, trivial solution
DEBUG conda.resolve:solve(1384): Noarch metric: 0
DEBUG conda.resolve:solve(1388): Solve: minimize number of optional installations
DEBUG conda.common._logic:minimize(664): Empty objective, trivial solution
DEBUG conda.resolve:solve(1391): Optional package install metric: 0
DEBUG conda.resolve:solve(1394): Solve: minimize number of necessary upgrades
DEBUG conda.common._logic:minimize(664): Empty objective, trivial solution
DEBUG conda.resolve:solve(1397): Dependency update count: 0
DEBUG conda.resolve:solve(1400): Solve: maximize versions and builds of indirect dependencies. Prefer arch over noarch where equivalent.
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 455945
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 435611
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 438652
DEBUG conda.common._logic:minimize(745): Final peak objective: 1
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 10576976
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 4038608
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 1975208
DEBUG conda.common._logic:_run_sat(607): Invoking SAT with clause count: 1250750

Questions about the paper.

Hi, thank you for sharing such impressive work! I am wondering if there is another path to derive a diffusion classifier. Diffusion model is an SDE, which has a closed ODE form. The ODE allows us to calculate the exact log-likelihood of an input X. The higher the likelihood, the more reliable the model believes the input X is from the distribution. So I think using the log-likelihood can be another strategy to implement all the functionalities in your paper, such as zero-shot classifier.

Best,
Zhangzhi

Question about the inference hyper-parameters

I see the 'Prompts kept per stage' and 'Evaluations per stage' in the paper, like '5 1' and '50 500' for CIFAR10, and the diffusion-classifier get the high score of 88.5. So I'm curious about the exact hyper-parameters about the inference of CIFAR10 and STL10. Thanks a lot!

Question about the code running speed

I tried to run the code on a single A100 as described in the Readme. It took me 30s to deal with a single image. Since Cifar10 has 10000 test images, it would take me 83h to finish the whole process. Is it normal? Is there any way to speed it up? Thank you!

Question about results of cifar 100 dataset

Hello, thanks for the great work! I was wondering if you have some insights on how well the model performs on CIFAR100, or if there is a set of hyperparameters that work well under this setting? Thank you!!

About the diffusion model implementation

I currently implement classification using my own dataset. If I wanted to make it better, my idea would be to fine-tune a diffusion model with my own dataset as a way to improve the accuracy for a specific dataset. I don't know if there is any problem with my idea, please advise.
If it works, how to load my own diffusion model? Thx!

About the prediction probability

Thank you for your interesting work, and replying for my previous question.

I have an another question regarding computing prediction probability distribution of a given data x.

When I run the eval_prob_adaptive.py and get a array of losses with respect to each class prompt, I found that the losses are very close to each other.
That is, when I softmax the array, the probability for each class is close to 1/N which may not be desired.

I found that similar issue had been raised before. (#11 (comment))

Could you check about this issue?
Thank you in advance.

Getting pred. probabilities

Hi there,

Is there an obvious way to to extract the prediction probabilities instead of just no. correct labels vs no. ground truth labels? I have checked the print_acc.py and eval_prob_adaptive.py but I can't see anything obvious. Any help would be super appreciated.

Thanks.

Example about multiple workers

Thanks for your impressive work!

I noticed you have mentioned --n_workers and --worker_idx options. In the code, they just split the dataset.
I'm wondering how I use them since the inference speed is relatively slow (maybe using multiprocessing?). And whether it will be better if using multi GPU?

Could you please give an example about the options :)

question about 'SD Features'

I'm looking forward to seeing you open up the implementation code for the 'SD Features' model mentioned in the paper. Because I think SD features is a research area.

about add noise implementation

another question is about the add noise implementation

noised_latent = latent * (scheduler.alphas_cumprod[batch_ts] ** 0.5).view(-1, 1, 1, 1).to(device) + \
noise * ((1 - scheduler.alphas_cumprod[batch_ts]) ** 0.5).view(-1, 1, 1, 1).to(device)

scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")

  1. I notice that is the implementation of DDIM or DDPM add_noise function, can I interpret this as being the same, i.e., replacing it with scheduler.add_noise(latent,noise,batch_ts)
  2. the scheduler used here is EulerDiscreteScheduler, anything different from other scheduler like original DDPM? Would it be better to use this scheduler in the classification situation.

Example of the argument "subset_path"

Can you provide some examples to explain how to set the argument "subset_path"?
For example, if I just want to test the first 10 testing images in CIFAR 10, how to do it?

Thank you.

multiplication of the encoded image by 0.18215

Hey,

Can you please explain the reasoning behind the multiplication of the image by 0.18215 after processing it with the VAE encoder? I see that it helps produce better results; however, I cannot discern the reasoning behind this operation.

Thanks!

No such operator xformers

hello everyone!
i encountered a problem which is "RuntimrError:No such operator xformers::efficient_attention_forward_cutlass - did you forget to build xformers with 'python setup.py develop'?"when i run the code.
I have looked for many solutions but none of them worked.Does anyone have a solution to make the code run?
Thanks!

When will the code be released?

Dear authors:

I'm very interested in your work, and looking forward to use it in my research. So I would like to ask is there a specific timeline for the code release, and is it possible that the code will be released in the next two weeks?

Really looking forward to the release :)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.