Coder Social home page Coder Social logo

ogroth / tf-gqn Goto Github PK

View Code? Open in Web Editor NEW
189.0 17.0 36.0 7.4 MB

Tensorflow implementation of Neural Scene Representation and Rendering

License: Apache License 2.0

Python 85.28% Jupyter Notebook 14.72%
view-synthesis gqn-datasets neural-processes paper-implementations

tf-gqn's People

Contributors

buesma avatar dependabot[bot] avatar ogroth avatar stefanwayon avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tf-gqn's Issues

Dataset

Dear Ogroth,

I checked the gqn dataset from https://github.com/deepmind/gqn-datasets. Shepard_metzler_7_parts contains 900 tfrecords for training. Each tfrecord has 20 scenes. So there are only 18000 scenes. For Mazes, there are 1080 records with 100 scenes for each record, which means 108000 in total.

So I think this link (https://github.com/deepmind/gqn-datasets) just contains a part of the whole data. Did you get the whole data from that link?

Best,
Bing

measuring quality of pre-trained your snapshot

Hi ogroth,

Thank you for sharing your great code.

I want to see how much your snapshot is working well for room ring camera dataset for not just ELBO and KL divergence between prior and posterior but also quality as image.

There are some options for doing that?

Thanks.

Best Regards,

Jaesik Yoon.

how many train step we need

when i train the model, it can run, but when the step meet 200000 loss drop to 6900 and it never go down, i train it for 370000 step and it still 6900, do you meet this question before?

eval_summary_hook error

Hi,
I started the GQN training with my self generated dataset with new update code.
However, I faced some error on eval_summary_hook with my self generated dataset

"Exactly one of scaffold or summary_op must be provided.")
ValueError: Exactly one of scaffold or summary_op must be provided.

This issue occurs only with my self generated dataset but without any issue with deepmind dataset.

I have found a way to overcome this issue by changing the eval_summary_hook in gqn_model.py

from

eval_summary_hook = tf.train.SummarySaverHook(
save_steps=1, output_dir=os.path.join(params['model_dir'], 'eval'),
summary_op=tf.summary.merge_all())

to

eval_summary_hook = tf.train.SummarySaverHook(
save_steps=1, output_dir=os.path.join(params['model_dir'], 'eval'),
scaffold=tf.train.Scaffold(summary_op=tf.summary.merge_all())) #<<<<modified this line

Do you have any clues on how this modification will affect the results? and do you have a rough idea on what might be the problem within my dataset that will cause this issue?

Thank you.

Cannot reproduce visualization from snapshots

Hello, thank you for providing this code.

I tried the code provided by @slimm in this issue but was unable to reproduce the results @Yangshell provided.

I get the following error msg:
Traceback (most recent call last):
File "test_model.py", line 26, in
for prediction in estimator.predict(input_fn=input_fn):
File "/home/kstan/tf-gqn/venv/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 551, in predict
features, None, model_fn_lib.ModeKeys.PREDICT, self.config)
File "/home/kstan/tf-gqn/venv/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1169, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/home/kstan/tf-gqn/gqn/gqn_model.py", line 102, in gqn_draw_model_fn
predictions=mu_target),
File "/home/kstan/tf-gqn/venv/lib/python3.5/site-packages/tensorflow/python/ops/metrics_impl.py", line 1312, in mean_squared_error
squared_error = math_ops.square(labels - predictions)
File "/home/kstan/tf-gqn/venv/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py", line 889, in r_binary_op_wrapper
x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
File "/home/kstan/tf-gqn/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1048, in convert_to_tensor
as_ref=False)
File "/home/kstan/tf-gqn/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1144, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/home/kstan/tf-gqn/venv/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 228, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "/home/kstan/tf-gqn/venv/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 207, in constant
value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "/home/kstan/tf-gqn/venv/lib/python3.5/site-packages/tensorflow/python/framework/tensor_util.py", line 430, in make_tensor_proto
raise ValueError("None values not supported.")
ValueError: None values not supported.

I suspect the error is coming from the input_fn, and that the data is not being fed properly to estimator.predict(). What is the mistake and how can it be fixed?

Thanks for the help.

Best,
Kevin

DRAW module as Generation architecture

