import tensorflow as tf
import sys
import skimage
import skimage.io
import skimage.transform
import numpy as np
synset = [l.strip() for l in open('/home/ubuntu/tensorflow-vgg16/synset.txt').readlines()]
VGG_MEAN = [103.939, 116.779, 123.68]
# returns image of shape [224, 224, 3]
# [height, width, depth]
def load_image(path):
# load image
img = skimage.io.imread(path)
img = img / 255.0
assert (0 <= img).all() and (img <= 1.0).all()
#print "Original Image Shape: ", img.shape
# we crop image from center
short_edge = min(img.shape[:2])
yy = int((img.shape[0] - short_edge) / 2)
xx = int((img.shape[1] - short_edge) / 2)
crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
# resize to 224, 224
resized_img = skimage.transform.resize(crop_img, (224, 224))
return resized_img
# returns the top1 string
def print_prob(prob):
#print prob
print "prob shape", prob.shape
pred = np.argsort(prob)[::-1]
# Get top1 label
top1 = synset[pred[0]]
print "Top1: ", top1
# Get top5 label
top5 = [synset[pred[i]] for i in range(5)]
print "Top5: ", top5
return top1
with open("/home/ubuntu/vgg16-v4.tfmodel", mode='rb') as f:
fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)
images = tf.placeholder("float", [None, 224, 224, 3])
tf.import_graph_def(graph_def, input_map={ "images": images })
print "graph loaded from disk"
graph = tf.get_default_graph()
cat = load_image("cat.jpg")
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
print "variables initialized"
batch = cat.reshape((1, 224, 224, 3))
assert batch.shape == (1, 224, 224, 3)
feed_dict = { images: batch }
prob_tensor = graph.get_tensor_by_name("import/prob:0")
prob = sess.run(prob_tensor, feed_dict=feed_dict)
print_prob(prob[0])
---------------------------------------------------------------------------
DecodeError Traceback (most recent call last)
<ipython-input-1-c8f1d9f927de> in <module>()
48
49 graph_def = tf.GraphDef()
---> 50 graph_def.ParseFromString(fileContent)
51
52 images = tf.placeholder("float", [None, 224, 224, 3])
/usr/local/lib/python2.7/dist-packages/google/protobuf/message.py in ParseFromString(self, serialized)
183 """
184 self.Clear()
--> 185 self.MergeFromString(serialized)
186
187 def SerializeToString(self):
/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py in MergeFromString(self, serialized)
1006 length = len(serialized)
1007 try:
-> 1008 if self._InternalParse(serialized, 0, length) != length:
1009 # The only reason _InternalParse would return early is if it
1010 # encountered an end-group tag.
/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py in InternalParse(self, buffer, pos, end)
1042 pos = new_pos
1043 else:
-> 1044 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1045 if field_desc:
1046 self._UpdateOneofState(field_desc)
/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py in DecodeRepeatedField(buffer, pos, end, message, field_dict)
626 new_pos = pos + size
627 if new_pos > end:
--> 628 raise _DecodeError('Truncated message.')
629 # Read sub-message.
630 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
DecodeError: Truncated message.