Coder Social home page Coder Social logo

gcn-pointcloud-classification's Introduction

Pointcloud Classification with PointNet and GCNs

This is a project using PointNet and GCNs for pointcloud classification. We implement the original PointNet[1] and make some attempts on incorporating graph convolutional networks[2] into PointNet. Both models are trained and tested on the ModelNet40[3] Dataset.

Model Description

This is an illustration of PointNet, taken from Figure 2 in [1]. The max-pooling operation aggregates the local feature from each point and ensures permutation invariance. The T-Net are a mini version of PointNet, consisting of a shared MLP(64, 128, 1024), a max-pooling operator and a MLP(512, 256, k×k), which regress the n×k input to a k×k transform.

illustration for PointNet

To incorporate graph convolutional networks, I modified the three MLP(64, 128, 1024) s (one in the main PointNet, two in TNets) by adding GCN layers, illustrated as follows. There are also dense connections between the GCN layers, combining features at different levels. To compare their performances, their output dimensions are kept identical (both are 1024). In addition, GCNs require graphs as an extra input, which are constructed by a thresholded Gaussian kernel weighting function. Hopefully the modification will endow the network with more capacity to capture structural information, and thus result in better performances.

incorporating GCNs to PointNet

Implementation

We implement both models in pytorch. Codes are stored in Models/PointNet/ and Models/PointNet+GCN, respectively. The dataset classes in Models/*/data.py are modified from https://github.com/WangYueFt/dgcnn/blob/master/pytorch/data.py. The pytorch implementation for GCN ( Models/PointNet+GCN/layers.py and the class GCN inModels/PointNet+GCN/models.py) are borrowed from https://github.com/tkipf/pygcn/blob/master/pygcn/layer.py.

Similar to [1], an regularization loss term (with weight 0.001) encouraging the 64-by-64 feature transform matrix to be close to an orthonormal matrix is added to the total loss. We use Adam optimizer, with learning rate initially set as 0.001 and decayed by 0.95 every epoch. We test both models for n (#points) = 128, 256, 512, 1024.

Results. The results of both models are listed below.

PointNet PointNet+GCN
n=128 86.7099% 87.6823%
n=256 87.2771% 87.4392%
n=512 87.3582% 87.4797%
n=1024 89.5057% 88.5332%

The overall accuracy over 40 classes of PointNet is comparable to the official results given by [1]. The GCN version performs slightly better than PointNet when n=128, 256, 512, but worse when n=1024. Noting than the relationships between the points are important especially when n is small, the results indicate that graph convolution layers probably play a role in capturing structural information. However, we have to admit that when n is large, the training of GCNs become really inefficient, and it is still possible that delicate hyper-parameter tuning may result in better accuracies and lead to more evident conclusions.

How to Use

First switch the working directory, by

cd Models/"PointNet" 
cd Models/"PointNet+GCN" 

Training

Run

python train.py -lr=1e-3 -n=1024 --model=PointNet.pt
python train.py -lr=1e-3 -n=1024 --model=PointNetGCN.pt

for training. The ModelNet40 dataset will be automatically downloaded (you can also download it manually from https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip , and unzip the file to ./Datasets/ directory).

The trained model will be stored in PointNet.pt or PointNetGCN.pt.

Testing

Run

python test.py -n=1024 --model=PointNet.pt
python test.py -n=1024 --model=PointNetGCN.pt

for testing the trained model (loaded from PointNet.pt or PointNetGCN.pt).

Using pretrained models

If you want to use our pretrained models, please download them from the following links:

Please save the *.pt files in Models/"PointNet" and Models/"PointNet+GCN", rename them as PointNet.pt and PointNetGCN.pt, respectively, and run the previous testing command.

References

[1] Charles, R. Q., Su, H., Kaichun, M., & Guibas, L. J. (2017). PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation. computer vision and pattern recognition.

[2] Kipf, T., & Welling, M. (2017). Semi-Supervised Classification with Graph Convolutional Networks. international conference on learning representations.

[3] Wu, Z., Song, S., Khosla, A., Yu, F., Zhang, L., Tang, X., & Xiao, J. (2015). 3D ShapeNets: A deep representation for volumetric shapes. computer vision and pattern recognition.

gcn-pointcloud-classification's People

Contributors

chendaichao avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

gcn-pointcloud-classification's Issues

train speed

dear ,why the POINT+GCN train speed is so slow??
I want to know your train speed.
thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.