Comments (8)
Hi Yangshell, you can run a trained model by setting up an estimator.
gqn_model = tf.estimator.Estimator(model_fn=gqn_draw_model_fn, model_dir='your/model/directory')
Obviously, 'you/model/directory'
needs to point to the directory, where a snapshot is stored (e.g. one of the model directories downloaded from the README).
Then you can call gqn_model.predict(.)
feeding context images and poses as well as a query pose. The data_provider test case shows how to run the DataReader
as a standalone object providing the correct data.
The predict(.)
function returns a dictionary with the predicted mean image and its variance.
For a better understanding of the model's inputs and outputs, please refer to the gqn_draw_model_fn
.
from tf-gqn.
Can I get the image that model predict and real?
from tf-gqn.
The predict(.)
returns a dictionary with the predicted mean image. The ground truth image is fetched by the DataReader with the remaining data tuple, see here.
from tf-gqn.
Hello, I use:
with tf.train.SingularMonitoredSession() as sess:
d = sess.run(data)
import tensorflow as tf
from gqn.gqn_model import gqn_draw_model_fn
from gqn.gqn_params import _DEFAULTS
gqn_model = tf.estimator.Estimator(model_fn=gqn_draw_model_fn, model_dir='/Users/yangshell/Downloads/rooms_ring_debug/gqn_pool_draw2', params=_DEFAULTS)
result = gqn_model.predict(d)
The "result" I get is a generator class, not a dict. What is the mistake in it?
Thank you for your patience!
from tf-gqn.
Hi Yangshell,
A few notes on your code to help you get going with this:
The _DEFAULTS
constant in gqn.gqn_params
is meant for internal use (hence the underscore at the beginning). The actual default parameters are gqn.gqn_params.PARAMS
.
If you take a look at the training script, where the Estimator
is configured, the params attribute for the estimator is a dict { "gqn_params": PARAMS, "debug": FLAGS.debug }
. You can set "debug": False
.
Estimator.predict
works similarly to Estimator.train
[code example] or Estimator.evaluate
, meaning you need to pass in an input function. We provide an input function that works with the GQN datasets: data_provider.gqn_tfr_provider.gqn_input_fn
.
The result of Estimator.predict
is indeed a generator class. This is a python object that you can iterate over (i.e. do things like for i in ...
).
Putting all that together, your code will end up looking something like:
from gqn.gqn_model import gqn_draw_model_fn
from gqn.gqn_params import PARAMS
from data_provider.gqn_tfr_provider import gqn_input_fn
MODEL_DIR='/Users/yangshell/Downloads/rooms_ring_debug/gqn_pool_draw2'
DATA_DIR='/tmp/data/gqn-dataset'
DATASET='rooms_ring_camera'
estimator = tf.estimator.Estimator(
model_fn=gqn_draw_model_fn,
model_dir=MODEL_DIR,
params={'gqn_params' : PARAMS, 'debug' : False})
input_fn = lambda mode: gqn_input_fn(
dataset=DATASET,
context_size=PARAMS.CONTEXT_SIZE,
root=DATA_DIR,
mode=mode)
for prediction in estimator.predict(input_fn=input_fn):
# prediction is the dict @ogroth was mentioning
print(prediction['predicted_mean']) # this is probably what you want to look at
print(prediction['predicted_variance']) # or use this to sample a noisy image
If you already have a data_provider.gqn_tfr_provider.TaskData
object with, say, numpy images as your input to the network, you could write a custom input_fn that maps that into tensors using something like tf.contrib.framework.nest
, and then predict using that input function.
Let us know how this goes. :-)
Best,
Ștefan
from tf-gqn.
I have achieve the test process. But I found problem in image result.
This is true image:
This is predict image:
I used model "gqn_pool_draw12". You can see the effect of the wall is good, but the model did not predict the blue cylinder.
from tf-gqn.
Hey Yangshell,
This is not unexpected and happens occasionally when the model seems to be "unsure" about the geometry of objects. In such cases, it seems to fall back to the prediction of the room's geometry. This behaviour might be mitigated by training the model longer (we've trained for ~200K iterations, the paper reported ~2M iterations) or feeding different context views.
But if you find more odd behavior, please feel free to open a new issue and post your findings. We've just started to experiment with the model ourselves and sharing failure cases like this can be super helpful for other people working with the code. :)
from tf-gqn.
Can someone provide a clean implementation of how to use the network to predict an image from the test set? It seems like people have got it to work but no one has provided clean code that works now. This would be super helpful. Thanks
EDIT: this works
#17 (comment)
from tf-gqn.
Related Issues (20)
- Predict result problem.
- measuring quality of pre-trained your snapshot HOT 4
- DRAW module as Generation architecture HOT 2
- Cannot reproduce visualization from snapshots HOT 28
- TypeError: Can't instantiate abstract class GQNTFRecordDataset with abstract methods _inputs HOT 3
- Question about pack_context function HOT 2
- data format
- AttributeError: 'GQNTFRecordDataset' object has no attribute '_graph_attr' HOT 2
- Training GQN with different dataset HOT 1
- GQN trained on CLEVR dataset HOT 10
- Self-Attention and other extensions HOT 2
- eval_summary_hook error HOT 2
- Dataset
- how many step we need
- how many train step we need HOT 2
- It stays in local minimum and never goes down HOT 1
- Here is why the loss stays in local minimum ... HOT 3
- How to train using multi-GPU HOT 1
- Gradient flowing between inference_cell and generator_cell
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tf-gqn.