Coder Social home page Coder Social logo

Run ROLO on tensorflow 1.1.8 about rolo HOT 8 OPEN

guanghan avatar guanghan commented on August 28, 2024 1
Run ROLO on tensorflow 1.1.8

from rolo.

Comments (8)

wanjinchang avatar wanjinchang commented on August 28, 2024 1

I modifiy the LSTM_Single function to
with tf.device('/gpu:0'):
# X, input shape: (batch_size, time_step_size, input_vec_size)
# XT shape: (time_step_size, batch_size, input_vec_size)
_X = tf.transpose(_X, [1, 0, 2]) # permute time_step_size and batch_size
# Reshape to prepare input to hidden activation
_X = tf.reshape(_X, [self.num_steps * self.batch_size, self.num_input]) # (num_steps*batch_size, num_input)
# Split data because rnn cell needs a list of inputs for the RNN inner loop
# Each array shape: (batch_size, num_input)
_X = tf.split(_X, self.num_steps, 0) # n_steps * (batch_size, num_input)
print(_X)

    cell = tf.nn.rnn_cell.LSTMCell(self.num_input, self.num_input)
    state = cell.zero_state(self.batch_size, dtype=tf.float32)
    outputs, state = tf.nn.static_rnn(cell, _X, initial_state=state, dtype=tf.float32)
    tf.get_variable_scope().reuse_variables()

it run correctly,but the result run on pretrainde model is wrong, cannot reproduce the result mentioned by author...

from rolo.

GarryLau avatar GarryLau commented on August 28, 2024 1

@bluetooth12
Looking at my rewrite file 'ROLO_network_test_single.py':


import sys,os
CURRENT_DIR=os.path.abspath('.')
print(CURRENT_DIR + '/utils' )
sys.path.append(CURRENT_DIR + '/utils' )

MODEl_DIR=os.path.abspath('../../../tensorflow')
print(MODEl_DIR)
sys.path.append(MODEl_DIR)

# Imports
import ROLO_utils as utils

import tensorflow as tf
from tensorflow.python.ops import rnn, rnn_cell
import cv2

import numpy as np
import os.path
import time
import random


