Coder Social home page Coder Social logo

vita-group / gnt Goto Github PK

View Code? Open in Web Editor NEW
328.0 22.0 24.0 41.78 MB

[ICLR 2023] "Is Attention All NeRF Needs?" by Mukund Varma T*, Peihao Wang* , Xuxi Chen, Tianlong Chen, Subhashini Venugopalan, Zhangyang Wang

Home Page: https://vita-group.github.io/GNT

License: MIT License

Python 100.00%
neural-radiance-fields neural-rendering transformer

gnt's Introduction

Is Attention All That NeRF Needs?

Mukund Varma T1*, Peihao Wang2*, Xuxi Chen2, Tianlong Chen2, Subhashini Venugopalan3, Zhangyang Wang2

1Indian Institute of Technology Madras, 2University of Texas at Austin, 3Google Research

* denotes equal contribution.

Project Page | Paper

This repository is built based on IBRNet's offical repository

  • News! GNT is accepted at ICLR 2023 🎉. Our updated cross-scene trained checkpoint should generalize to complex scenes, and even achieve comparable results to SOTA per-scene optimized methods without further tuning!
  • News! Our work was presented by Prof. Atlas in his talk at the MIT Vision and Graphics Seminar on 10/17/22.

Introduction

We present Generalizable NeRF Transformer (GNT), a pure, unified transformer-based architecture that efficiently reconstructs Neural Radiance Fields (NeRFs) on the fly from source views. Unlike prior works on NeRF that optimize a per-scene implicit representation by inverting a handcrafted rendering equation, GNT achieves generalizable neural scene representation and rendering, by encapsulating two transformers-based stages. The first stage of GNT, called view transformer, leverages multi-view geometry as an inductive bias for attention-based scene representation, and predicts coordinate-aligned features by aggregating information from epipolar lines on the neighboring views. The second stage of GNT, named ray transformer, renders novel views by ray marching and directly decodes the sequence of sampled point features using the attention mechanism. Our experiments demonstrate that when optimized on a single scene, GNT can successfully reconstruct NeRF without explicit rendering formula, and even improve the PSNR by ~1.3 dB↑ on complex scenes due to the learnable ray renderer. When trained across various scenes, GNT consistently achieves the state-of-the-art performance when transferring to forward-facing LLFF dataset (LPIPS ~20%↓, SSIM ~25%↑) and synthetic blender dataset (LPIPS ~20%↓, SSIM ~4%↑). In addition, we show that depth and occlusion can be inferred from the learned attention maps, which implies that the pure attention mechanism is capable of learning a physically-grounded rendering process. All these results bring us one step closer to the tantalizing hope of utilizing transformers as the ``universal modeling tool'' even for graphics.

teaser

Installation

Clone this repository:

git clone https://github.com/MukundVarmaT/GNT.git
cd GNT/

The code is tested with python 3.8, cuda == 11.1, pytorch == 1.10.1. Additionally dependencies include:

torchvision
ConfigArgParse
imageio
matplotlib
numpy
opencv_contrib_python
Pillow
scipy
imageio-ffmpeg
lpips
scikit-image

Datasets

We reuse the training, evaluation datasets from IBRNet. All datasets must be downloaded to a directory data/ within the project folder and must follow the below organization.

├──data/
    ├──ibrnet_collected_1/
    ├──ibrnet_collected_2/
    ├──real_iconic_noface/
    ├──spaces_dataset/
    ├──RealEstate10K-subset/
    ├──google_scanned_objects/
    ├──nerf_synthetic/
    ├──nerf_llff_data/

We refer to IBRNet's repository to download and prepare data. For ease, we consolidate the instructions below:

mkdir data
cd data/

# IBRNet captures
gdown https://drive.google.com/uc?id=1rkzl3ecL3H0Xxf5WTyc2Swv30RIyr1R_
unzip ibrnet_collected.zip

# LLFF
gdown https://drive.google.com/uc?id=1ThgjloNt58ZdnEuiCeRf9tATJ-HI0b01
unzip real_iconic_noface.zip

