Coder Social home page Coder Social logo

diff-mining's Introduction

Official PyTorch implementation of Diffusion Models as Data Mining Tools, which has been accepted in ECCV'24.

Introduction

Our approach allows you to take a large labelled input dataset, and mine the patches that are important for each label. It involves three steps:

  1. First you finetune Stable-Diffusion v1.5 with its standard loss $L_t(x, \epsilon, c)$ with prompts of the form $\text{"An image of Y"}$ (where Y is your label) in your custom dataset.
  2. For a sample of your input data you want to analyze, you then compute typicality $\mathbf{T}(x|c) = \mathbb{E}_{\epsilon,t}[L_t(x, \epsilon, \varnothing) - L_t(x, \epsilon, c)]$ for all images.
  3. You extract the top-1000 patches according to $\mathbf{T}(x | c)$ and then you cluster them using DIFT-161 features (ranking clusters according to median typicality of their elements).

Installation 🌱

Our codebase is mainly developed on diffusers implementation of LDMs.

conda env create -f environment.yaml
conda activate diff-mining

Data 💽

We apply our method in 5 different types of datasets: cars (CarDB), faces (FTT), street-view images (G^3), scenes (Places, high-res) and X-rays (ChestX-ray):

  • A properly extracted version of CarDB can be found here and can be downloaded with:
python scripts/download-cardb.py

Models 🔬

We share our models on huggingface which you can access through the handles:

or download them locally using:

python scripts/download-models.py

Approach

A full walkthrough of the pipeline can be seen in scripts: scripts/training.sh and scripts/typicality.sh.

  • Code for finetuning models can be found under: diffmining/finetuning/.
  • Code for computing typicality can be found at: diffmining/typicality/compute.py.
  • Code for averaging typicality across patches, computing DIFT features and clustering can be found at: diffmining/typicality/cluster.py

Applications🔸

We test our typicality measure in two different approaches which we properly discuss in our paper.

Clustering of Translated Visual Elements

Using our diffusion model, we can translate each image, e.g. in the case of geography, from one country to another. We use PnP which is the only method we found that was relatively robust in keeping a consistency between translated objects (i.e., windows would remain windows). You can launch this translation by running:

source scripts/parallel.sh translate

Afterwards you need to compute typicality for all elements:

source scripts/parallel.sh compute

and then cluster them using:

source scripts/parallel.sh cluster

Emergent Disease Localization in X-rays 🩻

As typicality is connected to a binary classifier of the conditional vs the null conditioning, it can be used to "spatialize" information related to the condition on the input image. We test this on X-ray images and show how typicality is improved after finetuning. To reproduce our results and evaluations run:

source scripts/xray.sh

Comparing with Doersch et al. 2012 🥐

We provide a minimal optimized implementation of the algorithm of "What makes Paris look like Paris?" under doersch/. Running the code should only require:

python doersch.py --which geo --category 'Italy'

yet you will probably have to adjust it to the dataset of choice.

Citing 💫

  @article{diff-mining,
    title = {Diffusion Models as Data Mining Tools},
    author = {Siglidis, Ioannis and Holynski, Aleksander and Efros, A. Alexei and Aubry, Mathieu and Ginosar, Shiry},
    journal = {ECCV},
    year = {2024},
  }

Acknowledgements

This work was partly supported by the European Research Council (ERC project DISCOVER, number 101076028) and leveraged the HPC resources of IDRIS under the allocation AD011012905R1, AD0110129052 made by GENCI. We would like to thank Grace Luo for data, code, and discussion; Loic Landreu and David Picard for insights on geographical representations and diffusion; Karl Doersch, for project advice and implementation insights; Sophia Koepke for feedback on our manuscript.

diff-mining's People

Contributors

ysig avatar

Stargazers

Eisneim Terry avatar Syrine Kalleli avatar Hyogon Ryu avatar zhengjia avatar Jeff Carpenter avatar  avatar James Chang avatar Liumc avatar  avatar Beacon_Song avatar Wenzhao Zheng avatar Youngjun Choi avatar SHI YI avatar SirRa1zel avatar Dongliang Chang avatar  avatar  avatar 瓜田里的猹 avatar yuli avatar John Casey avatar Xuanlong Yu avatar Alpesh Doshi avatar Cogito Ergo Sum avatar xshen avatar Said avatar Zhuoyang Pan avatar  avatar Park Sang kil avatar  avatar Andrés Romero avatar Clarence avatar Bingchen Zhao avatar menorki manil avatar Sai Kumar Dwivedi avatar Hengyu Liu avatar  avatar  avatar <>(CK)<> avatar Straughter "BatmanOsama" Guthrie avatar

Watchers

 avatar  avatar

diff-mining's Issues

about xrays

Greetings!
Thanks for sharing the code.
I want to run the CXR8 dataset with this model. I downloaded the code from GitHub and downloaded the CXR8 dataset to ./dataset/CXR8. Then I executed conda env create -f environment.yaml and conda activate diff-mining. At this time, when I executed python diffmining/applications/xray/finetune.py --data_path dataset/CXR8/ --train_batch_size 8 --output_dir models/xray --num_train_epochs 100, I encountered some errors in diffmining/applications/xray/finetune.py. For example, No module named 'pandas', name 'sys' is not defined, No module named 'matplotlib'. So did I miss any steps?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.