Coder Social home page Coder Social logo

hsg-aiml / sane Goto Github PK

View Code? Open in Web Editor NEW
6.0 5.0 2.0 1.12 MB

Code Repository for the ICML 2024 paper: "Towards Scalable and Versatile Weight Space Learning".

Python 97.46% Shell 0.38% Jupyter Notebook 2.16%
deep-learning machine-learning model-zoo neural-networks representation-learning weight-space-learning

sane's Introduction

Overview of the SANE approach

SANE: Scalable and Versatile Weight Space Learning

This repository contains the code for the paper "Towards Scalable and Versatile Weight Space Learning," presented at ICML 2024. This work introduces SANE*, a novel approach for learning task-agnostic representations of neural networks that are scalable to larger models and applicable to various tasks. The paper can be found here.

*Sequential Autoencoder for Neural Embeddings

Summary

Learning representations of well-trained neural network models holds the promise to provide an understanding of the inner workings of those models. However, previous work has faced limitations when processing larger networks or was task-specific to either discriminative or generative tasks. This paper introduces the SANE approach to weight-space learning. SANE overcomes previous limitations by learning task-agnostic representations of neural networks that are scalable to larger models of varying architectures and show capabilities beyond a single task. Our method extends the idea of hyper-representations towards sequential processing of subsets of neural network weights, allowing one to embed larger neural networks as a set of tokens into the learned representation space. SANE reveals global model information from layer-wise embeddings and can sequentially generate unseen neural network models, which was unattainable with previous hyper-representation learning methods. As the figure below shows, extensive empirical evaluation demonstrates that SANE (light blue) matches or exceeds state-of-the-art performance on several weight representation learning benchmarks, particularly in initialization for new tasks and larger ResNet architectures.

Performance overview of SANE

Key Methods

  • Sequential Decomposition: Breaking down neural network weights into smaller, manageable token sequences.
  • Self-Supervised Pretraining: Using a self-supervised approach to pretrain the SANE model on a variety of tasks and architectures on subsequences of models.
  • Model Analysis: Analysing models by their embedding sequences.
  • Model Sampling: Generating new neural network models by sampling from the learned representation space.

Results

  • Model Property Prediction: SANE embeddings demonstrate high predictive performance for model properties such as test accuracy, epoch, and generalization gap across various datasets and architectures.
  • Generative Capabilities: SANE can generate high-performing neural network models from scratch or fine-tune them with significantly less computational effort compared to training from scratch.
  • Scalability: The method scales to large models like ResNet-18, preserving meaningful information across long sequences of tokens.

Code Structure

  • data/: Scripts for data preprocessing and loading.
  • experiments/: Example experiments to pre-train SANE on a CIFAR100-ResNet18 zoo, predict propeties and sample models
  • src/: contains the SANE package to preprocess model checkpoint datasets, pre-train SANE, and perform discriminative and generative downstream tasks.

Running Experiments

We include code to run example experiments and showcase how to use our code.

Download Model Zoo Datasets

We have made several model zoos available at modelzoos.cc. Any of these zoos can be used in our pipeline, with minor adjustments.

To get started with a small experiment, navigate to ./data/ and run

bash download_cifar10_cnn_sample.sh

This will download and unzip a small model zoo example with CNN models trained on CIFAR-10. Training on large model zoos requires preprocessing for training efficiency. We provide code to preprocess training samples. To compile those datasets, run

python3 preprocess_dataset_cnn_cifar10_sample.py

in ./data/. in the same directory, we provide download and preprocessing scripts for other zoos as well. The preprocessed datasets have no specific dependency requirements, other than regular numpy and pytorch.

Please note that this is not the exact models used in the paper and will therefore produce different results. The full zoos can be downloaded from modelzoos.cc and used in the same way as the zoo sample.

Pretraining SANE

Code to pretrain SANE on the sample zoo is contained in experiments/pretrain_sane_cifar100_resnet18.py. The code relies on ray.tune to manage resources, but currently only runs a single config. To vary any of the configurations, exchange the value with tune.grid_search([value_1, ..., value_n]). To run the experiment, run

python3 pretrain_sane_cifar100_resnet18.py

in experiments/

Using SANE embeddings to predict properties

SANE embeddings preserve the sequential decomposition of models. This enables a more fine-granual analysis of models compared to global model embeddings. The figure below shows a comparison between SANE embeddings (right) and features used in the WeightWatcher library (left), which are based on the eigendecomposition of the weight matrices. Both show similar trends of layer properties in ResNet models, but SANE appears to pick up on additional signals in the middle layers. Analysis of models: comparing weight matrix eigendecomposition features (left) to SANE embeddings (right)

We provide code to use SANE embeddings to predict such model properties in experiments/property_prediction_cifar100_resnet18.py. It assumes downloaded dataset and pre-trained SANE as described above. Within property_prediction_cifar100_resnet18.py, set the path to the pretrained SANE model and epoch. then run

python3 property_prediction_cifar100_resnet18.py

within experiments. This will compute the property prediction results from both SANE embeddings and weight-statistic baselines and save them in a json.

Generating Models

Generating models can provide initializations even for new tasks and architectures that give an advantage over random initializations, see the Figure below. SANE weight generation comparison. Models initialized with SANE outperform random initialization on new models or tasks

In experiments, there is also code to generate and evaluate models. cnn-cifar10_exploration.ipynb is a quick-start notebook to explore the datasets, SANE models, training loop, as well as encoding and de-coding models. We further provide experiment code for the sample dataset of cnns and a larger resnet dataset. For the latter, sample_finetune_cifar100_resnet18.py contains an example for model sampling. As above, set the path to a pretrained SANE model and epoch. Then, run

python3 sample_finetune_cifar100_resnet18.py

within experiments. Generating models requires a pre-processed CIFAR100 dataset that can be generated by running

python3 prepare_cifar100_dataset.py

The small cnn sample experiment code is structured correspondingly.

Contact

Feel free to get in touch with us with any questions on the project or to request access to data and / or pretrained models. Reach out to [email protected].

Citation

If you use this code in your research, please cite our paper:

@inproceedings{schuerholt2024sane,
    title={Towards Scalable and Versatile Weight Space Learning},
    author={Konstantin Sch{"u}rholt and Michael W. Mahoney and Damian Borth},
    booktitle={Proceedings of the 41st International Conference on Machine Learning (ICML)},
    year={2024},
    organization={PMLR}
}

License

This project is licensed under the MIT License.

sane's People

Contributors

kschuerholt avatar hawk-li avatar damianborth avatar

Stargazers

Aaron Keesing avatar Eliot avatar J Schulte avatar  avatar Kuk Jin Kim avatar Eliot avatar

Watchers

 avatar Martin Eigenmann avatar Marco Schreyer avatar Kostas Georgiou avatar IronMan avatar

Forkers

lliai jatlantic

sane's Issues

regarding training dataset

Thank you for your excellent work.

Could you please provide the code to download and preprocess the dataset used in Table 4 with ResNet? Additionally, I need the code for the CNN-12k model referenced in Table 3. I would like to have the exact zoo for performance comparison.
Thank you!

EXAMPLE CODE FOR CNN3

Hello,
i think that it would be better if the example was for large CNN3 whose weights can be easily download

Did you train Resnet18 of cifar100 on epochs 60 only

I’m interested in knowing whether the method was trained solely with the base datasets provided in the examples—specifically for experiments using ResNet18 on CIFAR-10, CIFAR-100, and MiniImageNet. Could you please clarify if this is the case?

Thank you

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.