Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,9 +920,29 @@ def get_supported_datasets() -> list[str]:
return list(SUPPORTED_DATASET_CONFIG.keys()) + list(DATASET_COMBOS.keys())


_NESTED_USE_CACHE_CONFIG_ATTRS = ("text_config",)


def _iter_use_cache_configs(model: torch.nn.Module) -> Iterator[Any]:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

The new nested-config path isn't covered by any test. tests/unit/torch/utils/test_dataset_utils.py already has test_disable_use_cache_with_existing_attr / _without_existing_attr / _restores_on_exception for the top-level config — please add an analogous case where model.config.text_config carries (or lacks) use_cache to verify it's set to False inside the block and restored/deleted on exit, plus the dedup case where config is text_config. The PR body checks the "wrote tests" box but the diff is code-only.

"""Yield the top-level config and Step3.7-style nested text config."""
seen: set[int] = set()
config = getattr(model, "config", None)
if config is None:
return

for candidate in (
config,
*(getattr(config, attr, None) for attr in _NESTED_USE_CACHE_CONFIG_ATTRS),
):
if candidate is None or id(candidate) in seen:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

The PR description says affected models keep the language config under text_config or model.language_model.config, but this only iterates config and config.text_config. If model.language_model.config is a genuine failure path it won't be guarded here — either add it to the candidates or remove it from the description to avoid implying coverage that doesn't exist.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed. The PR body now describes the targeted Step 3.7 path only: model.config.text_config. I did not add model.language_model.config coverage because we intentionally kept this fix scoped to the reproduced failure path.

continue
seen.add(id(candidate))
yield candidate


@contextmanager
def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]:
"""Set ``model.config.use_cache = False`` for the duration of the block.
"""Set model config ``use_cache`` flags to ``False`` for the duration of the block.

KV caching is unwanted during calibration / memory-probe forward passes:
it wastes memory, and for hybrid Mamba/attention models (e.g., NemotronH)
Expand All @@ -931,23 +951,26 @@ def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]:
present) also sidesteps configs that never assign the attribute at all
— e.g., ``Step3p5Config`` from stepfun-ai/Step-3.5-Flash — where forward
code that reads ``self.config.use_cache`` would otherwise raise
``AttributeError``. The prior value is restored on exit if one existed.
``AttributeError``. Step3.7 keeps the relevant language config nested
under ``text_config``; that config object is handled the same way. The
prior value is restored on exit if one existed.
"""
config = getattr(model, "config", None)
if config is None:
yield
return
had_attr = hasattr(config, "use_cache")
prev = config.use_cache if had_attr else None
config.use_cache = False
states = []
for config in _iter_use_cache_configs(model):
had_attr = hasattr(config, "use_cache")
prev = config.use_cache if had_attr else None
config.use_cache = False
states.append((config, had_attr, prev))

try:
yield
finally:
if had_attr:
config.use_cache = prev
else:
with suppress(AttributeError):
delattr(config, "use_cache")
for config, had_attr, prev in reversed(states):
if had_attr:
config.use_cache = prev
else:
with suppress(AttributeError):
delattr(config, "use_cache")


def get_max_batch_size(
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/torch/utils/test_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DATASET_COMBOS,
_disable_use_cache,
_forward_loop,
_iter_use_cache_configs,
_pack_documents_into_rows,
_process_batch,
get_dataset_dataloader,
Expand Down Expand Up @@ -222,6 +223,45 @@ def test_disable_use_cache_without_existing_attr():
assert not hasattr(model.config, "use_cache")


@pytest.mark.parametrize("prev_value", [True, False])
def test_disable_use_cache_with_nested_text_config_existing_attr(prev_value):
"""Nested text config `use_cache` is disabled and restored."""
model = torch.nn.Linear(4, 4)
model.config = _Config()
model.config.text_config = _Config()
model.config.text_config.use_cache = prev_value

with _disable_use_cache(model):
assert model.config.use_cache is False
assert model.config.text_config.use_cache is False

assert not hasattr(model.config, "use_cache")
assert model.config.text_config.use_cache is prev_value


def test_disable_use_cache_with_nested_text_config_without_existing_attr():
"""Nested text config `use_cache` is removed if it was added by the context."""
model = torch.nn.Linear(4, 4)
model.config = _Config()
model.config.text_config = _Config()

with _disable_use_cache(model):
assert model.config.use_cache is False
assert model.config.text_config.use_cache is False

assert not hasattr(model.config, "use_cache")
assert not hasattr(model.config.text_config, "use_cache")


def test_iter_use_cache_configs_deduplicates_text_config_alias():
"""The same config object is patched once if `config.text_config is config`."""
model = torch.nn.Linear(4, 4)
model.config = _Config()
model.config.text_config = model.config

assert list(_iter_use_cache_configs(model)) == [model.config]


def test_forward_loop_runs_under_disabled_use_cache():
"""`_forward_loop` runs forward on every batch and restores `use_cache` on exit."""
seen_use_cache: list[bool] = []
Expand Down
Loading