Coder Social home page Coder Social logo

gram's Introduction

GRAM: Generative Radiance Manifolds for 3D-Aware Image Generation

This is an official pytorch implementation of the following paper:

Y. Deng, J. Yang, J. Xiang, and X. Tong, GRAM: Generative Radiance Manifolds for 3D-Aware Image Generation, IEEE Computer Vision and Pattern Recognition (CVPR), 2022. (Oral Presentation)

Project page | Paper | Video

Abstract: 3D-aware image generative modeling aims to generate 3D-consistent images with explicitly controllable camera poses. Recent works have shown promising results by training neural radiance field (NeRF) generators on unstructured 2D images, but still cannot generate highly-realistic images with fine details. A critical reason is that the high memory and computation cost of volumetric representation learning greatly restricts the number of point samples for radiance integration during training. Deficient sampling not only limits the expressive power of the generator to handle fine details but also impedes effective GAN training due to the noise caused by unstable Monte Carlo sampling. We propose a novel approach that regulates point sampling and radiance field learning on 2D manifolds, embodied as a set of learned implicit surfaces in the 3D volume. For each viewing ray, we calculate ray-surface intersections and accumulate their radiance generated by the network. By training and rendering such radiance manifolds, our generator can produce high quality images with realistic fine details and strong visual 3D consistency.

Requirements

  • Currently only Linux is supported.
  • 64-bit Python 3.6 installation or newer. We recommend using Anaconda3.
  • One or more high-end NVIDIA GPUs, NVIDIA drivers, and CUDA toolkit 10.1 or newer. We recommend using 8 Tesla V100 GPUs with 32 GB memory for training to reproduce the results in the paper.

Installation

Clone the repository and set up a conda environment with all dependencies as follows:

git clone https://github.com/microsoft/GRAM.git
cd GRAM
conda env create -f environment.yml
source activate gram

Alternatively, we provide a Dockerfile to build an image with the required dependencies.

Pre-trained models

Checkpoints for pre-trained models used in our paper (default settings) are as follows.

Dataset Config Resolution Training iterations Batchsize FID 20k KID 20k (x100) Download
FFHQ FFHQ_default 256x256 150k 32 14.5 0.65 Github link
Cats CATS_default 256x256 80k 16 14.6 0.75 Github link
CARLA CARLA_default 128x128 70k 32 26.3 1.15 Github link

Generating multi-view images with pre-trained models

Run the following script to render multi-view images of generated subjects using a pre-trained model:

# face images are generated by default (FFHQ_default)
python render_multiview_images.py

# custom setting for image generation
python render_multiview_images.py --config=<CONFIG_NAME> --generator_file=<GENERATOR_PATH.pth> --output_dir=<OUTPUT_FOLDER> --seeds=0,1,2

By default, the script generates images with watermarks. Use --no_watermark argument to remove them.

Training a model from scratch

Data preparation

GRAM/
│
└─── raw_data/
    |
    └─── ffhq/
	│
	└─── *.png   # original 1024x1024 images
	│
        └─── lm5p/   # detected 5 facial landmarks
	|   |
        |   └─── *.txt
	|
	└─── poses/  # estimated face poses
	    |
	    └─── *.mat    
  • Cats: Download the original cat images and provided landmarks using this link and organize all files as follows:
GRAM/
│
└─── raw_data/
    |
    └─── cats/
	│
	└─── *.jpg   # original images
	│
        └─── *.jpg.cat   # provided landmarks
  • CARLA: Download the original images and poses from GRAF and organize all files as follows:
GRAM/
│
└─── raw_data/
    |
    └─── carla/
	│
	└─── *.png   # original images
	│
        └─── poses/  # provided poses
	    |
	    └─── *_extrinsics.npy

Finally, run the following script for data preprocessing:

python preprocess_dataset.py --raw_dataset_path=./raw_data/<CATEGORY> --cate=<CATEGORY>

It will align all images and save them with the estimated/provided poses into ./datasets for the later training process.

Training networks

Run the following script to train a generator from scratch using the preprocessed data:

