Comments (3)
I quickly put together a script for comparison:
import torch
from graphs.losses.dist_chamfer import ChamferDist
cuda0 = torch.device('cuda:0')
x1 = torch.ones([1,2,3], device=cuda0)
x2 = torch.zeros([1,4,3], device=cuda0)
distChamfer = ChamferDist()
dist1, dist2 = distChamfer(x1, x2)
dist_total = torch.mean(dist1) + torch.mean(dist2)
print("PyTorch Chamfer Loss:", dist_total.item())
import numpy as np
from sklearn.neighbors import KDTree
def chamfer_distance_sklearn(array1,array2):
batch_size, num_point = array1.shape[:2]
dist = 0
for i in range(batch_size):
tree1 = KDTree(array1[i], leaf_size=num_point+1)
tree2 = KDTree(array2[i], leaf_size=num_point+1)
distances1, _ = tree1.query(array2[i])
distances2, _ = tree2.query(array1[i])
av_dist1 = np.mean(distances1)
av_dist2 = np.mean(distances2)
dist = dist + (av_dist1+av_dist2)/batch_size
return dist
x1_np = np.ones([1,2,3])
x2_np = np.zeros([1,4,3])
print("sklearn Chamfer Loss:", chamfer_distance_sklearn(x1_np, x2_np))
This returns:
PyTorch Chamfer Loss: 6.0
sklearn Chamfer Loss: 3.46
Manual calculation validates the sklearn output.
It seems like it's simply a square root that's missing from the PyTorch implementation.
By calculating the distance with dist_total = torch.mean(torch.sqrt(dist1)) + torch.mean(torch.sqrt(dist2))
we get the same output.
from 3d-coded.
Hi @kongsgard ,
Yes, your analysis is correct. It is actually intentional, It corresponds to how we define it in the paper :https://arxiv.org/pdf/1806.05228.pdf
Best regards,
Thibault
from 3d-coded.
I see. Thank you for your response!
from 3d-coded.
Related Issues (20)
- Get error by training the train_unsup.py HOT 9
- _pickle.UnpicklingError: invalid load key, '<'. HOT 1
- How to compute the error? HOT 1
- About unsup_training HOT 4
- About dataset processing HOT 3
- data is preprocessed HOT 18
- data is preprocessed HOT 1
- Results of computing correspondences HOT 5
- Rotation invariance HOT 2
- How is the error for correspondences computed? HOT 1
- Problems with downloading data HOT 2
- SMALL-R animals dataset HOT 2
- average Euclidean error on SCAPE HOT 4
- FileNotFoundError: [Errno 2] No such file or directory: 'chamfer_cuda.cpp' HOT 3
- how can i get the correspondence index? HOT 47
- About unsup_training HOT 1
- Training ERRor HOT 6
- Get terrible result HOT 2
- what's the network.mesh
- Problem executing the demo - `GLIBCXX_3.4.29' not found HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from 3d-coded.