Coder Social home page Coder Social logo

empirical-ntks's Introduction

Empirical NTKs in PyTorch

This repository contains code for efficiently computing empirical NTKs and is published alongside our paper "More Than a Toy: Random Matrix Models Predict How Real-World Neural Representations Generalize" (ICML 2022).

Usage

The following command computes the empirical NTK for a subset of CIFAR-10 (specifically, on the first 2,500 train samples and the first 1,000 test samples). The output is a 3,500 x 2,500 matrix.

python3 ntk.py CIFAR-10_0_2500_0_1000 resnet-18_pretrained --workers-per-device 2 --grad-chunksize 1900000 --mm-col-chunksize 20000 --loader-batch-size 50 --loader-num-workers 12

The following command computes the empirical NTK for all of CIFAR-10. The output is a 60,000 x 50,000 matrix.

python3 ~/empirical-ntks/ntk.py CIFAR-10 resnet-18_pretrained --workers-per-device 4 --grad-chunksize 1900000 --mm-col-chunksize 20000 --loader-batch-size 50 --loader-num-workers 12

To work with other datasets or models, see utils.py for further options.

Implementation

We pursue a very simple strategy for computing the empirical NTK: compute the N x P Jacobian matrix (for N samples and P parameters) and multiply it with its transpose. To make this computation feasible, we compute the Jacobian matrix in chunks along the P axis with matrices of size N x P0 (where P0 is set by --grad-chunksize). We store this (still large) matrix in RAM. For each chunk, we then compute the N x N matrix obtained by multiplying each chunk by its transpose; for each such computation, we again chunk along the P axis (and optionally along the N axis), sending each matrix multiplication to the GPU. This latter matrix multiplication step is typically the bottleneck in computation time.

By optimizing data transfer, increasing GPU utilization, and parallelizing with care, our implementation improves significantly over naive baselines. See ntk.py for implementation details.

Performance

Our library computes an empirical NTK (60,000 x 50,000) for a ResNet-18 over CIFAR-10 at float32 precision in 43 minutes (<1e-6 seconds per NTK entry) on a machine with four A100 GPUs and 755GB RAM.

Citation

If you find this code useful in your research, please consider citing our paper:

@inproceedings{wei2022more,
  title = {More Than a Toy: Random Matrix Models Predict How Real-World Neural Representations Generalize},
  author = {Wei, Alexander and Hu, Wei and Steinhardt, Jacob},
  booktitle = {Proceedings of the 39th International Conference on Machine Learning},
  year = {2022}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.