First off, thank you for posting this code as it is very helpful! Its not a problem but I am curious why you decided to use the DRAW module in your code even though it is nowhere cited in the GQN paper? Granted the paper lacks many implementation details...

GQN trained on CLEVR dataset

Thanks for the GQN implementation. I thought you might enjoy seeing some pictures of how it does trained on a different dataset. (Albeit with a limited amount of training time. I plan to train for longer.)

Screen Shot 2019-06-11 at 6 04 24 AM

Even on the test set it works pretty well with a relatively small amount of training. Seems to generalize better than on the flat shaded deepmind dataset.

image

I'm curious what kind of changes you might be interested in via pull request? I have some changes to the training parameters and also I've found self-attention to improve the speed of generalization and training in general. However that wasn't in the original paper.

Thanks,
logan

ERROR:tensorflow:Model diverged with loss = NaN

When I try to train, I have a problem. Can you help me?
Here is my error code:

(tensorflowgqn) E:\LH\tf-gqn-master-2\tf-gqn-master>python train_gqn.py ^
More?   --data_dir data\gqn-dataset ^
More?   --dataset shepard_metzler_5_parts ^
More?   --model_dir models\shepard_metzler_5_parts\gqn
Training a GQN.
PARSED ARGV: Namespace(adam_lr_alpha=0.0005, adam_lr_beta=5e-05, anneal_lr_tau=1600000, batch_size=36, chkpt_steps=10000, context_size=5, data_dir='data\\gqn-dataset', dataset='shepard_metzler_5_parts', debug=False, img_size=64, initial_eval=False, log_steps=100, memcap=1.0, model_dir='models\\shepard_metzler_5_parts\\gqn', queue_buffer=4, queue_threads=4, seq_length=8, train_epochs=2)
UNPARSED ARGV: []
Saved model config to models\shepard_metzler_5_parts\gqn\gqn_config.json
INFO:tensorflow:Using config: {'_model_dir': 'models\\shepard_metzler_5_parts\\gqn', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 10000, '_save_checkpoints_secs': None, '_session_config': gpu_options {
  per_process_gpu_memory_fraction: 1.0
  allow_growth: true
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000001B00D429DA0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
2023-08-13 15:47:55.159960: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2023-08-13 15:47:55.256456: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1432] Found device 0 with properties:
name: NVIDIA GeForce RTX 3080 Ti major: 8 minor: 6 memoryClockRate(GHz): 1.665
pciBusID: 0000:01:00.0
totalMemory: 12.00GiB freeMemory: 10.84GiB
2023-08-13 15:47:55.257468: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1511] Adding visible gpu devices: 0
2023-08-13 15:47:55.704851: I tensorflow/core/common_runtime/gpu/gpu_device.cc:982] Device interconnect StreamExecutor with strength 1 edge matrix:
2023-08-13 15:47:55.704983: I tensorflow/core/common_runtime/gpu/gpu_device.cc:988]      0
2023-08-13 15:47:55.705062: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1001] 0:   N
2023-08-13 15:47:55.705225: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 12287 MB memory) -> physical GPU (device: 0, name: NVIDIA GeForce RTX 3080 Ti, pci bus id: 0000:01:00.0, compute capability: 8.6)
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into models\shepard_metzler_5_parts\gqn\model.ckpt.
INFO:tensorflow:l2_reconstruction = [0.         0.03854151]
INFO:tensorflow:loss = 598648.6, step = 0
ERROR:tensorflow:Model diverged with loss = NaN.
Traceback (most recent call last):
  File "train_gqn.py", line 227, in <module>
    tf.app.run(argv=[sys.argv[0]] + UNPARSED_ARGS)
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
    _sys.exit(main(argv))
  File "train_gqn.py", line 212, in main
    hooks=[logging_hook],
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\tensorflow\python\estimator\estimator.py", line 354, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1207, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1241, in _train_model_default
    saving_listeners)
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1471, in _train_with_estimator_spec
    _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\tensorflow\python\training\monitored_session.py", line 671, in run
    run_metadata=run_metadata)
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1156, in run
    run_metadata=run_metadata)
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1255, in run
    raise six.reraise(*original_exc_info)
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\six.py", line 693, in reraise
    raise value
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1240, in run
    return self._sess.run(*args, **kwargs)
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1320, in run
    run_metadata=run_metadata))
  File "E:\JZL\deeping_learning\Anaconda\envs\tensorflowgqn\lib\site-packages\tensorflow\python\training\basic_session_run_hooks.py", line 753, in after_run
    raise NanLossDuringTrainingError
