Coder Social home page Coder Social logo

pvnieo / surfmnet-pytorch Goto Github PK

View Code? Open in Web Editor NEW
16.0 4.0 2.0 18.48 MB

A pytorch implementation of: "Unsupervised Deep Learning for Structured Shape Matching"

License: MIT License

Python 90.20% CMake 1.74% C++ 8.06%
unsupervised-deep-learning surfmnet-pytorch shape-matching shape-descriptor shape-correspondence python3 pytorch functional-maps

surfmnet-pytorch's Introduction

SURFMNet-pytorch

A pytorch implementation of: "Unsupervised Deep Learning for Structured Shape Matching" [link]

Installation

This implementation runs on python >= 3.7, use pip to install dependencies:

pip3 install -r requirements.txt

Download data & preprocessing

Download the desired dataset and put it in the data folder. Multiple datasets are available here.

An example with the faust-remeshed dataset is provided.

Build shot calculator:

cd fmnet/utils/shot
cmake .
make

If you got any errors in compiling shot, please see here.

Use fmnet/preprocess.py to calculate the Laplace decomposition, geodesic distance using the Dijkstra algorithm and the shot descriptors of input shapes, data are saved in .mat format:

usage: preprocess.py [-h] [-d DATAROOT] [-sd SAVE_DIR] [-ne NUM_EIGEN] [-nj NJOBS] [--nn NN] [--geo]

Preprocess data for FMNet training. Compute Laplacian eigen decomposition, shot features, and geodesic distance for each shape.

optional arguments:
  -h, --help            show this help message and exit
  -d DATAROOT, --dataroot DATAROOT
                        root directory of the dataset
  -sd SAVE_DIR, --save-dir SAVE_DIR
                        root directory to save the processed dataset
  -ne NUM_EIGEN, --num-eigen NUM_EIGEN
                        number of eigenvectors kept.
  -nj NJOBS, --njobs NJOBS
                        Number of parallel processes to use.
  --nn NN               Number of Neighbor to consider when computing geodesic matrix.
  --geo                 Compute geodesic distances.

NB: if the shapes have many vertices, the computation of geodesic distance will consume a lot of memory and take a lot of time.

Usage

Use the train.py script to train the SURFMNET network.

usage: train.py [-h] [--lr LR] [--b1 B1] [--b2 B2] [-bs BATCH_SIZE] [--n-epochs N_EPOCHS] [--dim-basis DIM_BASIS] [-nv N_VERTICES] [-nb NUM_BLOCKS] [--wb WB] [--wo WO] [--wl WL] [--wd WD]
                [--sub-wd SUB_WD] [-d DATAROOT] [--save-dir SAVE_DIR] [--n-cpu N_CPU] [--no-cuda] [--checkpoint-interval CHECKPOINT_INTERVAL] [--log-interval LOG_INTERVAL]

Launch the training of SURFMNet model.

optional arguments:
  -h, --help            show this help message and exit
  --lr LR               adam: learning rate
  --b1 B1               adam: decay of first order momentum of gradient
  --b2 B2               adam: decay of first order momentum of gradient
  -bs BATCH_SIZE, --batch-size BATCH_SIZE
                        size of the batches
  --n-epochs N_EPOCHS   number of epochs of training
  --dim-basis DIM_BASIS
                        number of eigenvectors used for representation.
  -nv N_VERTICES, --n-vertices N_VERTICES
                        Number of vertices used per shape
  -nb NUM_BLOCKS, --num-blocks NUM_BLOCKS
                        number of resnet blocks
  --wb WB               Bijectivity penalty weight
  --wo WO               Orthogonality penalty weight
  --wl WL               Laplacian commutativity penalty weight
  --wd WD               Descriptor preservation via commutativity penalty weight
  --sub-wd SUB_WD       Percentage of subsampled vertices used to compute descriptor preservation commutativity penalty
  -d DATAROOT, --dataroot DATAROOT
                        root directory of the dataset
  --save-dir SAVE_DIR   root directory of the dataset
  --n-cpu N_CPU         number of cpu threads to use during batch generation
  --no-cuda             Disable GPU computation
  --checkpoint-interval CHECKPOINT_INTERVAL
                        interval between model checkpoints
  --log-interval LOG_INTERVAL
                        interval between logging train information

