Coder Social home page Coder Social logo

yanndubs / ssl-risk-decomposition Goto Github PK

View Code? Open in Web Editor NEW
13.0 2.0 1.0 152.15 MB

Benchmark and analysis of 165 pretrained SSL models. Code for "Evaluating Self-Supervised Learning via Risk Decomposition".

Python 2.81% Shell 0.58% Jupyter Notebook 96.61%
benchmark deep-learning evaluation machine-learning model-zoo pytorch representation-learning self-supervised-learning

ssl-risk-decomposition's Introduction

Evaluating Self-Supervised Learning via Risk Decomposition License: MIT Python 3.8+

This repository contains:

All pretrained models, hyperparameters, results

We release all pretrained weights, hyperparameters, and results on torch.hub, which can be loaded using:

import torch

# loads the desired pretrained model and preprocessing pipeline
name = "dino_rn50" # example
model, preprocessor = torch.hub.load('YannDubs/SSL-Risk-Decomposition:main', name, trust_repo=True)

# gets all available models 
available_names = torch.hub.list('YannDubs/SSL-Risk-Decomposition:main')

# gets all results and hyperparameters as a dataframe 
results_df = torch.hub.load('YannDubs/SSL-Risk-Decomposition:main', "results_df")

The necessary dependencies are:

  • for most models: pip install torch torchvision tqdm timm pandas
  • for all models: pip install torch torchvision tqdm timm dill open_clip_torch git+https://github.com/openai/CLIP.git
Details
  • timm: for any ViT architecture
  • pandas: for results_df, metadata_df
  • dill: for BYOL
  • open-clip-torch: for OpenCLIP
  • git+https://github.com/openai/CLIP.git: for CLIP

Computing the loss decomposition

Here's a minimal code to compute the loss decomposition.

def compute_risk_components(model_ssl, D_train, D_test, model_sup=None, n_sub=10000, **kwargs):
    """Computes the SSL risk decomposition for `model_ssl` using a given training and testing set.
    
    If we are given a supervised `model_sup` of the same architecture as model_ssl, we compute the 
    approximation error. Else we merge it with usability error given that approx error is neglectable.
    """
    errors = dict()
    
    # featurize data to make probing much faster. Optional.
    D_train = featurize_data(model_ssl, D_train)
    D_test = featurize_data(model_ssl, D_test)
    
    D_comp, D_sub = data_split(D_train, n=n_sub)
    
    r_A_F = train_eval_probe(D_train, D_train, **kwargs)
    r_A_S = train_eval_probe(D_comp, D_sub, **kwargs)
    r_U_S = train_eval_probe(D_train, D_test, **kwargs)
    
    if model_sup is not None:
        D_train_sup = featurize_data(model_sup, D_train)
        errors["approx"] = train_eval_probe(D_train_sup, D_train_sup, **kwargs)
        errors["usability"] = r_A_F - errors["approx"]
    else:
        errors["usability"] = r_A_F # merges both errors but approx is neglectable
        
    errors["probe_gen"] = r_A_S - r_A_F
    errors["encoder_gen"] = r_U_S - r_A_S 
    errors["agg_risk"] = r_U_S
    return errors

def featurize_data(model, dataset):
    """Featurize a dataset using the model."""
    ...


def train_eval_probe(D_train, D_test, **kwargs):
    """Trains a model (encoder and probe) on D_train and evaluates it on D_test"""
    ...

def data_split(dataset, n):
    """Split a dataset into a set of size n and its complement"""
    ...

For a minimal notebook computing the loss decomposition and a specific implementations for the above functions see: Minimal training of DISSL.

For the actual code that we used (includes hyperparameter tuning) see: main_fullshot.py

Reproducing results

Steps to reproduce all the paper: 0. Install requirements_running.txt (pip) or environment_running.yml (conda) to compute all risk components.

  1. to recompute all risk components run: scripts/run_all.sh. To recompute specific models the corresponding script in scripts/ with the correct server (see config/server). E.g. scripts/simsiam.sh -s nlprun
  2. to recompute all few shot evaluation run: script_sk/run_all.sh (we use sklearn instead of pytorch for that).
  3. Install requirements_analyzing.txt (pip) or environment_analyzing.yml (conda) to analyze all results.
  4. to reproduce all the analysis and plot from the main_paper run notebooks/main_paper.ipynb
  5. to reproduce all the analysis and plot from the appendices run notebooks/appcs.ipynb

Contributing

If you have a pretrained model that you would like to add, please open a PR with the following:

  1. In hub/ the files and code to load your model. Then in hubconf.py add a one line function that loads the desired model. The name of that function will be the name of the model in torch.hub. Make sure that you load everything from hub/ using a learning underscore. Follow previous examples.
  2. Add all the hyperparameters and metadata in metadata.yaml. Documentation of every field can be found at the top of that file. Follow previous examples.

ssl-risk-decomposition's People

Contributors

yanndubs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

ssl-risk-decomposition's Issues

Regarding the Imagenet Dataset

Dear author,
Thank you for your work. Recently when I tried to run bash scripts/run_all.sh, I came across an error like this: ValueError: Imagenet data folder (imagenet256 or imagenet) not found in /data/data. I am trying to download imagenet data but there are so many different ones here. Could you please share how to download and organize data, Please? Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.