[FlaxSpeechEncoderDecoder] Fix input shape bug in weights init #16728
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The tuple
input_shape
is required in theinit
method of the FlaxSpeechEncoderDecoderModel in order to initialise the model weights - one must specify these input shapes to enable JAX to trace through the model dimensions.This tuple consists of two entries: the encoder and decoder input lengths. Speech encoders almost always downsample the sequence length dimension. Given an encoder input length, the decoder input length is computed through a convolutional formula. This convolutional formula should take into consideration two convolutional based modules:
Currently, only the first of these two convolutional based modules is accounted for. This PR amends the model script to account for the second of the two, i.e. the adapter module.