Coder Social home page Coder Social logo

Comments (2)

LasseRogers avatar LasseRogers commented on May 28, 2024

This is the whole error I'm getting

---------------------------------------------------------------------------
ScopeParamShapeError                      Traceback (most recent call last)
Cell In[16], line 2
      1 # run the forward pass (JIT compiled the first time it is called)
----> 2 pred_ids = p_generate(input_features)
      3 output_ids = device_get(pred_ids.reshape(-1, model.config.max_length))

    [... skipping hidden 12 frame]

Cell In[9], line 8, in generate_fn(input_features)
      7 def generate_fn(input_features):
----> 8     pred_ids = model.generate(
      9         input_features, task="transcribe", return_timestamps=False, max_length=model.config.max_length, params=params,
     10     )
     11     return pred_ids.sequences

File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:1588, in FlaxWhisperForConditionalGeneration.generate(self, input_features, generation_config, logits_processor, return_timestamps, task, language, is_multilingual, **kwargs)
   1585 if len(forced_decoder_ids) > 0:
   1586     generation_config.forced_decoder_ids = forced_decoder_ids
-> 1588 return super().generate(
   1589     input_features,
   1590     generation_config,
   1591     logits_processor=logits_processor,
   1592     **kwargs,
   1593 )

File /opt/conda/lib/python3.10/site-packages/transformers/generation/flax_utils.py:372, in FlaxGenerationMixin.generate(self, input_ids, generation_config, prng_key, trace, params, logits_processor, **kwargs)
    369 if self.config.is_encoder_decoder:
    370     # add encoder_outputs to model_kwargs
    371     if model_kwargs.get("encoder_outputs") is None:
--> 372         model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs)
    373     # prepare decoder_input_ids for generation
    374     input_ids = self._prepare_decoder_input_ids_for_generation(
    375         batch_size,
    376         decoder_start_token_id=generation_config.decoder_start_token_id,
    377         bos_token_id=generation_config.bos_token_id,
    378         model_kwargs=model_kwargs,
    379     )

File /opt/conda/lib/python3.10/site-packages/transformers/generation/flax_utils.py:167, in FlaxGenerationMixin._prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs)
    161 def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
    162     encoder_kwargs = {
    163         argument: value
    164         for argument, value in model_kwargs.items()
    165         if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
    166     }
--> 167     model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
    168     return model_kwargs

File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:1132, in FlaxWhisperPreTrainedModel.encode(self, input_features, attention_mask, output_attentions, output_hidden_states, return_dict, train, params, dropout_rng, **kwargs)
   1129     encode_module = module._get_encoder_module()
   1130     return encode_module(input_features, **kwargs)
-> 1132 return self.module.apply(
   1133     {"params": params or self.params},
   1134     input_features=jnp.array(input_features, dtype="f4"),
   1135     output_attentions=output_attentions,
   1136     output_hidden_states=output_hidden_states,
   1137     return_dict=return_dict,
   1138     deterministic=not train,
   1139     rngs=rngs,
   1140     method=_encoder_forward,
   1141 )

    [... skipping hidden 4 frame]

File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:1130, in FlaxWhisperPreTrainedModel.encode.<locals>._encoder_forward(module, input_features, **kwargs)
   1128 def _encoder_forward(module, input_features, **kwargs):
   1129     encode_module = module._get_encoder_module()
-> 1130     return encode_module(input_features, **kwargs)

    [... skipping hidden 2 frame]

File /opt/conda/lib/python3.10/site-packages/whisper_jax/modeling_flax_whisper.py:824, in FlaxWhisperEncoder.__call__(self, input_features, output_attentions, output_hidden_states, return_dict, deterministic)
    817     raise ValueError(
    818         "input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
    819         f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be"
    820         f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))"
    821     )
    823 input_features = input_features.transpose(0, 2, 1)
--> 824 hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
    825 hidden_states = with_sharding_constraint(hidden_states, ("batch", "embed", "num_mel"))
    826 hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)

    [... skipping hidden 2 frame]

File /opt/conda/lib/python3.10/site-packages/whisper_jax/layers.py:1205, in _Conv.__call__(self, inputs)
   1200 if self.mask is not None and self.mask.shape != kernel_shape:
   1201     raise ValueError(
   1202         "Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {kernel_shape}"
   1203     )
-> 1205 kernel = param_with_axes(
   1206     "kernel",
   1207     self.kernel_init,
   1208     kernel_shape,
   1209     self.params_dtype,
   1210     axes=self.kernel_axes,
   1211 )
   1213 if self.mask is not None:
   1214     kernel *= self.mask

File /opt/conda/lib/python3.10/site-packages/flax/linen/partitioning.py:159, in param_with_axes(name, init_fn, axes, module, *init_args, **init_kwargs)
    157   assert module is not None
    158 # define/fetch parameter on that module
--> 159 module_param = module.param(name, init_fn, *init_args, **init_kwargs)
    160 if axes is not None:
    161   # apply logical axis constraint immediately
    162   module_param = with_sharding_constraint(
    163       module_param, jax.sharding.PartitionSpec(*axes)
    164   )

    [... skipping hidden 1 frame]

File /opt/conda/lib/python3.10/site-packages/flax/core/scope.py:982, in Scope.param(self, name, init_fn, unbox, *init_args, **init_kwargs)
    977   for val, abs_val in zip(value_flat, abs_value_flat):
    978     # NOTE: We could check dtype consistency here as well but it's
    979     # usefuleness is less obvious. We might intentionally change the dtype
    980     # for inference to a half float type for example.
    981     if jnp.shape(val) != jnp.shape(abs_val):
--> 982       raise errors.ScopeParamShapeError(
    983         name, self.path_text, jnp.shape(abs_val), jnp.shape(val)
    984       )
    985 else:
    986   if not self.is_mutable_collection('params'):

ScopeParamShapeError: Initializer expected to generate shape (2, 3, 80, 768) but got shape (3, 80, 768) instead for parameter "kernel" in "/model/encoder/conv1"

from whisper-jax.

nairajay2k avatar nairajay2k commented on May 28, 2024

I am also getting this error? I tried downgrading whisper-jax to an older commit and other things start breaking.
@sanchit-gandhi Pl let us know how to solve this

from whisper-jax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.