## [IMPORTANT] remove scenes that appear in the test set
cd real_iconic_noface/
rm -rf data2_fernvlsb data2_hugetrike data2_trexsanta data3_orchid data5_leafscene data5_lotr data5_redflower
cd ../

# Spaces dataset
git clone https://github.com/augmentedperception/spaces_dataset

# RealEstate 10k
## make sure to install ffmpeg - sudo apt-get install ffmpeg
git clone https://github.com/qianqianwang68/RealEstate10K_Downloader
cd RealEstate10K_Downloader
python3 generate_dataset.py train
cd ../

# Google Scanned Objects
gdown https://drive.google.com/uc?id=1w1Cs0yztH6kE3JIz7mdggvPGCwIKkVi2
unzip google_scanned_objects_renderings.zip

# Blender dataset
gdown https://drive.google.com/uc?id=18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG
unzip nerf_synthetic.zip

# LLFF dataset (eval)
gdown https://drive.google.com/uc?id=16VnMcF1KJYxN9QId6TClMsZRahHNMW5g
unzip nerf_llff_data.zip

Usage

Training

# single scene
# python3 train.py --config <config> --train_scenes <scene> --eval_scenes <scene> --optional[other kwargs]. Example:
python3 train.py --config configs/gnt_blender.txt --train_scenes drums --eval_scenes drums
python3 train.py --config configs/gnt_llff.txt --train_scenes orchids --eval_scenes orchids

# cross scene
# python3 train.py --config <config> --optional[other kwargs]. Example:
python3 train.py --config configs/gnt_full.txt 

To decode coarse-fine outputs set --N_importance > 0, and with a separate fine network use --single_net = False

Pre-trained Models

Dataset Scene Download
LLFF fern ckpt renders
flower ckpt renders
fortress ckpt renders
horns ckpt renders
leaves ckpt renders
orchids ckpt renders
room ckpt renders
trex ckpt renders
Synthetic chair ckpt renders
drums ckpt renders
ficus ckpt renders
hotdog ckpt renders
lego ckpt renders
materials ckpt renders
mic ckpt renders
ship ckpt renders
generalization N.A. ckpt renders

To reuse pretrained models, download the required checkpoints and place in appropriate directory with name - gnt_<scene-name> (single scene) or gnt_<full> (generalization). Then proceed to evaluation / rendering. To facilitate future research, we also provide half resolution renderings of our method on several benchmark scenes. Incase there are issues with any of the above checkpoints, please feel free to open an issue.

Evaluation

# single scene
# python3 eval.py --config <config> --eval_scenes <scene> --expname <out-dir> --run_val --optional[other kwargs]. Example:
python3 eval.py --config configs/gnt_llff.txt --eval_scenes orchids --expname gnt_orchids --chunk_size 500 --run_val --N_samples 192
python3 eval.py --config configs/gnt_blender.txt --eval_scenes drums --expname gnt_drums --chunk_size 500 --run_val --N_samples 192

# cross scene
# python3 eval.py --config <config> --expname <out-dir> --run_val --optional[other kwargs]. Example:
python3 eval.py --config configs/gnt_full.txt --expname gnt_full --chunk_size 500 --run_val --N_samples 192

Rendering

To render videos of smooth camera paths for the real forward-facing scenes.

# python3 render.py --config <config> --eval_dataset llff_render --eval_scenes <scene> --expname <out-dir> --optional[other kwargs]. Example:
python3 render.py --config configs/gnt_llff.txt --eval_dataset llff_render --eval_scenes orchids --expname gnt_orchids --chunk_size 500 --N_samples 192

The code has been recently tidied up for release and could perhaps contain tiny bugs. Please feel free to open an issue.

Cite this work

If you find our work / code implementation useful for your own research, please cite our paper.

