Coder Social home page Coder Social logo

How can I test the model? about tf-gqn HOT 8 CLOSED

ogroth avatar ogroth commented on June 2, 2024
How can I test the model?

from tf-gqn.

Comments (8)

ogroth avatar ogroth commented on June 2, 2024

Hi Yangshell, you can run a trained model by setting up an estimator.
gqn_model = tf.estimator.Estimator(model_fn=gqn_draw_model_fn, model_dir='your/model/directory')
Obviously, 'you/model/directory' needs to point to the directory, where a snapshot is stored (e.g. one of the model directories downloaded from the README).
Then you can call gqn_model.predict(.) feeding context images and poses as well as a query pose. The data_provider test case shows how to run the DataReader as a standalone object providing the correct data.
The predict(.) function returns a dictionary with the predicted mean image and its variance.
For a better understanding of the model's inputs and outputs, please refer to the gqn_draw_model_fn.

from tf-gqn.

Yangshell avatar Yangshell commented on June 2, 2024

Can I get the image that model predict and real?

from tf-gqn.

ogroth avatar ogroth commented on June 2, 2024

The predict(.) returns a dictionary with the predicted mean image. The ground truth image is fetched by the DataReader with the remaining data tuple, see here.

from tf-gqn.

Yangshell avatar Yangshell commented on June 2, 2024

Hello, I use:
with tf.train.SingularMonitoredSession() as sess:
d = sess.run(data)

import tensorflow as tf
from gqn.gqn_model import gqn_draw_model_fn
from gqn.gqn_params import _DEFAULTS

gqn_model = tf.estimator.Estimator(model_fn=gqn_draw_model_fn, model_dir='/Users/yangshell/Downloads/rooms_ring_debug/gqn_pool_draw2', params=_DEFAULTS)
result = gqn_model.predict(d)

The "result" I get is a generator class, not a dict. What is the mistake in it?
Thank you for your patience!

from tf-gqn.

SliMM avatar SliMM commented on June 2, 2024

Hi Yangshell,

A few notes on your code to help you get going with this:

The _DEFAULTS constant in gqn.gqn_params is meant for internal use (hence the underscore at the beginning). The actual default parameters are gqn.gqn_params.PARAMS.

If you take a look at the training script, where the Estimator is configured, the params attribute for the estimator is a dict { "gqn_params": PARAMS, "debug": FLAGS.debug }. You can set "debug": False.

Estimator.predict works similarly to Estimator.train [code example] or Estimator.evaluate, meaning you need to pass in an input function. We provide an input function that works with the GQN datasets: data_provider.gqn_tfr_provider.gqn_input_fn.

The result of Estimator.predict is indeed a generator class. This is a python object that you can iterate over (i.e. do things like for i in ...).

Putting all that together, your code will end up looking something like:

from gqn.gqn_model import gqn_draw_model_fn
from gqn.gqn_params import PARAMS
from data_provider.gqn_tfr_provider import gqn_input_fn

MODEL_DIR='/Users/yangshell/Downloads/rooms_ring_debug/gqn_pool_draw2'
DATA_DIR='/tmp/data/gqn-dataset'
DATASET='rooms_ring_camera'

estimator = tf.estimator.Estimator(
    model_fn=gqn_draw_model_fn,
    model_dir=MODEL_DIR,
    params={'gqn_params' : PARAMS,  'debug' : False})

input_fn = lambda mode: gqn_input_fn(
        dataset=DATASET,
        context_size=PARAMS.CONTEXT_SIZE,
        root=DATA_DIR,
        mode=mode)

for prediction in estimator.predict(input_fn=input_fn):
    # prediction is the dict @ogroth was mentioning
    print(prediction['predicted_mean'])  # this is probably what you want to look at
    print(prediction['predicted_variance'])  # or use this to sample a noisy image

If you already have a data_provider.gqn_tfr_provider.TaskData object with, say, numpy images as your input to the network, you could write a custom input_fn that maps that into tensors using something like tf.contrib.framework.nest, and then predict using that input function.

Let us know how this goes. :-)

Best,
Ștefan

from tf-gqn.

Yangshell avatar Yangshell commented on June 2, 2024

I have achieve the test process. But I found problem in image result.
This is true image:
query0
This is predict image:
test0
I used model "gqn_pool_draw12". You can see the effect of the wall is good, but the model did not predict the blue cylinder.

from tf-gqn.

ogroth avatar ogroth commented on June 2, 2024

Hey Yangshell,
This is not unexpected and happens occasionally when the model seems to be "unsure" about the geometry of objects. In such cases, it seems to fall back to the prediction of the room's geometry. This behaviour might be mitigated by training the model longer (we've trained for ~200K iterations, the paper reported ~2M iterations) or feeding different context views.
But if you find more odd behavior, please feel free to open a new issue and post your findings. We've just started to experiment with the model ourselves and sharing failure cases like this can be super helpful for other people working with the code. :)

from tf-gqn.

tarunsharma1 avatar tarunsharma1 commented on June 2, 2024

Can someone provide a clean implementation of how to use the network to predict an image from the test set? It seems like people have got it to work but no one has provided clean code that works now. This would be super helpful. Thanks

@Yangshell

EDIT: this works
#17 (comment)

from tf-gqn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.