Coder Social home page Coder Social logo

mlbio-epfl / turtle Goto Github PK

View Code? Open in Web Editor NEW
13.0 3.0 6.0 5.7 MB

[ICML 2024] Let Go of Your Labels with Unsupervised Transfer

Home Page: https://brbiclab.epfl.ch/projects/turtle/

Python 100.00%
clustering deep-learning foundation-models icml icml-2024 implicit-bias maximum-margin-learning transfer-learning unsupervised-learning

turtle's Introduction

Let Go of Your Labels with Unsupervised Transfer

Artyom Gadetsky*, Yulun Jiang*, Maria Brbić

Project page | Paper | BibTeX


This repo contains the source code of 🐢 TURTLE, an unupervised learning algorithm written in PyTorch. 🔥 TURTLE achieves state-of-the-art unsupervised performance on the variety of benchmark datasets. For more details please check our paper Let Go of Your Labels with Unsupervised Transfer (ICML '24).

PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC PWC

The question we aim to answer in our work is how to utilize representations from foundation models to solve a new task in a fully unsupervised manner. We introduce the problem setting of unsupervised transfer and highlight the key differences between unsupervised transfer and other types of transfer. Specifically, types of downstream transfer differ in the amount of available supervision. Given representation spaces of foundation models, (i) supervised transfer, represented as a linear probe, trains a linear classifier given labeled examples of a downstream dataset; (ii) zero-shot transfer assumes descriptions of the visual categories that appear in a downstream dataset are given, and employs them via text encoder to solve the task; and (iii) unsupervised transfer assumes the least amount of available supervision, i.e., only the number of categories is given, and aims to uncover the underlying human labeling of a dataset.


TURTLE is a method that enables fully unsupervised transfer from foundation models. The key idea behind our approach is to search for the labeling of a downstream dataset that maximizes the margins of linear classifiers in the space of single or multiple foundation models to uncover the underlying human labeling. Compared to zero-shot and supervised transfer, unsupervised transfer with TURTLE does not need the supervision in any form. Compared to deep clustering methods, TURTLE does not require task-specific representation learning that is expensive for modern foundation models.

Dependencies

The code is built with the following libraries

To install cuml, you can follow the instructions on this page.

Quick Start

In our paper, we consider 26 vision datasets studied in (Radford et al. 2021) and 9 different foundation models. As a running example, we present the full pipeline to train TURTLE on the CIFAR100 dataset.

  1. Precompute representations and save ground truth labels for the dataset
python precompute_representations.py --dataset cifar100 --phis clipvitL14
python precompute_representations.py --dataset cifar100 --phis dinov2 
python precompute_labels.py --dataset cifar100
  1. Train TURTLE with 2 representation spaces
python run_turtle.py --dataset cifar100 --phis clipvitL14 dinov2 

or with the single representation space

python run_turtle.py --dataset cifar100 --phis clipvitL14
python run_turtle.py --dataset cifar100 --phis dinov2

The results and the checkpoints will be saved at ./data/results, ./data/task_checkpoints. You can also use --root_dir in all scripts to specify root directory instead of ./data which is used by default.

Data Preparation

Most datasets can be automatically downloaded by running precompute_representations.py and precompute_labels.py. However, some of the datasets require manual downloading. Please check dataset_preparation/data_utils.py for guide to prepare all the datasets used in our paper.

As an example, to prepare pets dataset that is not directly available at torchvision.datasets, one can run:

python dataset_preparation/prepare_pets.py -i ./data/datasets/pets -o ./data/datasets/pets -d

to download and extract the dataset at ./data/datasets/pets.

After downloading the dataset, run the following command to precompute the representations and labels:

python precompute_representations.py --dataset ${DATASET} --phis ${REPRESENTATION}
python precompute_labels.py --dataset ${DATASET}

Datasets and representations covered in this repo:

  • 26 datasets: food101, cifar10, cifar100, birdsnap, sun397, cars, aircraft, dtd, pets, caltech101, flowers, mnist, fer2013, stl10, eurosat, resisc45, gtsrb, kitti, country211, pcam, ucf101, kinetics700, clevr, hatefulmemes, sst, imagenet.
  • 9 representations: clipRN50, clipRN101, clipRN50x4, clipRN50x16, clipRN50x64, clipvitB32, clipvitB16, clipvitL14, dinov2.

Running TURTLE

Once the representations and labels are precomputed, to train TURTLE with a single space, run:

python run_turtle.py --dataset ${DATASET} --phis ${REPRESENTATION} 

or to train TURTLE with multiple representation spaces, run

python run_turtle.py --dataset ${DATASET} --phis ${REPRESENTATION1} ${REPRESENTATION2}

You can also use --inner_lr, ---outer_lr, --warm_start to specify inner step size, outer step size and whether to use cold-start or warm start bilevel optimization. Furthermore, use --cross_val to compute the generalization score for the found labeling after training. You can perform hyperparameter sweep and use the generalization score to select the best hyperparemeters without using ground truth labels.

Pre-trained Checkpoints

We also release the labelings found by TURTLE for all datasets and all model architectures used in our paper. To download pre-trained checkpoints, run:

wget https://brbiclab.epfl.ch/wp-content/uploads/2024/06/turtle_tasks.zip
unzip turtle_tasks.zip

Then, you can evaluate the pre-trained checkpoint of TURTLE with the single space by running:

python evaluate.py --dataset cifar100 --phis clipvitL14 --task_ckpt {PATH_TO_TURTLE_TASKS}/1space/clipvitL14/cifar100.pt
python evaluate.py --dataset cifar100 --phis dinov2     --task_ckpt {PATH_TO_TURTLE_TASKS}/1space/dinov2/cifar100.pt

or evaluate using two representation spaces using:

python evaluate.py --dataset cifar100 --phis clipvitL14 dinov2 --task_ckpt {PATH_TO_TURTLE_TASKS}/2space/clipvitL14_dinov2/cifar100.pt

Baselines

We also provide implemetation of Zero-shot Transfer with CLIP, Linear Probe and K-Means baselines in the baselines folder. To implement linear probe and K-Means baselines we employ cuml for highly efficient cuda implementations.

Linear Probe

Precompute the representations and then perform linear probe evaluation by running:

python baselines/linear_probe.py --dataset ${DATASET} --phis ${REPRESENTATION}

To select the l2 regularization strength for better performance, run

python baselines/linear_probe.py --dataset ${DATASET} --phis ${REPRESENTATION} --validation

K-Means

Precompute the representations and run K-Means baseline:

python baselines/kmeans.py --dataset ${DATASET} --phis ${REPRESENTATION}

Zero-shot Transfer

Run CLIP zero-shot transfer:

python baselines/clip_zs.py --dataset ${DATASET} --phis ${REPRESENTATION}

Acknowledgements

While developing TURTLE we greatly benefited from the open-source repositories:

Citing

If you find our code useful, please consider citing:

@inproceedings{
    gadetsky2024let,
    title={Let Go of Your Labels with Unsupervised Transfer},
    author={Gadetsky, Artyom and Jiang, Yulun and Brbi\'c, Maria},
    booktitle={International Conference on Machine Learning},
    year={2024},
}

turtle's People

Contributors

agadetsky avatar mbrbic avatar

Stargazers

 avatar huihui1999 avatar  avatar  avatar  avatar Giannis Moustakas avatar  avatar  avatar Wei Wang avatar Yicen Li avatar Денис avatar Lukas Miklautz avatar Benedikt Alkin avatar

Watchers

 avatar  avatar Kostas Georgiou avatar

turtle's Issues

Dinov2 killed issue

Hello, i am trying replicate papers result with github resource. But i cannnot help but raise issue in precompute_representations.py

When i am trying to do:

python3 precompute_representations.py --dataset mnist --phis dinov2 --batch_size 100

i got killed without error

i setup my environment with github guide

# Pull git
git clone https://github.com/mlbio-epfl/turtle.git

# Install dependency
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install numpy scipy scikit-learn clip tqdm
pip install \
    --extra-index-url=https://pypi.nvidia.com \
    cudf-cu12==24.6.* dask-cudf-cu12==24.6.* cuml-cu12==24.6.* \
    cugraph-cu12==24.6.* cuspatial-cu12==24.6.* cuproj-cu12==24.6.* \
    cuxfilter-cu12==24.6.* cucim-cu12==24.6.* pylibraft-cu12==24.6.* \
    raft-dask-cu12==24.6.* cuvs-cu12==24.6.*

thanks in advances

Using this for finding top k most similar images in the clusters.

I have a large image dataset without labels. My application is : given an input image find the k most similar images after the clusters are formed on my large image dataset. Can you please let me know? I work at FAANG and want to experiment it for image to image similarity for a project. I went thru the evaluation.py and run_turtle.py scripts, and feel confused how to do this for one single input image.
Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.