Skip to content

Add Moonshine to KerasHub #2093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 57 commits into
base: master
Choose a base branch
from

Conversation

harshaljanjani
Copy link
Collaborator

@harshaljanjani harshaljanjani commented Feb 12, 2025

Moonshine ASR Model Implementation in Keras

This PR introduces the Moonshine Automatic Speech Recognition (ASR) model into the Keras ecosystem. The Moonshine model, originally developed by UsefulSensors and available via Hugging Face, is a transformer-based architecture designed to transcribe audio inputs into text. This implementation ports the model into Keras, complete with support for pre-trained weights from Hugging Face.

Overview

The Moonshine ASR model employs an encoder-decoder architecture. The encoder processes audio features, while the decoder generates text transcriptions. This implementation includes custom layers and components to mirror the original model's behavior, validated against the Hugging Face version for accuracy.

Files Added

The following files have been added to implement the Moonshine ASR model:

  • moonshine_backbone.py defines the MoonshineBackbone class, the core of the model. It integrates the encoder and decoder blocks, embeddings, and layer normalization, forming the complete encoder-decoder pipeline.

  • moonshine_decoder.py contains the MoonshineDecoderBlock class, a decoder block with self-attention (causal), cross-attention, and feedforward layers. It supports caching for efficient generation and uses SwiGLU activation by default.

  • moonshine_encoder.py implements the MoonshineEncoderBlock class, the encoder component with self-attention and feedforward layers. It optionally uses SwiGLU activation, matching the original model's configuration.

  • moonshine_multi_head_attention.py provides a custom multi-head attention layer, the MoonshineMultiHeadAttention class.

  • moonshine_layers.py includes utility layers, which are:

    • MoonshineRotaryEmbedding: Rotary positional embeddings with dynamic scaling support.
    • MoonshineMLP: Can be configured to use SwiGLU activation for feedforward networks or as a linear layer with GeLU activation.
  • moonshine_audio_converter.py implements the MoonshineAudioConverter class, a specialized audio preprocessing layer that converts raw audio waveforms into feature representations suitable for the Moonshine ASR model. It includes downsampling and feature extraction, normalization, and handling of attention masks.

  • moonshine_tokenizer.py provides the MoonshineTokenizer class, which extends the LlamaTokenizer to handle text tokenization for the Moonshine model. It incorporates Moonshine-specific special tokens, including position embedding tokens, hex tokens, and empty tokens, and manages the conversion between raw text and token IDs.

  • moonshine_audio_to_text.py implements the MoonshineAudioToText class, a task model that extends the Seq2SeqLM base class. This class integrates the audio converter, backbone, and tokenizer components to create a complete end-to-end ASR pipeline. It includes methods for text generation from audio inputs, with support for customizable generation parameters and built-in trimming of output sequences.

  • moonshine_seq_2_seq_lm_preprocessor.py implements the MoonshineSeq2SeqLMPreprocessor class, which extends the Seq2SeqLMPreprocessor base class. It handles the conversion of raw audio inputs and text into a format suitable for MoonshineAudioToText. The preprocessor supports both training mode (with paired audio-text inputs) and generation mode (with audio inputs only), including methods for preprocessing and postprocessing during text generation.

  • Weights Conversion Script

    • Converts pre-trained weights from Hugging Face into a Keras-compatible format.
    • Loads them into the MoonshineBackbone model.
    • Validates the Keras implementation by comparing outputs with the Hugging Face model using random inputs.

Dependencies

  • Keras 3: Required for backend-agnostic operations.
  • Hugging Face Transformers: Needed by the weights conversion script for loading the original model.
  • Librosa: Required for audio processing.

Notes for Reviewers

  • The implementation is fully functional with pre-trained weights and ready for immediate use.
  • The modular design allows for easy extension or modification of individual components (e.g., attention layers or embeddings).
  • All custom layers are serializable with get_config() and registered with @keras.saving.register_keras_serializable, ensuring compatibility with Keras model saving/loading.
  • End-To-End Demo Notebook: Colab Notebook.
  • Functionality Tests Notebook Independent From HF: Colab Notebook.

