Coder Social home page Coder Social logo

tmd's Introduction

Tree Mover's Distance for Graphs

Understanding generalization and robustness of machine learning models fundamentally relies on assuming an appropriate metric on the data space. Identifying such a metric is particularly challenging for non-Euclidean data such as graphs. Here, we propose a pseudometric for attributed graphs, the Tree Mover's Distance (TMD), and study its relation to generalization. Via a hierarchical optimal transport problem, TMD reflects the local distribution of node attributes as well as the distribution of local computation trees, which are known to be decisive for the learning behavior of graph neural networks (GNNs). First, we show that TMD captures properties relevant to graph classification: a simple TMD-SVM performs competitively with standard GNNs. Second, we relate TMD to generalization of GNNs under distribution shifts, and show that it correlates well with performance drop under such shifts.

Tree Mover’s Distance: Bridging Graph Metrics and Stability of Graph Neural Networks NeurIPS 2022 [paper]
Ching-Yao Chuang and Stefanie Jegelka

Prerequisites

  • Python 3.7
  • PyTorch 1.3.1
  • PyTorch Geometric
  • POT

Usage Examples

The code for computing Tree Mover's Distance (TMD) lie in tmd.py. For instance, the following code compute the TMD between two graphs of MUTAG dataset.

from tmd import TMD
from torch_geometric.datasets import TUDataset

dataset = TUDataset('data', name='MUTAG')
d = TMD(dataset[0], dataset[1], w=1.0, L=4)

One can also specify different weighting constants for each layer as follows:

d = TMD(dataset[0], dataset[1], w=[0.33, 1, 3], L=4)

This results in a tighter bound on the stability of GNNs as Theorem 8 shows. Note that len(w) has to be the same as L-1.

Graph Classification on TUDataset

Step 1: Pre-compute the pairwise distance (potentially parallel). For instance, the following script compute the pairwise distances of MUTAG with pairwise_dist.py by separating it into 4 batches, where each batch is computed parallely. One can merge the batches with merge.py.

python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 0
python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 1
python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 2
python pairwise_dist.py --w 0.5 --L 4 --dataset MUTAG --n_per_idx 50 --idx 3

python merge.py --w 0.5 --L 4 --dataset MUTAG

Step 2: Train a SVM classifier based on the pre-computed distances:

python svc.py --w 0.5 --L 4 --dataset MUTAG

Measuring the Stability of GNNs

The script stability.py reproduce the stability experiments in Figure 5. In particular, it plots the correlation between a (L+1)-layer GIN and the tree mover's distance with graphs sampled from MUTAG.

python stability.py --L 3

Citation

If you find this repo useful for your research, please consider citing the paper

@article{chuang2022tree,
  title={Tree Mover’s Distance: Bridging Graph Metrics and Stability of Graph Neural Networks},
  author={Chuang, Ching-Yao and Jegelka, Stefanie},
  journal={Advances in Neural Information Processing Systems},
  volume={35},
  year={2022}
}

tmd's People

Contributors

chingyaoc avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.