Skip to content

Conversation

@utsab345
Copy link
Contributor

Summary

Fixes keras.optimizers.Muon failing with AttributeError: 'ResourceVariable' object has no attribute 'path' in Keras 3 / TF 2.16–2.20.

Changes

  • Replaced deprecated .path references with .name for variable identification.
  • Updated _should_use_adamw() logic to match modern Keras 3 variable handling.
  • Added robust tests covering:
    • No .path attribute access
    • Layer exclusion behavior
    • Compatibility with embeddings and dense layers

Result

All 11 tests passed locally on TensorFlow 2.16+ (Windows).

Closes #21793

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @utsab345, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical compatibility issue within the Keras Muon optimizer, specifically targeting an AttributeError that arises in TensorFlow 2.16+ due to the removal of the .path attribute from ResourceVariable objects. The core change involves updating the optimizer's internal logic to use the .name attribute for variable identification, ensuring seamless operation with modern TensorFlow versions. Additionally, the PR enhances the _should_use_adamw method with improved variable dimension checks and more resilient handling of layer exclusion patterns. The changes are thoroughly validated with a suite of new, targeted unit tests, and a minor refinement to image saving utilities is also included.

Highlights

  • TensorFlow 2.16+ Compatibility: Resolves an AttributeError in the Muon optimizer by replacing deprecated .path variable access with .name for compatibility with newer TensorFlow versions.
  • Updated AdamW Logic: Refines the _should_use_adamw() method in the Muon optimizer, including an expanded dimension check and robust handling of exclude_layers regex patterns.
  • Comprehensive Testing: Introduces new, specific unit tests for the Muon optimizer to cover variable identification, layer exclusion, dimension rules, and explicitly verifies the absence of the .path attribute error.
  • Image Utility Refinement: Standardizes "jpg" to "jpeg" file format handling in save_img and clarifies the warning message for RGBA to RGB conversion.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses the AttributeError for variable.path in recent TensorFlow versions by migrating to variable.name in the Muon optimizer. The changes are well-implemented and include several improvements:

  • The logic in _should_use_adamw is corrected to properly handle 4D tensors as intended for the Muon update step.
  • Robustness is improved by adding error handling for invalid regex patterns in exclude_layers.
  • The test suite for the Muon optimizer has been significantly improved by refactoring a large test into smaller, more focused tests with clear purposes, and by adding a specific test to prevent regressions of the .path attribute issue.

The changes look solid. I have one suggestion to further improve the robustness of the error handling in _should_use_adamw.

Comment on lines 150 to 152
except re.error:
# Skip invalid regex patterns in exclude_layers
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try...except block is a great addition for robustness against invalid regex patterns. However, re.search can also raise a TypeError if an element in self.exclude_layers is not a string or a compiled regex pattern (e.g., a number). To make this even more robust, consider catching TypeError as well.

Suggested change
except re.error:
# Skip invalid regex patterns in exclude_layers
continue
except (re.error, TypeError):
# Skip invalid regex patterns in exclude_layers
continue

@codecov-commenter
Copy link

codecov-commenter commented Oct 29, 2025

Codecov Report

❌ Patch coverage is 81.81818% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.65%. Comparing base (960133e) to head (ad582fa).
⚠️ Report is 3 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/optimizers/muon.py 76.92% 2 Missing and 1 partial ⚠️
keras/src/utils/image_utils.py 83.33% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21797      +/-   ##
==========================================
+ Coverage   82.63%   82.65%   +0.01%     
==========================================
  Files         577      577              
  Lines       59321    59331      +10     
  Branches     9300     9303       +3     
==========================================
+ Hits        49020    49040      +20     
+ Misses       7913     7890      -23     
- Partials     2388     2401      +13     
Flag Coverage Δ
keras 82.47% <81.81%> (+0.01%) ⬆️
keras-jax 63.34% <81.81%> (+0.02%) ⬆️
keras-numpy 57.58% <81.81%> (+0.02%) ⬆️
keras-openvino 34.28% <4.54%> (-0.01%) ⬇️
keras-tensorflow 64.13% <81.81%> (+0.02%) ⬆️
keras-torch 63.64% <81.81%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@utsab345 utsab345 force-pushed the fix-muon-variable-path-issue-21793 branch from 586b5c0 to 4a299d3 Compare October 29, 2025 05:18

def _muon_update_step(self, gradient, variable, lr):
m = self.adam_momentums[variable.path]
m = self.adam_momentums[variable.name]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue is that variable names are not unique. For instance, all Dense layers have a "kernel" variable. So multiple variables will clobber one another.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

keras.optimizers.Muon Fails with AttributeError on variable.path in Keras 3 / TF 2.16-2.20

4 participants