python train.py --config=<CONFIG_NAME> --output_dir=<OUTPUT_FOLDER>

The code will automatically detect all available GPUs and use DDP training. You can use the default configs provided in the configs.py or add your own config. By default, we use batch split suggested by pi-GAN to increase the effective batchsize during training.

The following table lists training times for different configs using 8 NVIDIA Tesla V100 GPUs (32GB memory):

Config Resolution Training iterations Batchsize Times
FFHQ_default 256x256 150k 32 12d 4h
CATS_default 256x256 80k 16 4d 6h
CARLA_default 128x128 70k 32 3d 15h

Training GRAM under 256x256 image resolution requires around 30GB memory for a typical forward-backward cycle with a batchsize of 1 using Pytorch Automatic Mixed Precision. To enable training using GPUs with limited memory, we provide an alternative way using patch-level forward and backward process (see here for a detailed explanation):

python train.py --config=<CONFIG_NAME> --output_dir=<OUTPUT_FOLDER> --patch_split=<NUMBER_OF_PATCHES> 

Currently we support a patch split of a power of 2 (e.g. 2, 4, 8, ...). It will effectively reduce the memory cost with a slight increase of the training time.

Evaluation

Run the following script for FID&KID calculation:

python fid_evaluation.py --no_watermark --config=<CONFIG_NAME> --generator_file=<GENERATOR_PATH.pth> --output_dir=<OUTPUT_FOLDER>

By default, 8000 real images and 1000 generated images from EMA model are used for evaluation. You can adjust the number of images according to your own needs.

Responsible Ai Considerations

The goal of this work is to study generative modelling of the 3D objects from 2D images, and to provide a method for generating multi-view images of non-existing, virtual objects. It is not intended to manipulate existing images nor to create content that is used to mislead or deceive. This method does not have understanding and control of the generated content. Thus, adding targeted facial expressions or mouth movements is out of the scope of this work. However, the method, like all other related AI image generation techniques, could still potentially be misused for impersonating humans. Currently, the images generated by this method contain visual artifacts, unnatural texture patterns, and other unpredictable failures that can be spotted by humans and fake image detection algorithms. We also plan to investigate applying this technology for advancing 3D- and video-based forgery detection.

License

Per concerns about misuse of this method, the code is available for use under a research-only license.

Citation

Please cite the following paper if this work helps your research:

@inproceedings{deng2022gram,
	title={GRAM: Generative Radiance Manifolds for 3D-Aware Image Generation},
	author={Deng, Yu and Yang, Jiaolong and Xiang, Jianfeng and Tong, Xin},
	booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition},
	year={2022}
}

Contact

If you have any questions, please contact Yu Deng ([email protected]) and Jiaolong Yang ([email protected])

Acknowledgements

We thank Harry Shum for the fruitful advice and discussion to improve the paper. This implementation takes pi-GAN as a reference. We thank the authors for their excellent work.

gram's People

Contributors

microsoftopensource avatar yangjiaolong avatar yudeng avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

gram's Issues

FIG 7 AND FIG 8

Hello 👋,
Thanks for sharing this great work.
Could you please provide more details on reproducing the figures 7 and 8?
Thanks

How to generate color patterns for individual surface?

Thanks for sharing the great work!

I am trying to reproduce the results in Fig. 7 (which shows the color patterns for individual surface), but cannot find relevant code in this repository. Could you please advise how to implement it? Many thanks!

How to use own image to generate images with different yaws angles

Hello, how to use render_multiview_images.py to input own image and generate images from different angles?
In render_multiview_images.py, the corresponding input image path is not found.
Can you tell me how to modify the input image path and generate the corresponding image?
Thank you

An error when training

Thank you for your great work!
When I run the training code, I get an error like the following