Closes issue #2083.

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! I left some initial comments.
I would suggest following the format, structure and naming conventions similar to teh Whisper model here - https://github.com/keras-team/keras-hub/tree/master/keras_hub/src/models/whisper

  • add docstrings
  • convert backbone to a functional model
  • add a moonshine_audio_converter.py
  • Add a numerics verification colab to verify the implementation

@harshaljanjani
Copy link
Collaborator Author

Will make the changes at the earliest, thanks for the review!

@divyashreepathihalli
Copy link
Collaborator

you will need to run shell/api_gen.sh and also shell/format.sh at root to resolve the code formatting error

@harshaljanjani
Copy link
Collaborator Author

Thanks for the review, made the changes! The issue regarding the build still persists.

@harshaljanjani harshaljanjani self-assigned this Feb 19, 2025
@harshaljanjani
Copy link
Collaborator Author

Summary of Changes:

  1. Added MoonshineDecoderBlock (passes numeric checks, facing a few issues in the reversible embeddings, which keeps me from integrating the whole decoder, but I'll try to fix that and get back).
  2. Made a testable component for the encoder subclassed from keras.Model separate from the MoonshineBackbone class, as it's easier to test loading weights this way since each of the preprocessor, decoder and encoder has separate weight files.

@harshaljanjani
Copy link
Collaborator Author

TODO:

  1. Verify the build methods, as the sanity checks for serialization don’t pass, even though the numerics are aligned.
  2. Write weight conversion scripts.

@harshaljanjani
Copy link
Collaborator Author

Status of the PR:
Weight assignment works, but the numerics differ.

Outputs of the convert_moonshine_checkpoints.py script:

MD5 Checksum Comparison
Decoder Weights Assignment
Preprocessor Weights Assignment
Encoder Weights Assignment

@harshaljanjani harshaljanjani marked this pull request as ready for review April 12, 2025 12:25
… the PyTorch backend, integrated into the KerasHub infra!
@sachinprasadhs sachinprasadhs removed the WIP Pull requests which are work in progress and not ready yet for review. label Apr 14, 2025
@harshaljanjani
Copy link
Collaborator Author

Updated the Colab notebook with results from the latest commit. The PR is now open for review.
What's New?
The task model has been integrated across all three backends into the KerasHub infra, including the custom caching strategy used by Moonshine.

@divyashreepathihalli
Copy link
Collaborator

divyashreepathihalli commented Apr 15, 2025

I don't see the demo notebook with the KerasHub model implemented here, I am seeing a demo from the Huggingface model in the colab
please add the demo with KH model - and verify that the outputs match with model.generate

@harshaljanjani
Copy link
Collaborator Author

I don't see the demo notebook with the KerasHub model implemented here, I am seeing a demo from the Huggingface model in the colab
please add the demo with KH model - and verify that the outputs match with model.generate

@divyashreepathihalli The outputs you see across the first three cells are the KH model outputs for four test samples for each preset, using the generate() function. I've run tools/checkpoint_conversion/convert_moonshine_checkpoints.py in each of the cells across the three backends, which both, verifies the numerics, and contains the end-to-end example.

The cell links are:

  1. PyTorch Backend Output of keras_model.generate()
  2. TensorFlow Backend Output of keras_model.generate()
  3. JAX Backend Output of keras_model.generate()

You may also review the checkpoint conversion file to verify the same.

The HF model is only used in the last cell, where I point out a bug in the HF implementation and show how for the same sample, the KH model presets give good transcripts across all three backends. (The sample used in this test is the "Female Clear Voice (Maximum Length - 64 Sec)" one.)

@harshaljanjani
Copy link
Collaborator Author

harshaljanjani commented Apr 19, 2025

@mattdangerw / @abheesht17 / @divyashreepathihalli Whenever you have a chance, could you please take a look at this PR and the notebook, thanks!

harshaljanjani and others added 5 commits April 21, 2025 14:15
The rope_scaling parameter was much more of a direct port from HF, in which it took a dict and pulled the type key from it. The Moonshine presets nowhere explicitly use the dynamic mode, and it isn't crucial to the model. If it is necessary in the future, sure, but for a seminal port, I think it's best to keep it out. It's best to inherit from the KH RotaryEmbedding class and leave the scaling_factor arg upto it instead, works perfectly fine as a replacement and is much more integrated into the existing infra.
@mattdangerw
Copy link
Member

Dropping a few comments. I think we need still need to get the generation here working similar to other models, make the preprocessing be actual preprocessing (no weights!). I still think a clearer high level colab with intended usage might help clarify things.

  • Do some weight conversion, upload to huggingface or kaggle (doesn't matter which) on your own user.
  • Make a colab that does not touch huggingface at all that shows the intended usage here.
  • Try to show some of the usages here Add Moonshine to KerasHub #2093 (comment)

How much of this is working today? Have we tried running fine-tuning? That will run preprocessing via a tf.data.Dataset map, does that work?

!pip install git+https://github.com/harshaljanjani/keras-hub@moonshine

import os
os.environ["KERAS_BACKEND"] = "jax"  # Or "tensorflow" or "torch" with zero other changes.

import keras
import keras_hub

audio_to_text = keras_hub.models.AudioToText.from_preset(
    "hf://harshaljanjani/keras-moonshine",
)

audio_to_text.generate(audio_tensor)
audio_to_text.generate(audio_batch)

audio_to_text.compile(sampler="top_k")
audio_to_text.generate(audio_tensor)

audio_to_text.compile(...)
audio_to_text.enable_lora(4)  # Optional.
audio_to_text.fit(audio_dataset)
autio_to_text.generate(audio_batch)

@harshaljanjani
Copy link
Collaborator Author

harshaljanjani commented Apr 29, 2025

Will check the comments out, thanks for the review @mattdangerw. I left a few replies, I'd love to hear your opinion on a few non-trivial things as mentioned in the replies; I'll proceed to make changes on the others.

How much of this is working today? Have we tried running fine-tuning? That will run preprocessing via a tf.data.Dataset map, does that work?

I haven't tested fine-tuning yet, but I'll see what I can do. Since you mentioned that the change in the generate() strategy was key, I focused on it for this round.

- MoonshineAudioConverter now has no trainable weights, all feature extraction is moved to the MoonshineBackbone

- Removed logits() function and used self.token_embedding(reverse=True) instead

- Resolved test_causal_lm_basics() for all backends, thus resolving tf.data.Dataset.map compatibility issues on JAX and Torch backends.

- Removed 64 second test file.
@harshaljanjani
Copy link
Collaborator Author

Addressed reviews - (JIT compile + dynamic shapes issue). Looking forward to guidance regarding the same, I'll try to see if I can solve it in the mean time.

Fixed JIT compile issues on TensorFlow and JAX without unnecessary shenanigans

Reverted to KerasNLP style of caching without stateful cache modes.
@harshaljanjani
Copy link
Collaborator Author

The PR should be ready for the next round of reviews @mattdangerw. Here's the new Colab you mentioned. I've tested the functionality with dummy inputs for now; hope you don't mind! I'll check the weights upload thing and the presets once the design is approved.

  1. Functionality Tests Notebook Independent From HF.
  2. Same Outputs Notebook, Updated To The Current PR's Version.

@harshaljanjani harshaljanjani requested a review from mattdangerw May 2, 2025 16:19
@divyashreepathihalli
Copy link
Collaborator

The PR should be ready for the next round of reviews @mattdangerw. Here's the new Colab you mentioned. I've tested the functionality with dummy inputs for now; hope you don't mind! I'll check the weights upload thing and the presets once the design is approved.

  1. Functionality Tests Notebook Independent From HF.
  2. Same Outputs Notebook, Updated To The Current PR's Version.

please add demo colabs, verifications etc to PR descriptions so that it is easier to find

@harshaljanjani
Copy link
Collaborator Author

harshaljanjani commented May 2, 2025

please add demo colabs, verifications etc to PR descriptions so that it is easier to find

Apologies, the end-to-end demo notebook has been linked in the PR description from the beginning. I've just linked the functionality tests I added today in the PR description!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants