diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index e9b53897fee..8b3fd0b067f 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -920,9 +920,29 @@ def get_supported_datasets() -> list[str]: return list(SUPPORTED_DATASET_CONFIG.keys()) + list(DATASET_COMBOS.keys()) +_NESTED_USE_CACHE_CONFIG_ATTRS = ("text_config",) + + +def _iter_use_cache_configs(model: torch.nn.Module) -> Iterator[Any]: + """Yield the top-level config and Step3.7-style nested text config.""" + seen: set[int] = set() + config = getattr(model, "config", None) + if config is None: + return + + for candidate in ( + config, + *(getattr(config, attr, None) for attr in _NESTED_USE_CACHE_CONFIG_ATTRS), + ): + if candidate is None or id(candidate) in seen: + continue + seen.add(id(candidate)) + yield candidate + + @contextmanager def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]: - """Set ``model.config.use_cache = False`` for the duration of the block. + """Set model config ``use_cache`` flags to ``False`` for the duration of the block. KV caching is unwanted during calibration / memory-probe forward passes: it wastes memory, and for hybrid Mamba/attention models (e.g., NemotronH) @@ -931,23 +951,26 @@ def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]: present) also sidesteps configs that never assign the attribute at all — e.g., ``Step3p5Config`` from stepfun-ai/Step-3.5-Flash — where forward code that reads ``self.config.use_cache`` would otherwise raise - ``AttributeError``. The prior value is restored on exit if one existed. + ``AttributeError``. Step3.7 keeps the relevant language config nested + under ``text_config``; that config object is handled the same way. The + prior value is restored on exit if one existed. """ - config = getattr(model, "config", None) - if config is None: - yield - return - had_attr = hasattr(config, "use_cache") - prev = config.use_cache if had_attr else None - config.use_cache = False + states = [] + for config in _iter_use_cache_configs(model): + had_attr = hasattr(config, "use_cache") + prev = config.use_cache if had_attr else None + config.use_cache = False + states.append((config, had_attr, prev)) + try: yield finally: - if had_attr: - config.use_cache = prev - else: - with suppress(AttributeError): - delattr(config, "use_cache") + for config, had_attr, prev in reversed(states): + if had_attr: + config.use_cache = prev + else: + with suppress(AttributeError): + delattr(config, "use_cache") def get_max_batch_size( diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index 65623528a4e..71bdd97daf7 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -25,6 +25,7 @@ DATASET_COMBOS, _disable_use_cache, _forward_loop, + _iter_use_cache_configs, _pack_documents_into_rows, _process_batch, get_dataset_dataloader, @@ -222,6 +223,45 @@ def test_disable_use_cache_without_existing_attr(): assert not hasattr(model.config, "use_cache") +@pytest.mark.parametrize("prev_value", [True, False]) +def test_disable_use_cache_with_nested_text_config_existing_attr(prev_value): + """Nested text config `use_cache` is disabled and restored.""" + model = torch.nn.Linear(4, 4) + model.config = _Config() + model.config.text_config = _Config() + model.config.text_config.use_cache = prev_value + + with _disable_use_cache(model): + assert model.config.use_cache is False + assert model.config.text_config.use_cache is False + + assert not hasattr(model.config, "use_cache") + assert model.config.text_config.use_cache is prev_value + + +def test_disable_use_cache_with_nested_text_config_without_existing_attr(): + """Nested text config `use_cache` is removed if it was added by the context.""" + model = torch.nn.Linear(4, 4) + model.config = _Config() + model.config.text_config = _Config() + + with _disable_use_cache(model): + assert model.config.use_cache is False + assert model.config.text_config.use_cache is False + + assert not hasattr(model.config, "use_cache") + assert not hasattr(model.config.text_config, "use_cache") + + +def test_iter_use_cache_configs_deduplicates_text_config_alias(): + """The same config object is patched once if `config.text_config is config`.""" + model = torch.nn.Linear(4, 4) + model.config = _Config() + model.config.text_config = model.config + + assert list(_iter_use_cache_configs(model)) == [model.config] + + def test_forward_loop_runs_under_disabled_use_cache(): """`_forward_loop` runs forward on every batch and restores `use_cache` on exit.""" seen_use_cache: list[bool] = []