Traceback (most recent call last):
  File "C:\Apps\Python39\lib\site-packages\torch\multiprocessing\spawn.py", line 69, in _wrap
    fn(i, *args)
  File "D:\Projects\GRAM\train.py", line 33, in train
    training_process(rank, world_size, opt, device)
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "C:\Apps\Python39\lib\site-packages\torch\nn\parallel\distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "C:\Apps\Python39\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\Projects\GRAM\generators\generators.py", line 112, in forward
    img, _ = self.renderer.render(self._intersections, self._volume(z, truncation_psi), img_size, camera_origin, camera_pos, fov, ray_start, ray_end, z.device)
  File "D:\Projects\GRAM\generators\renderers\manifold_renderer.py", line 146, in render
    transformed_points_sample, transformed_ray_directions, transformed_ray_origins, _ = transform_sampled_points(points_cam, rays_d_cam, camera_origin, camera_pos, device=device)
  File "D:\Projects\GRAM\generators\renderers\manifold_renderer.py", line 92, in transform_sampled_points
    cam2world_matrix = create_cam2world_matrix(forward_vector, camera_origin, device=device)
  File "D:\Projects\GRAM\generators\renderers\manifold_renderer.py", line 123, in create_cam2world_matrix
    cam2world = translation_matrix @ rotation_matrix
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

Please help solve,thank you very much!

Multi-gpu training cause RuntimeError

Hi, I try to run the train.py on ffhq dataset in multi-gpu manner. However, I meet the RuntimeError as follows

training_process(rank, world_size, opt, device)
File "/home/xintian/workspace/GRAM-main/training_loop.py", line 217, in training_process
d_loss = process.train_D(real_imgs, real_poses, generator_ddp, discriminator_ddp, optimizer_D, scaler, config, device)
File "/home/xintian/workspace/GRAM-main/processes/processes.py", line 38, in train_D
g_imgs, g_pos = generator_ddp(subset_z, **config['camera'])
File "/home/xintian/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/xintian/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 705, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/home/xintian/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/xintian/workspace/GRAM-main/generators/generators.py", line 112, in forward
img, _ = self.renderer.render(self._intersections, self._volume(z, truncation_psi), img_size, camera_origin, camera_pos, fov, ray_start, ray_end, z.device)
File "/home/xintian/workspace/GRAM-main/generators/renderers/manifold_renderer.py", line 194, in render
coarse_output = volume(transformed_points, transformed_ray_directions_expanded).reshape(batchsize, img_size * img_size, self.num_manifolds, 4)
File "/home/xintian/workspace/GRAM-main/generators/generators.py", line 76, in
return lambda points, ray_directions: self.representation.get_radiance(z, points, ray_directions, truncation_psi)
File "/home/xintian/workspace/GRAM-main/generators/representations/gram.py", line 317, in get_radiance
return self.rf_network(x, z, ray_directions, truncation_psi)
File "/home/xintian/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/xintian/workspace/GRAM-main/generators/representations/gram.py", line 260, in forward
frequencies_2, phase_shifts_2 = self.mapping_network(z2)
File "/home/xintian/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/xintian/workspace/GRAM-main/generators/representations/gram.py", line 93, in forward
frequencies_offsets = self.network(z)
File "/home/xintian/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/xintian/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/container.py", line 119, in forward
input = module(input)
File "/home/xintian/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/xintian/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 94, in forward
return F.linear(input, self.weight, self.bias)
File "/home/xintian/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/nn/functional.py", line 1753, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: Expected tensor for 'out' to have the same device as tensor for argument #2 'mat1'; but device 1 does not equal 0 (while checking arguments for addmm)

How can I fix this? By the way, when I run the train.py in single-gpu manner, it works.

Missing poses for FFHQ dataset

Hello,
First of all, thank you for providing all the links to the data you used for your experiments.

After downloading the face poses archive from this link you provide in the README, I noticed that the archive contains 69994 poses, whereas the original FFHQ dataset contains 70000 images.

These are the indices missing from the archive: [43079, 27757, 36783, 51858, 17972, 42709]

Is this intentional, or a bug? You don't have any functionality in your FFHQ dataset implementation to support a mismatch between extracted images and poses.

Question about extracted proxy 3D shapes

Thanks for your open-source work! You mention "extract proxy 3D shapes of the generated object using the volume-based marching cubes algorithm". You use 24/48 surface manifold in total. I wonder how you decide which isosurface of 24/48 surface manifold to reconstruct?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.