class ROLO_TF:
    disp_console = True
    restore_weights = True#False

    # YOLO parameters
    fromfile = None
    tofile_img = 'test/output.jpg'
    tofile_txt = 'test/output.txt'
    imshow = True
    filewrite_img = False
    filewrite_txt = False
    disp_console = True
    yolo_weights_file = 'weights/YOLO_small.ckpt'
    alpha = 0.1
    threshold = 0.2
    iou_threshold = 0.5
    num_class = 20
    num_box = 2
    grid_size = 7
    classes =  ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train","tvmonitor"]
    w_img, h_img = [352, 240]

    # ROLO Network Parameters
    model_path = '/home/lg/projects/ROLO/new'
    # rolo_weights_file = '/u03/Guanghan/dev/ROLO-dev/model_dropout_30.ckpt'
    meta_path = '/home/lg/projects/ROLO/new/model_demo_new.meta'

    lstm_depth = 3
    num_steps = 3  # number of frames as an input sequence
    num_feat = 4096
    num_predict = 6 # final output of LSTM 6 loc parameters
    num_gt = 4
    num_input = num_feat + num_predict # data input: 4096+6= 5002

    # ROLO Parameters
    batch_size = 1
    display_step = 1

    # tf Graph input
    x = tf.placeholder("float32", [None, num_steps, num_input])
    istate = tf.placeholder("float32", [None, 2*num_input]) #state & cell => 2x num_input
    y = tf.placeholder("float32", [None, num_gt])

    # Define weights
    weights = {
        'out': tf.Variable(tf.random_normal([num_input, num_predict]))
    }
    biases = {
        'out': tf.Variable(tf.random_normal([num_predict]))
    }


    def __init__(self,argvs = []):
        print("ROLO init")
        self.ROLO(argvs)


    def LSTM_single(self, name,  _X, _istate, _weights, _biases):
        _X = tf.transpose(_X, [1, 0, 2]) # permute time_step_size and batch_size
        _X = tf.reshape(_X, [self.num_steps * self.batch_size, self.num_input]) # (num_steps*batch_size, num_input)
        _X = tf.split(_X, self.num_steps, 0) # n_steps * (batch_size, num_input)
        cell = tf.nn.rnn_cell.LSTMCell(self.num_input, self.num_input)
        state = cell.zero_state(self.batch_size, dtype=tf.float32)
        outputs, state = tf.nn.static_rnn(cell, _X, initial_state=state, dtype=tf.float32)
        tf.get_variable_scope().reuse_variables()   
        return outputs

        # Experiment with dropout
    def dropout_features(self, feature, prob):
        num_drop = int(prob * 4096)
        drop_index = random.sample(xrange(4096), num_drop)
        for i in range(len(drop_index)):
            index = drop_index[i]
            feature[index] = 0
        return feature
    '''---------------------------------------------------------------------------------------'''
    def build_networks(self):
        if self.disp_console : print ("Building ROLO graph...")

        # Build rolo layers
        self.lstm_module = self.LSTM_single('lstm_test', self.x, self.istate, self.weights, self.biases)
        self.ious= tf.Variable(tf.zeros([self.batch_size]), name="ious")
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.import_meta_graph(self.meta_path)       #loading map
        self.saver.restore(self.sess, tf.train.latest_checkpoint(self.model_path))
        if self.disp_console : print ("Loading complete!" + '\n')


    def testing(self, x_path, y_path):
        total_loss = 0

        print("TESTING ROLO...")
        # Use rolo_input for LSTM training
        pred = self.LSTM_single('lstm_train', self.x, self.istate, self.weights, self.biases)
        print("pred: ", pred)
        self.pred_location = pred[0][:, 4097:4101]
        print("pred_location: ", self.pred_location)
        print("self.y: ", self.y)

        self.correct_prediction = tf.square(self.pred_location - self.y)
        print("self.correct_prediction: ", self.correct_prediction)
        self.accuracy = tf.reduce_mean(self.correct_prediction) * 100
        print("self.accuracy: ", self.accuracy)

        # Initializing the variables
        init = tf.initialize_all_variables()

        # Launch the graph
        with tf.Session() as sess:
            if (self.restore_weights == True):
                sess.run(init)
                self.saver.restore(self.sess, tf.train.latest_checkpoint(self.model_path))
                print ("Loading complete!" + '\n')
            else:
                sess.run(init)

            id = 0 #don't change this

            # Keep training until reach max iterations
            while id < self.testing_iters - self.num_steps:
                # Load training data & ground truth
                batch_xs = self.rolo_utils.load_yolo_output_test(x_path, self.batch_size, self.num_steps, id) # [num_of_examples, num_input] (depth == 1)

                # Apply dropout to batch_xs
                #for item in range(len(batch_xs)):
                #    batch_xs[item] = self.dropout_features(batch_xs[item], 0.4)

                batch_ys = self.rolo_utils.load_rolo_gt_test(y_path, self.batch_size, self.num_steps, id)
                print("Batch_ys_initial: ", batch_ys)
                batch_ys = utils.locations_from_0_to_1(self.w_img, self.h_img, batch_ys)


                # Reshape data to get 3 seq of 5002 elements
                batch_xs = np.reshape(batch_xs, [self.batch_size, self.num_steps, self.num_input])
                batch_ys = np.reshape(batch_ys, [self.batch_size, 4])
                print("Batch_ys: ", batch_ys)

                pred_location= sess.run(self.pred_location,feed_dict={self.x: batch_xs, self.y: batch_ys, self.istate: np.zeros((self.batch_size, 2*self.num_input))})
                print("ROLO Pred: ", pred_location)
                print("ROLO Pred in pixel: ", pred_location[0][0]*self.w_img, pred_location[0][1]*self.h_img, pred_location[0][2]*self.w_img, pred_location[0][3]*self.h_img)


                # Save pred_location to file
                utils.save_rolo_output_test(self.output_path, pred_location, id, self.num_steps, self.batch_size)

                #sess.run(optimizer, feed_dict={self.x: batch_xs, self.y: batch_ys, self.istate: np.zeros((self.batch_size, 2*self.num_input))})

                if id % self.display_step == 0:
                    # Calculate batch loss
                    loss = sess.run(self.accuracy, feed_dict={self.x: batch_xs, self.y: batch_ys, self.istate: np.zeros((self.batch_size, 2*self.num_input))})
                    print ("Iter " + str(id*self.batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss)) #+ "{:.5f}".format(self.accuracy)
                    total_loss += loss
                id += 1
                print(id)

            print ("Testing Finished!")
            avg_loss = total_loss/id
            print ("Avg loss: " + str(avg_loss))

        return None

    def ROLO(self, argvs):

            self.rolo_utils= utils.ROLO_utils()
            self.rolo_utils.loadCfg()
            self.params = self.rolo_utils.params

            arguments = self.rolo_utils.argv_parser(argvs)

            if self.rolo_utils.flag_train is True:
                self.training(utils.x_path, utils.y_path)
            elif self.rolo_utils.flag_track is True:
                self.build_networks()
                self.track_from_file(utils.file_in_path)
            elif self.rolo_utils.flag_detect is True:
                self.build_networks()
                self.detect_from_file(utils.file_in_path)
            else:
                print ("Default: running ROLO test.")
                self.build_networks()

                test= 8
                [self.w_img, self.h_img, sequence_name, dummy_1, self.testing_iters] = utils.choose_video_sequence(test)

                x_path = os.path.join('../DATA', sequence_name, 'yolo_out/')
                y_path = os.path.join('../DATA', sequence_name, 'groundtruth_rect.txt')
                self.output_path = os.path.join('../DATA', sequence_name, 'rolo_out_test/')
                utils.createFolder(self.output_path)

                self.testing(x_path, y_path)

    '''----------------------------------------main-----------------------------------------------------'''
