-
Notifications
You must be signed in to change notification settings - Fork 441
Fix nested use_cache disabling for calibration #1704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -920,9 +920,29 @@ def get_supported_datasets() -> list[str]: | |
| return list(SUPPORTED_DATASET_CONFIG.keys()) + list(DATASET_COMBOS.keys()) | ||
|
|
||
|
|
||
| _NESTED_USE_CACHE_CONFIG_ATTRS = ("text_config",) | ||
|
|
||
|
|
||
| def _iter_use_cache_configs(model: torch.nn.Module) -> Iterator[Any]: | ||
| """Yield the top-level config and Step3.7-style nested text config.""" | ||
| seen: set[int] = set() | ||
| config = getattr(model, "config", None) | ||
| if config is None: | ||
| return | ||
|
|
||
| for candidate in ( | ||
| config, | ||
| *(getattr(config, attr, None) for attr in _NESTED_USE_CACHE_CONFIG_ATTRS), | ||
| ): | ||
| if candidate is None or id(candidate) in seen: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The PR description says affected models keep the language config under
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed. The PR body now describes the targeted Step 3.7 path only: model.config.text_config. I did not add model.language_model.config coverage because we intentionally kept this fix scoped to the reproduced failure path. |
||
| continue | ||
| seen.add(id(candidate)) | ||
| yield candidate | ||
|
|
||
|
|
||
| @contextmanager | ||
| def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]: | ||
| """Set ``model.config.use_cache = False`` for the duration of the block. | ||
| """Set model config ``use_cache`` flags to ``False`` for the duration of the block. | ||
|
|
||
| KV caching is unwanted during calibration / memory-probe forward passes: | ||
| it wastes memory, and for hybrid Mamba/attention models (e.g., NemotronH) | ||
|
|
@@ -931,23 +951,26 @@ def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]: | |
| present) also sidesteps configs that never assign the attribute at all | ||
| — e.g., ``Step3p5Config`` from stepfun-ai/Step-3.5-Flash — where forward | ||
| code that reads ``self.config.use_cache`` would otherwise raise | ||
| ``AttributeError``. The prior value is restored on exit if one existed. | ||
| ``AttributeError``. Step3.7 keeps the relevant language config nested | ||
| under ``text_config``; that config object is handled the same way. The | ||
| prior value is restored on exit if one existed. | ||
| """ | ||
| config = getattr(model, "config", None) | ||
| if config is None: | ||
| yield | ||
| return | ||
| had_attr = hasattr(config, "use_cache") | ||
| prev = config.use_cache if had_attr else None | ||
| config.use_cache = False | ||
| states = [] | ||
| for config in _iter_use_cache_configs(model): | ||
| had_attr = hasattr(config, "use_cache") | ||
| prev = config.use_cache if had_attr else None | ||
| config.use_cache = False | ||
| states.append((config, had_attr, prev)) | ||
|
|
||
| try: | ||
| yield | ||
| finally: | ||
| if had_attr: | ||
| config.use_cache = prev | ||
| else: | ||
| with suppress(AttributeError): | ||
| delattr(config, "use_cache") | ||
| for config, had_attr, prev in reversed(states): | ||
| if had_attr: | ||
| config.use_cache = prev | ||
| else: | ||
| with suppress(AttributeError): | ||
| delattr(config, "use_cache") | ||
|
|
||
|
|
||
| def get_max_batch_size( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new nested-config path isn't covered by any test.
tests/unit/torch/utils/test_dataset_utils.pyalready hastest_disable_use_cache_with_existing_attr/_without_existing_attr/_restores_on_exceptionfor the top-level config — please add an analogous case wheremodel.config.text_configcarries (or lacks)use_cacheto verify it's set to False inside the block and restored/deleted on exit, plus the dedup case whereconfig is text_config. The PR body checks the "wrote tests" box but the diff is code-only.