@inproceedings{
    t2023is,
    title={Is Attention All That Ne{RF} Needs?},
    author={Mukund Varma T and Peihao Wang and Xuxi Chen and Tianlong Chen and Subhashini Venugopalan and Zhangyang Wang},
    booktitle={The Eleventh International Conference on Learning Representations },
    year={2023},
    url={https://openreview.net/forum?id=xE-LtsE-xx}
}

gnt's People

Contributors

buesma avatar mukundvarmat avatar peihaowang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gnt's Issues

Export the model as Mesh ply or obj format

Hi,
Thanks to author for amazing work.
Can you please tell, how to export the model as Mesh ply or obj format ?
And can render render videos for 360° scenes ?

Thankful for response in advance.

About Cross-scene Training Setting

Thanks for you attention. You provided training command as python3 train.py --config configs/gnt_full.txt without other arguments, but I find the default number of ray sampling is 512 in config.py, while your paper said this number is 4096 which is 8 times of default setting. Dose it mean you trained you model with 8 gpu?

How to apply multi-gpu training?

Hello, thanks for the great work! I'm wondering how can we apply mulit-gpu training?

I use the following command

python train.py --config configs/gnt_ft_rffr.txt --distributed --local_rank 2

but it occurs the following problems:

Error initializing torch.distributed using env:// rendezvous: environment variable RANK expected, but not set

The distributed training code of train.py is shown below:

    if args.distributed:
        torch.distributed.init_process_group(backend="nccl", init_method="env://localhost:50000")
        args.local_rank = int(os.environ.get("LOCAL_RANK"))
        torch.cuda.set_device(args.local_rank)

Image cropping during training

Dear authors,

Thanks for your great work! I have a question about the image cropping operations during training.

Starting from L123 of ./gnt/data_loaders/llff.py

if self.mode == "train":
there are some cropping operations during training with the LLFF dataset.

In the default setting, when factor = 4 for LLFF dataset, the original resolution should be 1008*756. I think crop_h = np.random.randint(low=250, high=750) means getting a cropped patch with height within [250, 750]. But crop_w = int(400 * 600 / crop_h) does not give corresponding width which renders the patch.

I think there should be something like crop_w = int(600 * crop_h / 400), but in this case the ratio of the height / width of the cropped patch becomes 3: 2, which is different from the original one which is 4: 3. I'm wondering whether there are some bugs.

Thank you in advance!

abnormal near_depth in render result

Hi, author. Thank you for your great work first. I try to render horns' pictures by using the pretrained-model on generalization dataset. But when I run the render.py with parameters "--config configs/gnt_llff.txt --eval_dataset llff_render --eval_scenes horns --expname gnt_horns --chunk_size 500 --N_samples 192", it return "ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group"

Traceback (most recent call last):
File "/home/zheshi/Documents/GNT/render.py", line 191, in
render(args)
File "/home/zheshi/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "/home/zheshi/Documents/GNT/render.py", line 66, in render
model = GNTModel(
File "/home/zheshi/Documents/GNT/gnt/model.py", line 76, in init
self.start_step = self.load_from_ckpt(
File "/home/zheshi/Documents/GNT/gnt/model.py", line 161, in load_from_ckpt
self.load_model(fpath, load_opt, load_scheduler)
File "/home/zheshi/Documents/GNT/gnt/model.py", line 127, in load_model
self.optimizer.load_state_dict(to_load["optimizer"])
File "/home/zheshi/.local/lib/python3.8/site-packages/torch/optim/optimizer.py", line 146, in load_state_dict
raise ValueError("loaded state dict contains a parameter group "
ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group

Can you help me with this problem? Thanks in advance.

AttributeError: Can't pickle local object 'train.<locals>.<lambda>'

D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\python.exe train.py --config configs/gnt_llff.txt --train_scenes fern --eval_scenes fern
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\torchvision\models_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' in
stead.
warnings.warn(
D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\torchvision\models_utils.py:223: UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in th
e future. The current behavior is equivalent to passing weights=AlexNet_Weights.IMAGENET1K_V1. You can also use weights=AlexNet_Weights.DEFAULT to get the most up-to-date weights.
warnings.warn(msg)
Loading model from: D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\lpips\weights\v0.1\alex.pth
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\torchvision\models_utils.py:223: UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in th
e future. The current behavior is equivalent to passing weights=VGG16_Weights.IMAGENET1K_V1. You can also use weights=VGG16_Weights.DEFAULT to get the most up-to-date weights.
warnings.warn(msg)
Loading model from: D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\lpips\weights\v0.1\vgg.pth
outputs will be saved to ./out\gnt_llff
training dataset: llff_test
loading ['fern'] for train
loading ['fern'] for validation
No ckpts found, training from scratch...
Traceback (most recent call last):
File "D:\code\3D_Reconstruction\GNT\train.py", line 319, in
train(args)
File "D:\code\3D_Reconstruction\GNT\train.py", line 97, in train
for train_data in train_loader:
File "D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\utils\data\dataloader.py", line 441, in iter
return self._get_iterator()
File "D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\utils\data\dataloader.py", line 388, in _get_iterator
return _MultiProcessingDataLoaderIter(self)
File "D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\utils\data\dataloader.py", line 1042, in init
w.start()
File "D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\multiprocessing\process.py", line 121, in start
self._popen = self._Popen(self)
File "D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\multiprocessing\context.py", line 224, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\multiprocessing\context.py", line 327, in _Popen
return Popen(process_obj)
File "D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\multiprocessing\popen_spawn_win32.py", line 93, in init
reduction.dump(process_obj, to_child)
File "D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\multiprocessing\reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'train..'
PS D:\code\3D_Reconstruction\GNT> Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\torchvision\models_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' in
stead.
warnings.warn(
D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\torchvision\models_utils.py:223: UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in th
e future. The current behavior is equivalent to passing weights=AlexNet_Weights.IMAGENET1K_V1. You can also use weights=AlexNet_Weights.DEFAULT to get the most up-to-date weights.
warnings.warn(msg)
Loading model from: D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\lpips\weights\v0.1\alex.pth
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\torchvision\models_utils.py:223: UserWarning: Arguments other than a weight enum or None for 'weights' are deprecated since 0.13 and may be removed in th
e future. The current behavior is equivalent to passing weights=VGG16_Weights.IMAGENET1K_V1. You can also use weights=VGG16_Weights.DEFAULT to get the most up-to-date weights.
warnings.warn(msg)
Loading model from: D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\site-packages\lpips\weights\v0.1\vgg.pth
Traceback (most recent call last):
File "", line 1, in
File "D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\multiprocessing\spawn.py", line 116, in spawn_main
exitcode = _main(fd, parent_sentinel)
File "D:\Users\gracejwang\AppData\Local\Programs\Python\Python39\lib\multiprocessing\spawn.py", line 126, in _main
self = reduction.pickle.load(from_parent)
EOFError: Ran out of input

Training strategy for the released model

Hi, thanks for the fantastic work!

I've been attempting to replicate the results using the training configurations provided in the repository. However, it appears that the iterations of the pretrained model don't quite align with the instructions in the configs. In your paper, you mentioned training GNT with N_rand set to 4096 for 250k iterations across all examples, while in the released model, it seems that much longer iterations were employed based on the model names (for instance, the fern model was trained for 840k iterations, while the generalization model underwent 720k iterations).

As I attempted to train the models following your configs, I noticed a significant discrepancy compared to the released models. I was wondering if you could possibly update the configurations or training strategies so that we can accurately reproduce the numbers for the model you released. Thank you so much!

Question about result of generalization model

Thank you for sharing your interesting work!

Using generalization model parameter uploaded in the github, I had difference between the results in the paper and my results.
PSNR of Nerf synthetic in paper: 27.29
PSNR of Nerf synthetic reproduced : 25.327
PSNR of llff in paper: 25.86
PSNR of llff reproduced : 25.642

Is it correct that the checkpoints of the performance published in the paper were uploaded properly? If not, can I get a checkpoint for that performance?

color issue

Hi, author. Thanks for your excellent work on GNT again.
I tested the code on nerf_synthetic drums model and the result looks good.
But I met an issue that, when it comes to an indoor scene, the color of images predicted looks strange. Just as
1
It seems that black and white reversed partly. Where lies the problem?

loaded state dict contains a parameter group that doesn't match the size of optimizer's group

Descrpiton

I use the pretrained generalization model at https://drive.google.com/file/d/1AMN0diPeHvf2fw53IO5EE2Qp4os5SkoX/view?usp=share_link.
The command is
CUDA_VISIBLE_DEVICES=2 python3 eval.py --config configs/gnt_blender.txt \ --eval_dataset nerf_synthetic \ --eval_scenes mic --run_val \ --expname gnt_author_pretrained_cross_mic \ --ckpt_path /home/cbe/lwy/GNT/out/pretrained_cross_720000.pth

Bug

  File "/home/cbe/lwy/GNT/gnt/model.py", line 133, in load_from_ckpt
    self.load_model(fpath, load_opt, load_scheduler)
  File "/home/cbe/lwy/GNT/gnt/model.py", line 102, in load_model
    self.optimizer.load_state_dict(to_load["optimizer"])
  File "/home/cbe/miniconda3/envs/vision/lib/python3.8/site-packages/torch/optim/optimizer.py", line 146, in load_state_dict
    raise ValueError("loaded state dict contains a parameter group "
ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group

Questions

I wonder if you can run the pretrained generalization model ? or why I met such problem?

Related Issue

before renamed, #8 also met same problem when he/she runs render.py, but i am running eval.py.

Thanks for your great work !!!

Implementation details of view transformer

In the provided code, attn = k - q[:,:,None,:] + pos, attn = self.attn_fc(attn). However, in Fig. 2.a and alg.1, there should not be self.attn_fc component. Could you give an explanation?

keys mismatch when loading pre-trained models

Hi, author. I met an error when loading the pre-trained models. The details are as follows.

outputs will be saved to ./out/gnt_lego
loading ['lego'] for val
Traceback (most recent call last):
  File "eval.py", line 236, in <module>
    eval(args)
  File "/home/vt/anaconda3/envs/gnt/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "eval.py", line 82, in eval
    model = GNTModel(
  File "/media/vt/work/NYXcode/GNT_new/gnt/model.py", line 77, in __init__
    self.start_step = self.load_from_ckpt(
  File "/media/vt/work/NYXcode/GNT_new/gnt/model.py", line 166, in load_from_ckpt
    self.load_model(fpath, load_opt, load_scheduler)
  File "/media/vt/work/NYXcode/GNT_new/gnt/model.py", line 136, in load_model
    self.net_coarse.load_state_dict(to_load["net_coarse"])
  File "/home/vt/anaconda3/envs/gnt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GNT:
	Missing key(s) in state_dict: "view_trans.0.attn_norm.weight", "view_trans.0.attn_norm.bias", ...
	Unexpected key(s) in state_dict: "view_selftrans.0.attn_norm.weight", "view_selftrans.0.attn_norm.bias", ...

I try to modify the mismatching keys, but the correspondence much later is barely perceptible.
Do you have any solutions? Thanks in advance.

cuda out of memory

Hi, author. Thanx for your great work!
I cannot wait to see how it will work. I train on my RTX3090 but it gives "cuda out of memory".
Could u kindly provide the trained ckpts with us to check? Or how can I set the parameters to run it successfully.
Looking forward to your reply!

How to filter the epipolar features

I appreciate your work and I am very interested in it!
By the way, I saw your code and wonder where the code adding an inductive bias for epipolar constraints is.
I found "gnt/data_loaders/data_verifier.py" but this doesn't apply to actual model implementation.
Or is it just omitted because all the datasets you use already satisfy the epipolar constraints?

Evaluation Mismatch with Paper Result

Hello, thanks for your great work! We are very intereseted in exploring further basing on your work.

However, there is a little issue. When trying to reproduce the evaluation for cross-scene generalization result with model '720000.pth' you released, we found the result is significant lower than the result you reported in your paper. We show the result below.

          Name         |  PSNR |  SSIM | LPIPS |
       llff-paper      | 25.86 | 0.867 | 0.116 |
   llff-reproduce      | 25.53 | 0.855 | 0.130 |
    blender-paper      | 27.29 | 0.937 | 0.056 |
blender-reproduce      | 26.02 | 0.926 | 0.073 |

Do you have any idea? Since it's important for us to reproduce your paper accurately.
All evaluation settings strictly follow README, and the evaluation is performed on a single nvidia 3090 GPU.

By the way, could you tell us your training hardware?

Appreciate for your attention!!!

Questions about proper training epoch

Hi, thanks for nice work!
I trained GNT for 50k~100k epochs without using pretrained model. But the results didn't seemed to be good as pretrained ones. Is there recommended training epochs? And does N_rand affects the performance? The config that I changed was N_rand because of lack of CUDA.
Thanks for any assistance.

Pre-trained optimizer mismatch with the model when resuming training

Dear authors. Thank you for your great work first.

I would like to resume the training based on your provided pre-trained models. So I run the following command:
python train.py --config configs/gnt_llff.txt --ckpt_path=./trex_model_300000.pth --train_scenes trex --eval_scenes trex --expname resume_trex --chunk_size 500 --N_samples 20

But there will be a RuntimeError:
The size of tensor a (64) must match the size of tensor b (4) at non-singleton dimension 1 when executing this line:

GNT/train.py

Line 144 in 33a99a9

model.optimizer.step()

The following is the Trackback:

outputs will be saved to ./out/resume_trex
training dataset: llff_test
loading ['trex'] for train
loading ['trex'] for validation
Reloading from ./trex_model_300000.pth, starting at step=300000
Traceback (most recent call last):
  File "/home/rt/Downloads/GNT/train.py", line 319, in <module>
    train(args)
  File "/home/rt/Downloads/GNT/train.py", line 144, in train
    model.optimizer.step()
  File "/home/rt/anaconda3/envs/GNT/lib/python3.9/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper
    return wrapped(*args, **kwargs)
  File "/home/rt/anaconda3/envs/GNT/lib/python3.9/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    return func(*args, **kwargs)
  File "/home/rt/anaconda3/envs/GNT/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/rt/anaconda3/envs/GNT/lib/python3.9/site-packages/torch/optim/adam.py", line 157, in step
    adam(params_with_grad,
  File "/home/rt/anaconda3/envs/GNT/lib/python3.9/site-packages/torch/optim/adam.py", line 213, in adam
    func(params,
  File "/home/rt/anaconda3/envs/GNT/lib/python3.9/site-packages/torch/optim/adam.py", line 262, in _single_tensor_adam
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: The size of tensor a (64) must match the size of tensor b (4) at non-singleton dimension 1

By checking the pre-trained optimizer, I found the ['optimizer']['state'] of some layers have a mismatched shape with their corresponding layers. For example, the optimizer state belonging to the 15th layer is a tensor with the shape of torch.Size([64, 64]), but the 15th layer of the GNT model has a dimension of torch.Size([8, 4]).

And further, I realized the problem is that the layer sequences in the pre-trained optimizer and initialized GNT model are different. For this problem, I am not sure whether this issue may apply.

I don't know what caused this disorder. Hope you can help me.

Rendering results of nerf_synthetic lego validation dataset by provided single-scene model

Description:

I use the pretrained model that you provided at lego synthetic by provided single-scene model at https://drive.google.com/file/d/1IbhbBr5XfxQz0jSQM3nLX_htTbvc59kj/view?usp=share_link
the command is:
CUDA_VISIBLE_DEVICES=2 python3 eval.py --config configs/gnt_blender.txt \ --eval_dataset nerf_synthetic \ --eval_scenes lego --run_val\ --expname gnt_author_pretrained_single_lego \ --ckpt_path out/gnt_lego_from_single_ckpt/pretrained_lego_model_435000.pth \
(gnt_blender.txt is not modified)

Results:

  1. rendering results on validation set seems not correct, since the color has some green

image

  1. but the metrics seems resonable as this screenshot shows:

image

Questions:

I wonder the cause of such unexpected rendering results.Is it caused by my command or anything else?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.