Comments (2)
you need to modified the main.py, such as:
def train(self):
create_dir(SAVE_DIR)
saver = tf.train.Saver(tf.global_variables(), max_to_keep = 10)
ntrain = len(self.train_data.image) # traindata : 55000 testdata : 10000
ntest = len(self.test_data.image)
nbatch = ntrain//self.batch_size
nbatch_t = ntest // self.batch_size
for epoch in range(self.epoch):
# shuffle start
index = np.arange(ntrain)
np.random.shuffle(index)
shuffle_image = self.train_data.image[index] # shuffle_image: 55000,784
shuffle_label = self.train_data.label[index]
########################## shuffle the testdata #####################
index_t = np.arange(ntest)
np.random.shuffle(index_t)
shuffle_test = self.test_data.image[index_t] # shuffle_test: 10000,784
shuffle_label_t = self.test_data.label[index_t]
######################################################################
# shuffle end
Train_accuracy = 0
Test_accuracy = 0
for batch in tqdm(range(nbatch), ascii = True, desc = "batch"):
start = self.batch_size*batch
end = self.batch_size*(batch+1)
train_feed_dict = {self.image : shuffle_image[start:end], self.label : shuffle_label[start:end]}
_, batch_accuracy = self.sess.run([self.run_train, self.accuracy], feed_dict = train_feed_dict)
Train_accuracy += batch_accuracy
Train_accuracy/=nbatch # avg acc
###################### I modified it .. #########################
# test_feed_dict = {self.image : self.test_data.image, self.label : self.test_data.label} #
for batch in tqdm(range(nbatch_t), ascii = True, desc = "batch"):
start = self.batch_size*batch
end = self.batch_size*(batch+1)
test_feed_dict = {self.image: shuffle_test[start:end], self.label: shuffle_label_t[start:end]}
test_accuracy = self.sess.run(self.accuracy, feed_dict=test_feed_dict)
Test_accuracy += test_accuracy
Test_accuracy/=nbatch_t # avg test acc
logger.info("Epoch({}/{}) train_accuracy : {}%, test_accuracy : {}%".format(epoch+1, self.epoch, Train_accuracy, Test_accuracy))
if epoch%self.save_every == self.save_every-1:
saver.save(self.sess, os.path.join(SAVE_DIR, 'model'), global_step = epoch+1)
hope it's not too late to you 。。。。
from deformable_convnet.
thank u , it run well
from deformable_convnet.
Related Issues (2)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from deformable_convnet.