Coder Social home page Coder Social logo

Comments (3)

kongsgard avatar kongsgard commented on June 4, 2024

I quickly put together a script for comparison:

import torch
from graphs.losses.dist_chamfer import ChamferDist

cuda0 = torch.device('cuda:0')
x1 = torch.ones([1,2,3], device=cuda0)
x2 = torch.zeros([1,4,3], device=cuda0)

distChamfer = ChamferDist()
dist1, dist2 = distChamfer(x1, x2)
dist_total = torch.mean(dist1) + torch.mean(dist2)
print("PyTorch Chamfer Loss:", dist_total.item())


import numpy as np
from sklearn.neighbors import KDTree

def chamfer_distance_sklearn(array1,array2):
    batch_size, num_point = array1.shape[:2]
    dist = 0
    for i in range(batch_size):
        tree1 = KDTree(array1[i], leaf_size=num_point+1)
        tree2 = KDTree(array2[i], leaf_size=num_point+1)
        distances1, _ = tree1.query(array2[i])
        distances2, _ = tree2.query(array1[i])
        av_dist1 = np.mean(distances1)
        av_dist2 = np.mean(distances2)
        dist = dist + (av_dist1+av_dist2)/batch_size
    return dist

x1_np = np.ones([1,2,3])
x2_np = np.zeros([1,4,3])
print("sklearn Chamfer Loss:", chamfer_distance_sklearn(x1_np, x2_np))

This returns:
PyTorch Chamfer Loss: 6.0
sklearn Chamfer Loss: 3.46

Manual calculation validates the sklearn output.
It seems like it's simply a square root that's missing from the PyTorch implementation.
By calculating the distance with dist_total = torch.mean(torch.sqrt(dist1)) + torch.mean(torch.sqrt(dist2)) we get the same output.

from 3d-coded.

ThibaultGROUEIX avatar ThibaultGROUEIX commented on June 4, 2024

Hi @kongsgard ,
Yes, your analysis is correct. It is actually intentional, It corresponds to how we define it in the paper :https://arxiv.org/pdf/1806.05228.pdf
Best regards,
Thibault

from 3d-coded.

kongsgard avatar kongsgard commented on June 4, 2024

I see. Thank you for your response!

from 3d-coded.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.