Coder Social home page Coder Social logo

Comments (8)

uricohen avatar uricohen commented on April 28, 2024

An example which reproduce those failures is available at my fork:

python demogen/parse_tuning.py

from google-research.

yidingjiang avatar yidingjiang commented on April 28, 2024

Are the problems with resnet's only?

from google-research.

uricohen avatar uricohen commented on April 28, 2024

from google-research.

yidingjiang avatar yidingjiang commented on April 28, 2024

I couldn't find extract_layers_util in your repo, but one quick thing to try: can you try to call tf.reset_default_graph between loading different models?

from google-research.

uricohen avatar uricohen commented on April 28, 2024

from google-research.

uricohen avatar uricohen commented on April 28, 2024

This indeed solve the issue for all batchnorm models in resnet, but not for groupnorm!

The following error are no longer there:

  • Most fail with Not found: Key resnet/group_norm/beta not found in checkpoint
  • Few fail with ValueError: Trying to share variable resnet/conv2d/kernel, but specified shape (3, 3, 3, 32) and found shape (3, 3, 3, 16).

The following error is still there, in all groupnorm models:

  • Many fail with Invalid argument: Assign requires shapes of both tensors to match. lhs shape= [1,32,1,1] rhs shape= [32]

That is, for resnet I could read 108 / 216 cifar10 models and 162 / 324 of cifar100 models.

from google-research.

yidingjiang avatar yidingjiang commented on April 28, 2024

I think the issue is that in the original code the tensor shapes are initialized as [c] and reshaped to [1, c, 1,1] but it was changed later to initializing the tensorshape with [1, c, 1, 1] directly. My bad that I didn't catch it. It might take a me bit of time to push the change, but if you do the following it should fix the issue:

  1. Go to models/resent.py
  2. Go to the function group_norm
  3. Change:
    gamma = tf.get_variable('gamma', [1, c, 1, 1],
                            initializer=tf.constant_initializer(1.0))
    beta = tf.get_variable('beta', [1, c, 1, 1],
                           initializer=tf.constant_initializer(0.0))

to

    gamma = tf.get_variable('gamma', [c],
                            initializer=tf.constant_initializer(1.0))
    beta = tf.get_variable('beta', [c],
                           initializer=tf.constant_initializer(0.0))
    gamma = tf.reshape(gamma, [1, c, 1, 1])
    beta = tf.reshape(beta, [1, c, 1, 1])

from google-research.

uricohen avatar uricohen commented on April 28, 2024

from google-research.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.