diff --git a/ build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/memory_tracker.py b/ build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/memory_tracker.py index 00ca124451..6520f14c59 100644 --- a/ build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/memory_tracker.py +++ b/ build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/memory_tracker.py @@ -38,9 +38,9 @@ def memory_status( # pylint: disable=too-many-locals tag: str = "", reset_max: bool = True, sync: bool = True, - writers: Tuple[Any] = (), + writers: Tuple[Any, ...] = (), step: int = 0, -) -> Tuple[float]: +) -> Tuple[float, float, float, float]: """Memory status gpu.""" rank = dist.get_rank() local_rank = rank % torch.cuda.device_count() @@ -93,8 +93,8 @@ def memory_status( # pylint: disable=too-many-locals def memory_status_cpu( # pylint: disable=too-many-locals - tag: str = "", writers: Tuple[Any] = (), step: int = 0 -) -> Tuple[float]: + tag: str = "", writers: Tuple[Any, ...] = (), step: int = 0 +) -> Tuple[float, float, float, float]: """Memory status cpu.""" rank = dist.get_rank() local_rank = rank % torch.cuda.device_count()