tensorflow.python.training.basic_session_run_hooks.NanLossDuringTrainingError: NaN loss during training.

data format

Hello, thank you for providing this code!
I try to use my dataset to train this GQN model,but I don't know the data's format and the number of the data if I want to use this model in the actual scene such as my apartment.
Could you tell me a scope of the number of the data?
Thank you !

Gradient flowing between inference_cell and generator_cell

There seems to be an issue in the inference_rnn function where the inference_cell and generator_cell are connected together:

tf-gqn/gqn/gqn_draw.py

Lines 417 to 421 in bc84f24

# estimate statistics and sample state from posterior
mu_q, sigma_q, z_q = compute_eta_and_sample_z(inf_state.lstm.h,
scope="Sample_eta_q")
# input into generator RNN
gen_input = _GeneratorCellInput(representations, query_poses, z_q)

It looks like the gradient flows through z_q.

Adding the line z_q = tf.stop_gradient(z_q)
seems to improve the results when just the generator_rnn is used during testing.

Download Dataset

Hi,

Can you give another way to download the dataset?
I am stuck for the current downloading way.

Thank you,
Mingjia

How can I get the right result?

I use view interpolation notebook to load shepard_metzler_5_parts, But I can't get a correct result.
This is my result:
view_interpolation_preview
Here is my process:

'''imports'''
# stdlib
import os
import sys
import logging
# numerical computing
import numpy as np
import tensorflow as tf
# plotting
import imageio
logging.getLogger("imageio").setLevel(logging.ERROR)  # switch off warnings during lossy GIF-generation
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from IPython.display import Image, display
# GQN src
root_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(root_path)
print(sys.path)
from data_provider.gqn_provider import gqn_input_fn
from gqn.gqn_predictor import GqnViewPredictor

['C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\python36.zip', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\DLLs', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew', '', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib\site-packages', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib\site-packages\win32', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib\site-packages\win32\lib', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib\site-packages\Pythonwin', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib\site-packages\IPython\extensions', 'C:\Users\lenovo\.ipython', 'E:\Desktop\tf-gqn-master\tf-gqn-master']

'''directory setup'''
data_dir = os.path.join(root_path, 'data')
model_dir = os.path.join(root_path, 'models')
tmp_dir = os.path.join(root_path, 'notebooks', 'tmp')
gqn_dataset_path = os.path.join(data_dir, 'gqn-dataset')
# dataset flags
# dataset_name = 'jaco'  # one of the GQN dataset names
# dataset_name = 'rooms_ring_camera'  # one of the GQN dataset names
# dataset_name = 'rooms_free_camera_no_object_rotations'  # one of the GQN dataset names
# dataset_name = 'rooms_free_camera_with_object_rotations'  # one of the GQN dataset names
dataset_name = 'shepard_metzler_5_parts'#'shepard_metzler_5_parts'  # one of the GQN dataset names
# dataset_name = 'shepard_metzler_7_parts'  # one of the GQN dataset names
data_path = os.path.join(gqn_dataset_path, dataset_name)
print("Data path: %s" % (data_path, ))
# model flags
model_name = 'gqn'#'gqn8'
# model_name = 'gqn12'
gqn_model_path = os.path.join(model_dir, dataset_name)
model_path = os.path.join(gqn_model_path, model_name)
print("Model path: %s" % (model_path, ))
# tmp
notebook_name = 'view_interpolation'
notebook_tmp_path = os.path.join(tmp_dir, notebook_name)
os.makedirs(notebook_tmp_path, exist_ok=True)
print("Tmp path: %s" % (notebook_tmp_path, ))

Data path: E:\Desktop\tf-gqn-master\tf-gqn-master\data\gqn-dataset\shepard_metzler_5_parts
Model path: E:\Desktop\tf-gqn-master\tf-gqn-master\models\shepard_metzler_5_parts\gqn
Tmp path: E:\Desktop\tf-gqn-master\tf-gqn-master\notebooks\tmp\view_interpolation

'''data reader setup'''
mode = tf.estimator.ModeKeys.EVAL
ctx_size=5  # needs to be the same as the context size defined in gqn_config.json in the model_path
batch_size=1  # should be kept at 1
dataset = gqn_input_fn(
        dataset_name=dataset_name, root=gqn_dataset_path, mode=mode,
        context_size=ctx_size, batch_size=batch_size, num_epochs=1,
        num_threads=4, buffer_size=1)
iterator = dataset.make_initializable_iterator()
data = iterator.get_next()
'''video predictor & session setup'''
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # run on CPU only, adjust to GPU id for speedup
#os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
predictor = GqnViewPredictor(model_path)
sess = predictor.sess
sess.run(iterator.initializer)
print("Loop completed.")

>>> Instantiated GQN:
enc_r Tensor("GQN/Sum:0", shape=(1, 1, 1, 256), dtype=float32)
canvas_0 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add:0", shape=(1, 64, 64, 256), dtype=float32)
canvas_1 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_1:0", shape=(1, 64, 64, 256), dtype=float32)
canvas_2 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_2:0", shape=(1, 64, 64, 256), dtype=float32)
canvas_3 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_3:0", shape=(1, 64, 64, 256), dtype=float32)
canvas_4 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_4:0", shape=(1, 64, 64, 256), dtype=float32)
canvas_5 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_5:0", shape=(1, 64, 64, 256), dtype=float32)
canvas_6 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_6:0", shape=(1, 64, 64, 256), dtype=float32)
canvas_7 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_7:0", shape=(1, 64, 64, 256), dtype=float32)
mu_target Tensor("GQN/eta_g/BiasAdd:0", shape=(1, 64, 64, 3), dtype=float32)
INFO:tensorflow:Restoring parameters from E:\Desktop\tf-gqn-master\tf-gqn-master\models\shepard_metzler_5_parts\gqn\model.ckpt-0
>>> Restored parameters from: E:\Desktop\tf-gqn-master\tf-gqn-master\models\shepard_metzler_5_parts\gqn\model.ckpt-0
Loop completed.

'''data visualization'''
skip_load = 1  # adjust this to skip through records
print("Loop completed.")
# fetch & parse
for _ in range(skip_load):
    d, _ = sess.run(data)
ctx_frames = d.query.context.frames
ctx_poses = d.query.context.cameras
tgt_frame = d.target
tgt_pose = d.query.query_camera
tuple_length = ctx_size + 1  # context points + 1 target

print(">>> Context frames:\t%s" % (ctx_frames.shape, ))
print(">>> Context poses: \t%s" % (ctx_poses.shape, ))
print(">>> Target frame:  \t%s" % (tgt_frame.shape, ))
print(">>> Target pose:   \t%s" % (tgt_pose.shape, ))

# visualization constants
MAX_COLS_PER_ROW = 6
TILE_HEIGHT, TILE_WIDTH, TILE_PAD = 3.0, 3.0, 0.8
np.set_printoptions(precision=2, suppress=True)

# visualize all data tuples in the batch
for n in range(batch_size):
    # define image grid
    ncols = int(np.min([tuple_length, MAX_COLS_PER_ROW]))
    nrows = int(np.ceil(tuple_length / MAX_COLS_PER_ROW))
    fig = plt.figure(figsize=(TILE_WIDTH * ncols, TILE_HEIGHT * nrows))
    grid = ImageGrid(
        fig, 111,  # similar to subplot(111)
        nrows_ncols=(nrows, ncols),
        axes_pad=TILE_PAD,  # pad between axes in inch.
    )
    # visualize context
    for ctx_idx in range(ctx_size):
        rgb = ctx_frames[n, ctx_idx]
        pose = ctx_poses[n, ctx_idx]
        caption = "ctx: %02d\nxyz:%s\nyp:%s" % \
            (ctx_idx + 1, pose[0:3], pose[3:])
        grid[ctx_idx].imshow(rgb)
        grid[ctx_idx].set_title(caption, loc='center')
    # visualize target
    rgb = tgt_frame[n]
    pose = tgt_pose[n]
    caption = "target\nxyz:%s\nyp:%s" % \
        (pose[0:3], pose[3:])
    grid[-1].imshow(rgb)
    grid[-1].set_title(caption, loc='center')
    plt.show()

Loop completed.
>>> Context frames: (1, 5, 64, 64, 3)
>>> Context poses: (1, 5, 7)
>>> Target frame: (1, 64, 64, 3)
>>> Target pose: (1, 7)
image

'''run the view prediction'''

# visualize all data tuples in the batch
for n in range(batch_size):

    print(">>> Predictions:")
    # define image grid for predictions
    ncols = int(np.min([tuple_length, MAX_COLS_PER_ROW]))
    nrows = int(np.ceil(tuple_length / MAX_COLS_PER_ROW))
    fig = plt.figure(figsize=(TILE_WIDTH * ncols, TILE_HEIGHT * nrows))
    grid = ImageGrid(
        fig, 111,  # similar to subplot(111)
        nrows_ncols=(nrows, ncols),
        axes_pad=TILE_PAD,  # pad between axes in inch.
    )
    # load the scene context
    predictor.clear_context()
    for i in range(ctx_size):
        ctx_frame = ctx_frames[n, i]
        ctx_pose = ctx_poses[n, i]
        predictor.add_context_view(ctx_frame, ctx_pose)
    # render query
    query_pose = tgt_pose[n]
    pred_frame = predictor.render_query_view(query_pose)[0]
    caption = "query\nxyz:%s\nyp:%s" % \
        (query_pose[0:3], query_pose[3:])
    grid[0].imshow(pred_frame)
    grid[0].set_title(caption, loc='center')
    # re-render context (auto-encoding consistency)
    for ctx_idx in range(ctx_size):
        query_pose = ctx_poses[n, ctx_idx]
        pred_frame = predictor.render_query_view(query_pose)[0]
        caption = "ctx: %02d\nxyz:%s\nyp:%s" % \
            (ctx_idx + 1, query_pose[0:3], query_pose[3:])
        grid[ctx_idx + 1].imshow(pred_frame)
        grid[ctx_idx + 1].set_title(caption, loc='center')
    plt.show()

    print(">>> Ground truth:")
    # define image grid for predictions
    ncols = int(np.min([tuple_length, MAX_COLS_PER_ROW]))
    nrows = int(np.ceil(tuple_length / MAX_COLS_PER_ROW))
    fig = plt.figure(figsize=(TILE_WIDTH * ncols, TILE_HEIGHT * nrows))
    grid = ImageGrid(
        fig, 111,  # similar to subplot(111)
        nrows_ncols=(nrows, ncols),
        axes_pad=TILE_PAD,  # pad between axes in inch.
    )
    # query
    pose = tgt_pose[n]
    rgb = tgt_frame[n]
    caption = "query\nxyz:%s\nyp:%s" % \
        (pose[0:3], pose[3:])
    grid[0].imshow(rgb)
    grid[0].set_title(caption, loc='center')
    # context
    for ctx_idx in range(ctx_size):
        pose = ctx_poses[n, ctx_idx]
        rgb = ctx_frames[n, ctx_idx]
        caption = "ctx: %02d\nxyz:%s\nyp:%s" % \
            (ctx_idx + 1, pose[0:3], pose[3:])
        grid[ctx_idx + 1].imshow(rgb)
        grid[ctx_idx + 1].set_title(caption, loc='center')
    plt.show()

image

'''render a view interpolation trajectory'''

# query pose trajectory per dataset
# [[0, 0, 0, yaw, 0] for yaw in range(0, 360, 10)]
def _query_poses(num_poses=40, radius=3.0, height=2.5, angle=30.0):
    x = list(radius * np.sin(np.linspace(np.pi, -np.pi, num_poses)))
    y = list(radius * np.cos(np.linspace(np.pi, -np.pi, num_poses)))
    z = list(height * np.ones((num_poses, )))
    yaw = list(np.linspace(0.0, 360.0, num_poses))
    pitch = list(angle * np.ones((num_poses, )))
    poses = list(zip(x, y, z, yaw, pitch))
    return poses

QUERY_POSES = {
    'shepard_metzler_5_parts' : _query_poses(),
    'shepard_metzler_7_parts' : _query_poses(),
}

# generate query poses
query_poses = QUERY_POSES[dataset_name]
query_poses = [np.array(qp) for qp in query_poses]

# render corresponding views
print(">>> Rendering interpolation trajectory for %d query poses..." % (len(query_poses), ))
frame_buffer = []
for i, query_pose in enumerate(query_poses):
    pred_frame = predictor.render_query_view(query_pose)[0]
    frame_buffer.append(pred_frame)
    if (i+1) % 10 == 0:
        print("    %d / %d frames rendered." % ((i+1), len(query_poses)))

# show gif of view interpolation trajectory
gif_tmp_path = os.path.join(notebook_tmp_path, 'view_interpolation_preview.gif')
imageio.mimsave(gif_tmp_path, frame_buffer)
with open(gif_tmp_path, 'rb') as file:
    display(Image(file.read()))

>>> Rendering interpolation trajectory for 40 query poses...
10 / 40 frames rendered.
20 / 40 frames rendered.
30 / 40 frames rendered.
40 / 40 frames rendered.
view_interpolation_preview

Self-Attention and other extensions

@ogroth, as mentioned in #27 I've found it useful to add self-attention and also some additional training options to limit the amount of data used for training. Also a simple test script. Would you please take a look at my branch and give me an idea of what parts you might consider including and I'll create a pull request with that subset?

master...loganbruns:sa_gqn

Thanks,
logan

AttributeError: 'GQNTFRecordDataset' object has no attribute '_graph_attr'

Hello, I am trying to use your code to perform training again but I got this error
I have downloaded the dataset and place it properly. And the dataset I used is mazes.
The version of my tensorflow is 1.13.1.
Thanks a lot
No_graph_attr

I change the tensorflow's version to 1.12.0,but another error occurres that 'ValueError: Dimension 1 in both shapes must be equal, but are 84 and 64. Shapes are [?,84,84] and [?,64,64]. for 'GQN/GQN_RNN/Inference/LSTM_inf/concat' (op: 'ConcatV2') with input shapes: [?,84,84,3], [?,64,64,256], [] and with computed input tensors: input[2] = <-1>.'
I think the reason is that the function tf.concat([query_image, canvas], axis=-1),
image

Training GQN with different dataset

Hi,

I am trying to train the GQN with shepard_metzler_5_parts dataset.
After the training has finished, I used the script mentioned in here to visualize the result. However, the generated image is just a black image.

So, the default training parameters is only for rooms_ring_camera? Do I need to change some of the training parameters in gqn_params.py for different dataset?

How to train using multi-GPU

Thanks for the code!
I wonder is there any way to train this model using 2 GPU. Right now I am training GQN with shepard_metzler_5_parts dataset and only 1 GPU is training.

Question about pack_context function

Hello,
I am reading your code and there is one thing I dont understand about is the _pack_context function that u used as the input to the encoder.

def _pack_context(context_poses, context_frames, model_params):
  # shorthand notations for model parameters
  _DIM_POSE = model_params.POSE_CHANNELS
  _DIM_H_IMG = model_params.IMG_HEIGHT
  _DIM_W_IMG = model_params.IMG_WIDTH
  _DIM_C_IMG = model_params.IMG_CHANNELS

  # pack scene context into pseudo-batch for encoder
  context_poses_packed = tf.reshape(context_poses, shape=[-1, _DIM_POSE])
  context_frames_packed = tf.reshape(
      context_frames, shape=[-1, _DIM_H_IMG, _DIM_W_IMG, _DIM_C_IMG])

context_poses is having the shape of (?,5,7) and u get the context_poses_packed by reshape it into (?,7). My question is that is that a problem when concatenating it with the feature maps in the encoder ?
You also did the same thing with the context frames also, can u explain why u did this instead of creating a for loop and put the original view + pose as the input and do the summation afterwards. I am sorry if I made a silly question because seems like your approach works.

It stays in local minimum and never goes down

Just like the question raised by the last person, in the step of 200k, the loss goes down to 6930 and never decrease!
However, the visualization shows that the model can not reach the expected results and only a mess.
I repeated this train process again and the same question happened again.
Can anyone help me to repeat the results?
Or share your experience.
Thanks a lot!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.