Example

python3 train.py -bs 4 --n-epochs 20

surfmnet-pytorch's People

Contributors

pvnieo avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

surfmnet-pytorch's Issues

Overfit or wrong loss

Hi @pvnieo , I am training the network with the Faust_original dataset while there is something weir happening. Within 50 mini-batch(term used in the paper) iterations, the loss goes up quickly. I tried batch-sizes 1,2,4,8 but none worked, e.g.

python3 train.py -bs 8 --n-epochs 2 -d data/Faust_original/MAT

I have tried lowering the lr by order of 2, e.g. 1e-5, which does prolong the loss upto 6th epoches but the loss never went down below 600. (The paper mentioned a scale of 10 from there results). I am not sure it is a sign of overfit or something else from the implementation. Any suggestion?

My setup is:
Ubuntu 20.04
Python 3.8.5
and

Package           Version
----------------- -----------
numpy             1.19.4
Pillow            8.0.1
pip               20.2.4
scipy             1.5.4
setuptools        50.3.2
torch             1.7.0+cu110
torchaudio        0.7.0
torchvision       0.8.1+cu110
...
#epoch:1, #batch:1, #iteration:1, loss:618366.5
#epoch:1, #batch:2, #iteration:2, loss:191922.375
#epoch:1, #batch:3, #iteration:3, loss:390996.5625
#epoch:1, #batch:4, #iteration:4, loss:155653.78125
#epoch:1, #batch:5, #iteration:5, loss:122857.890625
#epoch:1, #batch:6, #iteration:6, loss:56080.80859375
#epoch:1, #batch:7, #iteration:7, loss:54986.0859375
#epoch:1, #batch:8, #iteration:8, loss:38508.7421875
#epoch:1, #batch:9, #iteration:9, loss:34346.59375
#epoch:1, #batch:10, #iteration:10, loss:48555.078125
#epoch:1, #batch:11, #iteration:11, loss:30304.625
#epoch:1, #batch:12, #iteration:12, loss:29533.83984375
#epoch:1, #batch:13, #iteration:13, loss:31013.697265625
#epoch:1, #batch:14, #iteration:14, loss:32305.798828125
#epoch:1, #batch:15, #iteration:15, loss:26479.25
#epoch:1, #batch:16, #iteration:16, loss:21710.76171875
#epoch:1, #batch:17, #iteration:17, loss:20447.109375
#epoch:1, #batch:18, #iteration:18, loss:19567.962890625
#epoch:1, #batch:19, #iteration:19, loss:23133.2578125
#epoch:1, #batch:20, #iteration:20, loss:15480.0986328125
#epoch:1, #batch:21, #iteration:21, loss:16198.91796875
#epoch:1, #batch:22, #iteration:22, loss:17660.18359375
#epoch:1, #batch:23, #iteration:23, loss:14030.8515625
#epoch:1, #batch:24, #iteration:24, loss:8511.982421875   <---------------------------
#epoch:1, #batch:25, #iteration:25, loss:43280.13671875
#epoch:1, #batch:26, #iteration:26, loss:34185.2734375
#epoch:1, #batch:27, #iteration:27, loss:15525.109375
#epoch:1, #batch:28, #iteration:28, loss:11670.400390625
#epoch:1, #batch:29, #iteration:29, loss:18548.89453125
#epoch:1, #batch:30, #iteration:30, loss:13342.068359375
#epoch:1, #batch:31, #iteration:31, loss:24310.1484375
#epoch:1, #batch:32, #iteration:32, loss:18103.34765625
#epoch:1, #batch:33, #iteration:33, loss:12983.232421875
#epoch:1, #batch:34, #iteration:34, loss:15744.41015625
#epoch:1, #batch:35, #iteration:35, loss:19102.33203125
#epoch:1, #batch:36, #iteration:36, loss:14325.658203125
#epoch:1, #batch:37, #iteration:37, loss:15989.193359375
#epoch:1, #batch:38, #iteration:38, loss:22554.8984375
#epoch:1, #batch:39, #iteration:39, loss:26798.611328125
#epoch:1, #batch:40, #iteration:40, loss:42119.96875
#epoch:1, #batch:41, #iteration:41, loss:93560.6328125
#epoch:1, #batch:42, #iteration:42, loss:117597.265625
#epoch:1, #batch:43, #iteration:43, loss:254039.015625
#epoch:1, #batch:44, #iteration:44, loss:134116.859375
#epoch:1, #batch:45, #iteration:45, loss:158679.6875
#epoch:1, #batch:46, #iteration:46, loss:165269.21875
#epoch:1, #batch:47, #iteration:47, loss:162998.328125
#epoch:1, #batch:48, #iteration:48, loss:245227.140625
#epoch:1, #batch:49, #iteration:49, loss:232698.59375
#epoch:1, #batch:50, #iteration:50, loss:629152.125
#epoch:1, #batch:51, #iteration:51, loss:379238.6875
#epoch:1, #batch:52, #iteration:52, loss:311628.1875
#epoch:1, #batch:53, #iteration:53, loss:421314.84375
...

