Coder Social home page Coder Social logo

Comments (3)

faustomorales avatar faustomorales commented on August 16, 2024 1

Thanks for sharing that fix -- I made a similar one in e9a62ba which is now published in v0.0.12. Please comment back if this doesn't resolve the issue. Thanks, again!

from vit-keras.

faustomorales avatar faustomorales commented on August 16, 2024

I'm not sure about the reason for this error message. I don't think we're re-using any layer names. But it appears to work if you use the input layer created as part of the ViT model (see below).

import tensorflow as tf
import vit_keras.vit as vit
base = vit.vit_b16(
    image_size=256,
    pretrained=True,
    include_top=False,
    pretrained_top=False,
)
x = tf.keras.layers.Dense(64,activation='relu')(base.output)
x = tf.keras.layers.Dense(5, activation='softmax')(x)
model = tf.keras.Model(inputs=base.inputs, outputs=x)
opt = tf.keras.optimizers.Adam()
loss = tf.keras.losses.CategoricalCrossentropy()
model.compile(optimizer=opt,loss=loss,metrics=['categorical_accuracy'])
model.save('ViT.h5')

Alternatively, we could (and perhaps should) provide the user an option to supply their own input layer. Feedback welcome!

from vit-keras.

awsaf49 avatar awsaf49 commented on August 16, 2024

@faustomorales well I think the name Dense_0 and Dense_1 is the reason behind this error. If you check the variable names you'll see that Dense has 12 duplicates. I'm currently using this workaround. I'm simply giving each Dense a unique name. I also had to make sure variable names do match with the weight.
Though this is not the best but solves this issue.

class TransformerBlock(tf.keras.layers.Layer):
    """Implements a Transformer block."""
    def __init__(self, *args, num_heads, mlp_dim, dropout, n, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.dropout = dropout
        self.n       = n

    def build(self, input_shape):
        self.att = MultiHeadSelfAttention(
            num_heads=self.num_heads,
            name=f"MultiHeadDotProductAttention_1",
        )
        self.mlpblock = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(
                    self.mlp_dim, activation=tfa.activations.gelu, name=f"TB{self.n}_Dense_0"
                ),
                tf.keras.layers.Dropout(self.dropout),
                tf.keras.layers.Dense(input_shape[-1], name=f"TB{self.n}_Dense_1"),
                tf.keras.layers.Dropout(self.dropout),
            ],
            name=f"MlpBlock_3",

For matching the variable names with weight,

for match in matches:
    source_keys_used.extend(match["keys"])
    source_weights = [params_dict[k if not k.startswith('TB') else k[4:]] for k in match["keys"]]
    if match.get("reshape", False):
        source_weights = [
            source.reshape(expected.shape)
            for source, expected in zip(
                source_weights, match["layer"].get_weights()
            )
        ]

from vit-keras.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.