Comments (2)
This is the whole error I'm getting
---------------------------------------------------------------------------
ScopeParamShapeError Traceback (most recent call last)
Cell In[16], line 2
1 # run the forward pass (JIT compiled the first time it is called)
----> 2 pred_ids = p_generate(input_features)
3 output_ids = device_get(pred_ids.reshape(-1, model.config.max_length))
[... skipping hidden 12 frame]
Cell In[9], line 8, in generate_fn(input_features)
7 def generate_fn(input_features):
----> 8 pred_ids = model.generate(
9 input_features, task="transcribe", return_timestamps=False, max_length=model.config.max_length, params=params,
10 )
11 return pred_ids.sequences
File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:1588, in FlaxWhisperForConditionalGeneration.generate(self, input_features, generation_config, logits_processor, return_timestamps, task, language, is_multilingual, **kwargs)
1585 if len(forced_decoder_ids) > 0:
1586 generation_config.forced_decoder_ids = forced_decoder_ids
-> 1588 return super().generate(
1589 input_features,
1590 generation_config,
1591 logits_processor=logits_processor,
1592 **kwargs,
1593 )
File /opt/conda/lib/python3.10/site-packages/transformers/generation/flax_utils.py:372, in FlaxGenerationMixin.generate(self, input_ids, generation_config, prng_key, trace, params, logits_processor, **kwargs)
369 if self.config.is_encoder_decoder:
370 # add encoder_outputs to model_kwargs
371 if model_kwargs.get("encoder_outputs") is None:
--> 372 model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
373 # prepare decoder_input_ids for generation
374 input_ids = self._prepare_decoder_input_ids_for_generation(
375 batch_size,
376 decoder_start_token_id=generation_config.decoder_start_token_id,
377 bos_token_id=generation_config.bos_token_id,
378 model_kwargs=model_kwargs,
379 )
File /opt/conda/lib/python3.10/site-packages/transformers/generation/flax_utils.py:167, in FlaxGenerationMixin._prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs)
161 def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
162 encoder_kwargs = {
163 argument: value
164 for argument, value in model_kwargs.items()
165 if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
166 }
--> 167 model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
168 return model_kwargs
File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:1132, in FlaxWhisperPreTrainedModel.encode(self, input_features, attention_mask, output_attentions, output_hidden_states, return_dict, train, params, dropout_rng, **kwargs)
1129 encode_module = module._get_encoder_module()
1130 return encode_module(input_features, **kwargs)
-> 1132 return self.module.apply(
1133 {"params": params or self.params},
1134 input_features=jnp.array(input_features, dtype="f4"),
1135 output_attentions=output_attentions,
1136 output_hidden_states=output_hidden_states,
1137 return_dict=return_dict,
1138 deterministic=not train,
1139 rngs=rngs,
1140 method=_encoder_forward,
1141 )
[... skipping hidden 4 frame]
File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:1130, in FlaxWhisperPreTrainedModel.encode.<locals>._encoder_forward(module, input_features, **kwargs)
1128 def _encoder_forward(module, input_features, **kwargs):
1129 encode_module = module._get_encoder_module()
-> 1130 return encode_module(input_features, **kwargs)
[... skipping hidden 2 frame]
File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:824, in FlaxWhisperEncoder.__call__(self, input_features, output_attentions, output_hidden_states, return_dict, deterministic)
817 raise ValueError(
818 "input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
819 f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be"
820 f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))"
821 )
823 input_features = input_features.transpose(0, 2, 1)
--> 824 hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
825 hidden_states = with_sharding_constraint(hidden_states, ("batch", "embed", "num_mel"))
826 hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
[... skipping hidden 2 frame]
File /opt/conda/lib/python3.10/site-packages/whisper_jax/layers.py:1205, in _Conv.__call__(self, inputs)
1200 if self.mask is not None and self.mask.shape != kernel_shape:
1201 raise ValueError(
1202 "Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {kernel_shape}"
1203 )
-> 1205 kernel = param_with_axes(
1206 "kernel",
1207 self.kernel_init,
1208 kernel_shape,
1209 self.params_dtype,
1210 axes=self.kernel_axes,
1211 )
1213 if self.mask is not None:
1214 kernel *= self.mask
File /opt/conda/lib/python3.10/site-packages/flax/linen/partitioning.py:159, in param_with_axes(name, init_fn, axes, module, *init_args, **init_kwargs)
157 assert module is not None
158 # define/fetch parameter on that module
--> 159 module_param = module.param(name, init_fn, *init_args, **init_kwargs)
160 if axes is not None:
161 # apply logical axis constraint immediately
162 module_param = with_sharding_constraint(
163 module_param, jax.sharding.PartitionSpec(*axes)
164 )
[... skipping hidden 1 frame]
File /opt/conda/lib/python3.10/site-packages/flax/core/scope.py:982, in Scope.param(self, name, init_fn, unbox, *init_args, **init_kwargs)
977 for val, abs_val in zip(value_flat, abs_value_flat):
978 # NOTE: We could check dtype consistency here as well but it's
979 # usefuleness is less obvious. We might intentionally change the dtype
980 # for inference to a half float type for example.
981 if jnp.shape(val) != jnp.shape(abs_val):
--> 982 raise errors.ScopeParamShapeError(
983 name, self.path_text, jnp.shape(abs_val), jnp.shape(val)
984 )
985 else:
986 if not self.is_mutable_collection('params'):
ScopeParamShapeError: Initializer expected to generate shape (2, 3, 80, 768) but got shape (3, 80, 768) instead for parameter "kernel" in "/model/encoder/conv1"
from whisper-jax.
I am also getting this error? I tried downgrading whisper-jax to an older commit and other things start breaking.
@sanchit-gandhi Pl let us know how to solve this
from whisper-jax.
Related Issues (20)
- How to add millisecond for the timestamp?
- I have downloaded the flax_model, where can I call it?
- why whisper-jax did not use my GPU? HOT 3
- Rust impl
- Unsuccessful deployment HOT 1
- Coral TPU support HOT 2
- Slower than openai whisper with my gpu HOT 2
- I want to use whisper-at models HOT 1
- Has translate be integrated into transcribe? It returns English but expect Chinese. HOT 3
- Slow post processing HOT 1
- unable to run TPU using current kaggle environment HOT 1
- Large Model causing performance degradation?
- HuggingFace space erroring more often than usual HOT 1
- Transcription issues.
- Punctuation mark
- Confidence score and average log probability on Whisper-JAX
- whisper-large-v3 (in demo code) VS whisper-large-v2 (in kaggle notebook)
- Add wrapper for wyoming API
- Kernel always restarting when JIT compiling the forward call on MacBook Pro M3 Max
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from whisper-jax.