Error with different number of vertices

When testing with the remesh set (MAT_SHOT), the program failed with

Original Traceback (most recent call last):
  File "...\Python383\lib\site-packages\torch\utils\data\_utils\worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "...\Python383\lib\site-packages\torch\utils\data\_utils\fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "...\Python383\lib\site-packages\torch\utils\data\_utils\collate.py", line 83, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "...\Python383\lib\site-packages\torch\utils\data\_utils\collate.py", line 83, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "...\Python383\lib\site-packages\torch\utils\data\_utils\collate.py", line 55, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [5001, 352] at entry 0 and [5000, 352] at entry 1

All parameters for train.py used are default.

Moreover, could the evaluation/inferernce code be provided ?

Loss goes to nan after 6 iterations

I'm following through your examples and am having issues when I start training.

It starts off fine but after 6 iterations the loss turns to nan. Do you have any guidance on what could be causing this?

about code

Hello, this code doesn't seem to have a test section.

Point-to-point correspondence reconstruction

Hi @pvnieo

I am having a hard time on reconstructing the point-to-point (p2p) correspondence from the trained model.
In order to verify the issue, I tried to train a super-tiny dataset contains, 000.off and 001.off (from the provided data-zoo) only , with either basis-count=40, 60, 100, 120 each at 4000 iterations and 1500 vertices.

Neither of the cases could reproduce an appropriate p2p map.

The following is a snapshot for that part

  ...
  feat_x, _, evecs_x, evecs_trans_x = dataset.loader(args.shapeX) #path to 000.mat
  feat_y, _, evecs_y, evecs_trans_y = dataset.loader(args.shapeY) #path to 001.mat
  ...
  C1, C2, feat_1, feat_2 = surfmnet(feat_x, feat_y, evecs_trans_x, evecs_trans_y)
  FM_t = torch.transpose(C1, 1, 2)

  P = torch.bmm(evecs_x, FM_t)
  P = P.cpu().detach()

  tree = KDTree(P[0])  # Tree on (n1,k2)
  p2p = tree.query(evecs_y, k=1, return_distance=False).flatten()  # (n2,)
  ...

Do you have any idea ?

Edit:

When preprocessing, the old/new sqrt areas are identical, is it normal?

Unable to build the shot.cpp file

Hi,

Could you please share more information about the system (was Docker used?) that was used to build the shot script? I have tried multiple Python configurations, including different versions of Ubuntu and g++, but have had no luck.

Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.