diff --git a/src/llmcompressor/core/helpers.py b/src/llmcompressor/core/helpers.py index 1840488ea..cdb2e09c9 100644 --- a/src/llmcompressor/core/helpers.py +++ b/src/llmcompressor/core/helpers.py @@ -6,7 +6,7 @@ conditional logging and parameter tracking. """ -from typing import Any, Generator, Optional, Tuple, Union +from typing import Any, Generator from llmcompressor.core.state import State from llmcompressor.metrics import LoggerManager @@ -21,7 +21,7 @@ def should_log_model_info( model: Any, loggers: LoggerManager, current_log_step: float, - last_log_step: Optional[float] = None, + last_log_step: float | None = None, ) -> bool: """ Check if we should log model level info @@ -65,9 +65,7 @@ def log_model_info(state: State, current_log_step): ) -def _log_current_step( - logger_manager: LoggerManager, current_log_step: Union[float, int] -): +def _log_current_step(logger_manager: LoggerManager, current_log_step: float | int): """ Log the Current Log Step to the logger_manager @@ -80,7 +78,7 @@ def _log_current_step( def _log_model_loggable_items( logger_manager: LoggerManager, - loggable_items: Generator[Tuple[str, Any], None, None], + loggable_items: Generator[tuple[str, Any], None, None], epoch: float, ): """ @@ -93,9 +91,10 @@ def _log_model_loggable_items( """ for loggable_item in loggable_items: log_tag, log_value = loggable_item - if isinstance(log_value, dict): - logger_manager.log_scalars(tag=log_tag, values=log_value, step=epoch) - elif isinstance(log_value, (int, float)): - logger_manager.log_scalar(tag=log_tag, value=log_value, step=epoch) - else: - logger_manager.log_string(tag=log_tag, string=log_value, step=epoch) + match log_value: + case dict(): + logger_manager.log_scalars(tag=log_tag, values=log_value, step=epoch) + case int() | float(): + logger_manager.log_scalar(tag=log_tag, value=log_value, step=epoch) + case _: + logger_manager.log_string(tag=log_tag, string=log_value, step=epoch)