Coder Social home page Coder Social logo

Comments (2)

hkaggarwal avatar hkaggarwal commented on July 19, 2024

Hi,
It is sometimes difficult to suggest a possible solution without seeing the minimal code that will reproduce the error.

  1. Regarding out of memory error:
    In the trn.py, You will get OOM error in at least following two cases:
    a. When you use a large batch size. It is in general true.
    b. When you increase the number of iterations K such that the model does not fit on the GPU.
    In the paper, we used K=10 with a batch size of 1 and that fits well on 16 GB P100 card.
    So for the 12 GB card, I would suggest trying with maximum K=7.
    Another issue could be RAM size, if you try to load entire dataset at once than it may not fit in the RAM since several copies of some Tensors are created internally by TensorFlow, I guess.
    You may take help of "$nvidia-smi" to check how much memory is being used.

  2. Sharing the variables is easy by using AUTO_REUSE in the variable scope like this:
    with tf.variable_scope('SomeName',reuse=tf.AUTO_REUSE):

  3. In my opinion, There are two ways to use pretrained model.
    First, By restoring the weight, but as you noticed, it is possible only if you have exactly the same model. So if you trained a model with K=1 then you can restore those weights in K=2 or any higher value. But in that case you should not change the model or variable names.
    Second, you can extract the weights from the trained model in the numpy array and instead of use Xavior or random initializer you can initialize with those variables. I have this small code which can extract weights in numpy dictionary. So you may try to adapt from this.

def getWeights(wtsDir,chkPointNum):
    """
    Input:
        wtsDir: Full path of directory containing modelTst.meta
        nLay: no. of convolution+BN+ReLu blocks in the model
    output:
        wt: numpy dictionary containing the weights. The keys names ae full
        names of corersponding tensors in the model.
    """
    tf.reset_default_graph()
    if chkPointNum=='last':
        loadChkPoint=tf.train.latest_checkpoint(wtsDir)
    else:
        loadChkPoint=wtsDir+'/model'+chkPointNum
    config = tf.ConfigProto()
    config.gpu_options.allow_growth=True
    with tf.Session(config=config) as s1:
        saver = tf.train.import_meta_graph(wtsDir + '/modelTst.meta')
        saver.restore(s1, loadChkPoint)
        keys=[n.name+':0' for n in tf.get_default_graph().as_graph_def().node if "Variable" in n.op]
        var=tf.global_variables()

        wt={}
        for key in keys:
            va=[v for v in var if v.name==key][0]
            wt[key]=s1.run(va)

    tf.reset_default_graph()
    return wt

from modl.

duancaohui avatar duancaohui commented on July 19, 2024

Very detailed and kindly reply,Thanks!

from modl.

Related Issues (11)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.