Coder Social home page Coder Social logo

kvignesh1420 / gnn_collapse Goto Github PK

View Code? Open in Web Editor NEW
5.0 2.0 1.0 493 KB

[NeurIPS 2023] Official implementation of "A Neural Collapse Perspective on Feature Evolution in Graph Neural Networks"

Home Page: https://arxiv.org/abs/2307.01951

License: Apache License 2.0

Python 98.56% Shell 1.44%
graph-neural-networks neural-collapse optimization community-detection stochastic-block-model

gnn_collapse's Introduction

Exploring neural collapse in graph neural networks

This repo contains the code for our NeurIPS 2023 paper titled "A Neural Collapse Perspective on Feature Evolution in Graph Neural Networks". In this work, we especially focus on the role of graph structure in facilitating/hindering the tendancy towards collapsed minimizers.

NC1 metrics across projected power iterations and GNN layers.

Setup

$ python3.9 -m virtualenv .venv
$ source .venv/bin/activate
$ pip install -r requirements.txt

Data

We randomly sample graphs from the stochastic block model to control the properties of planted communities. The SBM class in gnn_collapse.data.sbm is an instance of torch_geometric.data.Dataset and facilitates direct encapsulation with torch DataLoader. Currently, we support the following feature strategies:

class FeatureStrategy(Enum):
    EMPTY = "empty"
    DEGREE = "degree"
    RANDOM = "random"
    RANDOM_NORMAL = "random_normal"
    DEGREE_RANDOM = "degree_random"
    DEGREE_RANDOM_NORMAL = "degree_random_normal"

Models

We primarily focus on the GraphConv model due to it's simplicity and similarity with a wide variety of message passing approaches. We customize the source code of class GraphConv(MessagePassing) (available here) to control whether the lin_root weight matrix ($W_1$ in the paper) is applied or not.

To add new models, one key point to consider is the naming convention of the weight matrices in various layers. For instance, the GCNConv layer has a single lin property that corresponds to the weight matrix. To handle such scenarios, it is best to modify the weight variable allocation in the track_train_graphs_final_nc(...) method (in the gnn_collapse.train.online.OnlineRunner() class).

Finally, to register a new model, please add an entry in the gnn_collapse.models.GNN_factory dictionary. This will facilitate model name validation and custom behaviours (such as the weight matrix selection, mentioned above) during training/inference.

NOTE: The code for gnn_collapse.models.graphconv.GraphConvModel() can be used as a reference to add new models.

Experiments

We employ a config based design to run and hash the experiments. The configs folder contains the final folder to maintain the set of experiments that have been presented in the paper. The experimental folder is a placeholder for new contributions. A config file is a JSON formatted file which is passed to the python script for parsing. The config determines the runtime parameters of the experiment and is hashed for uniqueness.

To run GNN experiments:

$ bash run_gnn.sh

To run gUFM experiments

$ bash run_ufm.sh

To run GNN experiments with larger depth

$ bash run_gnn_deeper.sh

To run spectral methods experiments

$ bash run_spectral.sh

A new folder called out will be created and the results are stored in a folder named after the hash of the config.

Citation

@inproceedings{kothapalli2023neural,
  title={A Neural Collapse Perspective on Feature Evolution in Graph Neural Networks},
  author={Kothapalli, Vignesh and Tirer, Tom and Bruna, Joan},
  booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
  year={2023}
}

Contributing

Please feel free to open issues and create pull requests to fix bugs and improve performance.

License

MIT

gnn_collapse's People

Contributors

dependabot[bot] avatar kvignesh1420 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

bmorphism

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.