Coder Social home page Coder Social logo

borgwardtlab / neuralwalker Goto Github PK

View Code? Open in Web Editor NEW
16.0 3.0 1.0 26.74 MB

Official Pytorch implementation of NeuralWalker

Home Page: https://arxiv.org/pdf/2406.03386

License: BSD 3-Clause "New" or "Revised" License

Python 96.40% Cython 2.27% Shell 1.33%
graph-neural-networks graph-representation-learning graph-transformer random-walk state-space-models

neuralwalker's Introduction

๐Ÿšถ NeuralWalker

arXiv PWC PWC PWC

The repository implements the NeuralWalker in Pytorch Geometric described in the following paper

Dexiong Chen, Till Schulz, and Karsten Borgwardt. Learning Long Range Dependencies on Graphs via Random Walks, Preprint 2024.

TL;DR: A novel random-walk based neural architecture for graph representation learning.

๐Ÿ“– Overview

NeuralWalker

NeuralWalker samples random walks with a predefined sampling rate and length, then uses advanced sequence models to process them. Additionally, local and global message passing can be employed to capture complementary information. The main components of NeuralWalker are a random walk sampler, and a stack of neural walker blocks (a walk encoder block + a message passing block). Each walk encoder block has a walk embedder, a sequence layer, and a walk aggregator.

  • Random walk sampler: samples m random walks independently without replacement.
  • Walk embedder: computes walk embeddings given the node/edge embeddings at the current layer.
  • Sequence layer: any sequence model, e.g. CNNs, RNNs, Transformers, or state-space models.
  • Walk aggregator: aggregates walk features into node features via pooling of the node features encountered in all the walks passing through that node.

๐Ÿ› ๏ธ Installation

We recommend the users to manage dependencies using miniconda or micromamba:

# Replace micromamba with conda if you use conda or miniconda
micromamba env create -f environment.yaml 
micromamba activate neuralwalker
pip install -e .

Note

Our code is also compatible with more recent Pytorch versions, your can use micromamba env create -f environment_latest.yaml for development purposes.

Tip

NeuralWalker relies on a sequence model to process random walk sequences, such as CNNs, Transformers, or state-space models. If you encounter any issues when installing the state-space model Mamba, please consult its installation guideline.

๐Ÿ’ก Reproducing results

Running NeuralWalker on Benchmarking GNNs, LRGB, and OGB

All configurations for the experiments are managed by hydra, stored in ./config.

Below you can find the list of experiments conducted in the paper:

  • Benchmarking GNNs: zinc, mnist, cifar10, pattern, cluster
  • LRGB: pascalvoc, coco, peptides_func, peptides_struct, pcqm_contact
  • OGB: ogbg_molpcba, ogbg_ppa, ogbg_code2
# You can replace zinc with any of the above datasets
python train.py experiment=zinc

# Running NeuralWalker with a different model architecture
python train.py experiment=zinc experiment/model=conv+vn_3L

Tip

You can replace conv+vn_3L with any model provided in config/experiment/model, or a customized model by creating a new one in that folder.

Running NeuralWalker on node classification tasks

We integrate NeuralWalker with Polynormer, SOTA model for node classifcation. See node_classifcation for more details.

Debug mode

python train.py mode=debug

neuralwalker's People

Contributors

claying avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

alexwjung

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.