Comments (3)
Thanks for sharing that fix -- I made a similar one in e9a62ba which is now published in v0.0.12. Please comment back if this doesn't resolve the issue. Thanks, again!
from vit-keras.
I'm not sure about the reason for this error message. I don't think we're re-using any layer names. But it appears to work if you use the input layer created as part of the ViT model (see below).
import tensorflow as tf
import vit_keras.vit as vit
base = vit.vit_b16(
image_size=256,
pretrained=True,
include_top=False,
pretrained_top=False,
)
x = tf.keras.layers.Dense(64,activation='relu')(base.output)
x = tf.keras.layers.Dense(5, activation='softmax')(x)
model = tf.keras.Model(inputs=base.inputs, outputs=x)
opt = tf.keras.optimizers.Adam()
loss = tf.keras.losses.CategoricalCrossentropy()
model.compile(optimizer=opt,loss=loss,metrics=['categorical_accuracy'])
model.save('ViT.h5')
Alternatively, we could (and perhaps should) provide the user an option to supply their own input layer. Feedback welcome!
from vit-keras.
@faustomorales well I think the name Dense_0
and Dense_1
is the reason behind this error. If you check the variable names you'll see that Dense
has 12 duplicates. I'm currently using this workaround. I'm simply giving each Dense
a unique name. I also had to make sure variable names do match with the weight.
Though this is not the best but solves this issue.
class TransformerBlock(tf.keras.layers.Layer):
"""Implements a Transformer block."""
def __init__(self, *args, num_heads, mlp_dim, dropout, n, **kwargs):
super().__init__(*args, **kwargs)
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.dropout = dropout
self.n = n
def build(self, input_shape):
self.att = MultiHeadSelfAttention(
num_heads=self.num_heads,
name=f"MultiHeadDotProductAttention_1",
)
self.mlpblock = tf.keras.Sequential(
[
tf.keras.layers.Dense(
self.mlp_dim, activation=tfa.activations.gelu, name=f"TB{self.n}_Dense_0"
),
tf.keras.layers.Dropout(self.dropout),
tf.keras.layers.Dense(input_shape[-1], name=f"TB{self.n}_Dense_1"),
tf.keras.layers.Dropout(self.dropout),
],
name=f"MlpBlock_3",
For matching the variable names with weight,
for match in matches:
source_keys_used.extend(match["keys"])
source_weights = [params_dict[k if not k.startswith('TB') else k[4:]] for k in match["keys"]]
if match.get("reshape", False):
source_weights = [
source.reshape(expected.shape)
for source, expected in zip(
source_weights, match["layer"].get_weights()
)
]
from vit-keras.
Related Issues (20)
- how to get the official pretrain model with npz format? HOT 1
- The [call] method in TransformerBlock has different input args HOT 1
- grayscale images HOT 1
- Multilabel: class-wise attention maps HOT 2
- image size HOT 1
- weights file failed to load ("File is not a zip file") HOT 2
- load_pretrained - BASE_URL 404 Not Found HOT 3
- different image size in fine-tuning HOT 3
- Different number of channels HOT 3
- error with visualization HOT 1
- Visualization of transfer learning model
- can you public a url to download the offcial weights? HOT 1
- load fine-tuned model with keras.load_model HOT 4
- Interest in implementing pre-trained weights from "ImageNet-21K Pretraining for the Masses" to vit-keras HOT 1
- Meaning of FlexErf node
- No requirements file makes compatibility issues challenging to debug. HOT 1
- Minimum image size to use pretrained weights HOT 1
- Finetuning
- Error importing the model HOT 1
- ModuleNotFoundError: No module named 'keras.src.engine' HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from vit-keras.