Coder Social home page Coder Social logo

torch-splatting's Introduction

torch-splatting

A pure pytorch implementation of 3D gaussian splatting.

Train

clone the repo

git clone https://github.com/hbb1/torch-splatting.git --recursive

and run

python train.py

Tile-based rendering is implemeted. Because running loop for python is slow, it uses 64x64-sized tile instead of 16x16 as 3DGSS did. The training time is about 2 hours for 512x512 resolution image for 30k iterartions, tested on a RTX 2080Ti. The number of 3D gaussians is fixed, of 16384 points. Under this setting, it matchs the original diff-gaussian-splatting implementation (~39 PSNR on my synthetic data).

Stay Tuned.

Reference

https://github.com/graphdeco-inria/gaussian-splatting/tree/main

https://github.com/graphdeco-inria/diff-gaussian-rasterization

https://github.com/openai/point-e/tree/main/point_e

torch-splatting's People

Contributors

hbb1 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

torch-splatting's Issues

Bug in gauss renderer

Hello,

This is a great project, really makes prototyping features in splatting much easier!

I found that in order for the project to run, I had to change this line: https://github.com/hbb1/torch-splatting/blob/85c362152883cda066d6a1cc949c2bab7e2cc1c7/gaussian_splatting/gauss_render.py#L209C1-L210C1

from

T = torch.cat([torch.ones_like(alpha[:,:1]), 1-alpha[:,-1:]], dim=1).cumprod(dim=1)

to:

T = torch.cat([torch.ones_like(alpha[:,:1]), 1-alpha[:,:-1]], dim=1).cumprod(dim=1)

Basically, just changing the indexing on the second alpha from [:,-1:] to [:, :-1].

Not sure if this is something just on my machine - but this got rid of a mismatched shape error on the next line. I also think the version I proposed matches the original implementation.

No module named 'simple_knn'

I have installed simple_knn-1.1.6 by "pip install simple_knn", but the wrong " No module named 'simple_knn' " still exists.
What can i do to fix it?

How to calculate the viewspace_point in your renderer code

@hbb1 Hi ! Thanks for awesome work !

Recently , I am trying to add adaptive control based on your torch-rasterizer . I noticed that your code seems not create the screenspoints like official implemention . In official implemention , it first creates the screenspacepoints whose type like guassians.xyz but data is all zero and use retains_grad to retain the gradients of the 2D (screen-space) means. And then it defines means2D= screenspace_points , then pass means2d and other parameters into the rasterizer to render image . Finally it return screenspace_points as viewpoint_tensor . And then use its grad to finish adaptive control . I noticed that yours code 's means2d is created by

means2D = torch.stack([mean_coord_x, mean_coord_y], dim=-1)
. And I want to know how I can get the correct gradients of the 2D (screen-space) means for adaptive control. My current idea is return your code's means2d as viewpoint_tensor and retains its grad for adaptive control. Do you think this is correct? Do you think there have any other way for me to finish adaptive control? Any suggestions is welcomed ! Looking forward to your reply!

adaptive gaussian number

Awesome repo! Are there any plans for implementing the adaptive control of the gaussians? If I want to implement it myself, what files should I look at in particular? Thanks.

Custom Dataset

Hello,

I am a relative beginner, and could not figure out how to create a custom dataset to operate train.py upon. Is it possible if you can provide some guidance?

Thank you in advance.

The problem with training speed

I deployed your project with win10 & Pycharm. But at the moment it remains at 10% after more than 5h of training. I am curious to know if there is any way to speed this up.

Cannot achieve PSNR≈39 on provided dataset

Aloha, great repo!
Just wandering, what dataset you use to get result PSNR≈39? Did you use the dataset included in your repo, or other synthetic dataset?
Cos as I cloned and ran 'train.py' with 30'000 iterations, the result PSNR cannot achieve 39. It's about 35 instead.
Any idea on this? Thx :)

error in build cov2d

Hello! thanks for your great work and I am also try to implement gaussian splatting in pytorch. And I noticed that in the line 92-93 of gauss_render.py:

tx = (t[..., 0] / t[..., 2]).clip(min=-tan_fovx*1.3, max=tan_fovx*1.3) * t[..., 0]
ty = (t[..., 1] / t[..., 2]).clip(min=-tan_fovy*1.3, max=tan_fovy*1.3) * t[..., 1]

but in diff-gaussian-rasterization
image
The clipped x y in multipled by t.z.

Also, I hope you can give more information of your pytorch demo (dataset format...). Thanks a lot!

Bug?

      [0., 0., 0.,  ..., 0., 0., 0.],
      [0., 0., 0.,  ..., 0., 0., 0.],
      [0., 0., 0.,  ..., 0., 0., 0.]]]], device='cuda:0')

loss tensor(1.0310, device='cuda:0')
Traceback (most recent call last):
File "/root/superi/superv.py", line 886, in
main()
File "/root/superi/superv.py", line 882, in main
train(image_base=image_base, image_size=image_size, batch_size_train=2, batch_size_valid=1, epochs=20, checkpoint_path=os.path.join(work_path,checkpoint_base_path), checkpoint_best_file=os.path.join(work_path,checkpoint_base_path, checkpoint_best_name), device=device)
File "/root/superi/superv.py", line 805, in train
loss.backward()
File "/opt/anaconda3/envs/superi/lib/python3.12/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/opt/anaconda3/envs/superi/lib/python3.12/site-packages/torch/autograd/init.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Add requirements.txt file

Please add a requirements.txt file. There is an issue with running the code due to simple-knn that still persists.

No module named 'simple_knn'

I have installed simple_knn-1.1.6 by "pip install simple_knn", but the wrong " No module named 'simple_knn' " still exists.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.