Comments (2)
Hi,
It is sometimes difficult to suggest a possible solution without seeing the minimal code that will reproduce the error.
-
Regarding out of memory error:
In the trn.py, You will get OOM error in at least following two cases:
a. When you use a large batch size. It is in general true.
b. When you increase the number of iterations K such that the model does not fit on the GPU.
In the paper, we used K=10 with a batch size of 1 and that fits well on 16 GB P100 card.
So for the 12 GB card, I would suggest trying with maximum K=7.
Another issue could be RAM size, if you try to load entire dataset at once than it may not fit in the RAM since several copies of some Tensors are created internally by TensorFlow, I guess.
You may take help of "$nvidia-smi" to check how much memory is being used. -
Sharing the variables is easy by using
AUTO_REUSE
in thevariable scope
like this:
with tf.variable_scope('SomeName',reuse=tf.AUTO_REUSE):
-
In my opinion, There are two ways to use pretrained model.
First, By restoring the weight, but as you noticed, it is possible only if you have exactly the same model. So if you trained a model with K=1 then you can restore those weights in K=2 or any higher value. But in that case you should not change the model or variable names.
Second, you can extract the weights from the trained model in the numpy array and instead of use Xavior or random initializer you can initialize with those variables. I have this small code which can extract weights in numpy dictionary. So you may try to adapt from this.
def getWeights(wtsDir,chkPointNum):
"""
Input:
wtsDir: Full path of directory containing modelTst.meta
nLay: no. of convolution+BN+ReLu blocks in the model
output:
wt: numpy dictionary containing the weights. The keys names ae full
names of corersponding tensors in the model.
"""
tf.reset_default_graph()
if chkPointNum=='last':
loadChkPoint=tf.train.latest_checkpoint(wtsDir)
else:
loadChkPoint=wtsDir+'/model'+chkPointNum
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as s1:
saver = tf.train.import_meta_graph(wtsDir + '/modelTst.meta')
saver.restore(s1, loadChkPoint)
keys=[n.name+':0' for n in tf.get_default_graph().as_graph_def().node if "Variable" in n.op]
var=tf.global_variables()
wt={}
for key in keys:
va=[v for v in var if v.name==key][0]
wt[key]=s1.run(va)
tf.reset_default_graph()
return wt
from modl.
Very detailed and kindly reply,Thanks!
from modl.
Related Issues (11)
- parameter setting to reproduce the result in tstDemo.py HOT 10
- The knee data you used for training HOT 1
- Why do you calculate the gradient for the conjugate gradient mannually? HOT 2
- Will MoDL-MUSSELS open source? HOT 2
- The knee data seems to be lost HOT 1
- Compatible with Tensorflow 2.0 HOT 1
- Not able to download knee data HOT 1
- Is the version of tensorflow above 1.7 or below? HOT 1
- ValueError while runing trn.py
- Question in Conjugate Gradient method
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from modl.