def main(argvs):
        ROLO_TF(argvs)

if __name__=='__main__':
        main(' ')

from rolo.

wei-1234567 avatar wei-1234567 commented on August 28, 2024

请问你在跑的过程中遇到
ValueError:使用无效保存路径调用恢复/u03/Guanghan/dev/ROLO-dev/output/ROLO_model/model_step3_exp1_old.ckpt
问题吗,是怎样解决的

from rolo.

wanjinchang avatar wanjinchang commented on August 28, 2024

你把这个路径改成你本地模型的路径就可以啦

from rolo.

dvidal8 avatar dvidal8 commented on August 28, 2024

@wanjinchang what model .ckpt did you use? I got the next error:

NotFoundError (see above for traceback): Tensor name "rnn/lstm_cell/bias" not found in checkpoint files /home/dvn/ROLO/models/model_demo.ckpt

And the error shows up with any of the uploaded models...

from rolo.

wanjinchang avatar wanjinchang commented on August 28, 2024

The TensorFlow version of this project maybe 0.x,if you use newer version upper to 1.x, the api of the lstm have been changed.You should use the follow code for reference to convert the model to fix this problem.
OLD_CHECKPOINT_FILE = "your path/model_step3_exp3.ckpt"
NEW_CHECKPOINT_FILE = "your path/model_step3_exp3_new.ckpt"

import tensorflow as tf
vars_to_rename = {
"RNN/LSTMCell/W_0": "rnn/lstm_cell/kernel",
"RNN/LSTMCell/B": "rnn/lstm_cell/bias",
}
new_checkpoint_vars = {}
reader = tf.train.NewCheckpointReader(OLD_CHECKPOINT_FILE)
for old_name in reader.get_variable_to_shape_map():
print(old_name)
if old_name in vars_to_rename:
new_name = vars_to_rename[old_name]
else:
new_name = old_name
new_checkpoint_vars[new_name] = tf.Variable(reader.get_tensor(old_name))

init = tf.global_variables_initializer()
saver = tf.train.Saver(new_checkpoint_vars)

with tf.Session() as sess:
sess.run(init)
saver.save(sess, NEW_CHECKPOINT_FILE)

from rolo.

bluetooth12 avatar bluetooth12 commented on August 28, 2024

I tried the code above (see below) and I get three files:
model_step3_exp3_new.ckpt.data-00000-of-00001
model_step3_exp3_new.ckpt.index
model_step3_exp3_new.ckpt.meta

What do I put in ROLO_network_test_all.py (approx line 264)?
self.rolo_weights_file= 'C:/Users/el006794/ROLO-master/model_step3_exp3.ckpt'

OLD_CHECKPOINT_FILE = "C:/Users/el006794/ROLO-master/model_step3_exp3.ckpt"
NEW_CHECKPOINT_FILE = "C:/Users/el006794/ROLO-master/model_step3_exp3_new.ckpt"

import tensorflow as tf
vars_to_rename = {
"RNN/LSTMCell/W_0": "rnn/lstm_cell/kernel",
"RNN/LSTMCell/B": "rnn/lstm_cell/bias",
}
new_checkpoint_vars = {}
reader = tf.train.NewCheckpointReader(OLD_CHECKPOINT_FILE)
for old_name in reader.get_variable_to_shape_map():
print(old_name)
if old_name in vars_to_rename:
new_name = vars_to_rename[old_name]
else:
new_name = old_name
new_checkpoint_vars[new_name] = tf.Variable(reader.get_tensor(old_name))

init = tf.global_variables_initializer()
saver = tf.train.Saver(new_checkpoint_vars)

with tf.Session() as sess:
sess.run(init)
saver.save(sess, NEW_CHECKPOINT_FILE)
#saver.restore(sess, NEW_CHECKPOINT_FILE)

with tf.Session() as sess:
saver = tf.train.import_meta_graph('model_step3_exp3_new.ckpt.meta')
#saver.restore(sess, "model_step3_exp3_new.ckpt.data-00000-of-00001")
saver.restore(sess, NEW_CHECKPOINT_FILE)

#with tf.Session() as sess:

saver = tf.train.Saver()

saver.restore(sess, NEW_CHECKPOINT_FILE)

from rolo.

pari93411 avatar pari93411 commented on August 28, 2024

Hi!
How can I get these three files?
I run the code but it gives me the error about the restore file.

from rolo.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.