Coder Social home page Coder Social logo

maxreiss123 / medgeconv Goto Github PK

View Code? Open in Web Editor NEW

This project forked from stefreck/medgeconv

0.0 0.0 0.0 357 KB

Tensorflow implemetation of the Edge Convolution

License: MIT License

Shell 1.34% C++ 15.46% Python 73.60% Cuda 6.36% Makefile 3.24%

medgeconv's Introduction

MEdgeConv

https://travis-ci.org/StefReck/MEdgeConv.svg?branch=master

An efficient tensorflow 2 implementation of the edge-convolution layer EdgeConv used in e.g. ParticleNet.

The structure of the layer is as described in 'ParticleNet: Jet Tagging via Particle Clouds' https://arxiv.org/abs/1902.08570.

Instructions

Install via:

pip install medgeconv

Use e.g. like this:

import medgeconv

nodes = medgeconv.DisjointEdgeConvBlock(
    units=[64, 64, 64],
    next_neighbors=16,
)((nodes, coordinates))

Inputs to EdgeConv are 2 ragged tensors: nodes and coordinates

  • nodes, shape (batchsize, None, n_features)
    Node features of the graph. Secound dimension is the number of nodes, which can vary from graph to graph.
  • coordinates, shape (batchsize, None, n_coords)
    Features of each node used for calculating nearest neighbors.

Example: Input for a graph with 2 features per node, and all node features used as coordinates.

import tensorflow as tf

nodes = tf.ragged.constant([
   # graph 1: 2 nodes
   [[2., 4.],
    [2., 6.]],
   # graph 2: 4 nodes
   [[0., 2.],
    [3., 7.],
    [4., 0.],
    [1., 2.]],
], ragged_rank=1)

print(nodes.shape)  # output: (2, None, 2)

# using all node features as coordinates
coordinates = nodes

Example

The full ParticleNet for n_features = n_coords = 2, and a dense layer with 2 neurons as the output can be built like this:

import tensorflow as tf
import medgeconv

inp = (
    tf.keras.Input((None, 2), ragged=True),
    tf.keras.Input((None, 2), ragged=True),
)
x = medgeconv.DisjointEdgeConvBlock(
    units=[64, 64, 64],
    batchnorm_for_nodes=True,
    next_neighbors=16,
)(inp)

x = medgeconv.DisjointEdgeConvBlock(
    units=[128, 128, 128],
    next_neighbors=16,
)(x)

x = medgeconv.DisjointEdgeConvBlock(
    units=[256, 256, 256],
    next_neighbors=16,
    pooling=True,
)(x)

output = tf.keras.layers.Dense(2)(x)
model = tf.keras.Model(inp, output)

The last EdgeConv layer has pooling = True. This will attach a node-wise global average pooling layer in the end, producing normal not-ragged tensors again.

The model can then be used on ragged Tensors:

nodes = tf.RaggedTensor.from_tensor(tf.ones((3, 17, 2)))
model.predict((nodes, nodes))

Loading models

To load models, use the custom_objects:

import medgeconv

model = load_model(path, custom_objects=medgeconv.custom_objects)

knn_graph kernel

This package includes a cuda kernel for calculating the k nearest neighbors on a batch of graphs. It comes with a precompiled kernel for the version of tensorflow specified in requirements.txt.

To compile it locally, e.g. for a different version of tensorflow, go to medgeconv/tf_ops and adjust the compile.sh bash script. Running it will download the specified tf dev docker image and produce the file medgeconv/tf_ops/python/ops/_knn_graph_ops.so.

Publications

Results using this model architecture in the context of particle physics were presented at the ICRC 2021 conference https://doi.org/10.22323/1.395.1048 , as well as the VLVnT 2021 https://arxiv.org/abs/2107.13375 .

medgeconv's People

Contributors

stefreck avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.