Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssm_enhancement #689

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

ssm_enhancement #689

wants to merge 2 commits into from

Conversation

vishesh9131
Copy link

Pull Request: Enhancements to Mamba and Jamba State-space Models

Summary

This pull request introduces several enhancements to the Mamba and Jamba state-space models (SSMs) implementation, including new recurrence methods, hybrid approaches, and comprehensive testing.

Changes

1. New Recurrence Methods

  • HybridMambaRecurrence: Combines different recurrence methods to leverage their strengths.
  • AlternativeMambaRecurrence: Implements an alternative recurrence method for potentially better performance or accuracy.

2. Enhancements to ssm.py

  • Added HybridMambaRecurrence and AlternativeMambaRecurrence classes.
  • Updated MambaMixerLayer and JambaMixerLayer to integrate the new recurrence methods.

3. Comprehensive Testing in ssm_test.py

  • Added tests for HybridMambaRecurrence and AlternativeMambaRecurrence in MambaMixerLayerTest.
  • Added tests for hybrid and alternative recurrence methods in StackedMambaTest.
  • Added tests for hybrid and alternative recurrence methods in StackedMixedSSMTransformerTest.

(Documentation and Examples : Updated docstrings and comments to reflect the new features and changes.)

Testing

All new features have been thoroughly tested with the following configurations:

  • Different data types (jnp.float32, jnp.bfloat16).
  • Various model dimensions, state dimensions, and hidden dimensions.
  • Integration within MambaBlock, JambaMambaBlock, and StackedMixedSSMTransformerLayer.

Conclusion

These enhancements provide additional flexibility and options for implementing and experimenting with different recurrence methods in the Mamba and Jamba models, potentially improving performance and accuracy for various tasks.

These enhancements provide additional flexibility and options for implementing and experimenting with different recurrence methods in the Mamba and Jamba models, potentially improving performance and accuracy for various tasks.
@vishesh9131
Copy link
Author

Hey @markblee ,

Could you please take a look at my PR when you get a chance? Thanks!

@vishesh9131
Copy link
Author

Hey @swiseman ,

Could you please take a look at my PR when you get a chance? Thanks!

Copy link
Contributor

@swiseman swiseman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

axlearn/common/ssm.py Outdated Show resolved Hide resolved
axlearn/common/ssm.py Outdated Show resolved Hide resolved
axlearn/common/ssm_test.py Show resolved Hide resolved



class HybridMambaRecurrence(BaseMambaRecurrence):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for these new classes. Do people use either the hybrid recurrences or alternative recurrences defined below? Is there evidence that they are useful empirically? If not, I think it would be simpler to leave these classes out for now, and if necessary let people define them in downstream experiment files which import axlearn.common.ssm.

Copy link
Author

@vishesh9131 vishesh9131 Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @swiseman , Thank you for your valuable input. I've reviewed the hybrid recurrences and alternative recurrences, and it seems that they haven't been used extensively in practice. Based on your benchmarking results, it appears that the AssociativeScanMambaRecurrence is more efficient than the HybridMambaRecurrence.

Given the lack of empirical evidence and the performance advantage of the AssociativeScanMambaRecurrence, I agree that it's reasonable to remove the HybridMambaRecurrence and other less-used recurrences from the core axlearn.common.ssm module for now.

This will simplify the codebase and make it easier for users to understand and use. If there's a strong need for these recurrences in the future, they can be defined in downstream experiment files as you suggested.

-Vishesh

- fixed functions redundant definitions
- fixed Incorrect Module Import in layers.py
@vishesh9131 vishesh9131 reopened this Sep 11, 2024
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, shall we close the PR or turn it into draft? Thanks.

@vishesh9131
Copy link
Author

Hi, shall we close the PR or turn it into draft? Thanks.

Hey @ruomingp we can close this PR, I am working on these hybrid structures and it will take a while...

Thanks for reading;
@vishesh9131

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants