Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions megatron/training/one_logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@

_one_logger_utils_version = "1.2.0-mlm"

rankpulse = None
try:
if os.environ.get('RANKPULSE_ENABLE', '0').lower() in ['1', 'true', 'yes', 'y']:
import one_logger.rankpulse as rankpulse
except ImportError:
print("WARNING: RANKPULSE_ENABLE is set but rankpulse module is not available. Please install one-logger package with rankpulse support.")

def get_timestamp_in_ms():
"""Helper function to get timestamp in ms
Expand All @@ -31,6 +37,17 @@ def on_train_start(iteration, consumed_train_samples, train_samples, seq_length,
log_throughput (bool): log throughput or not
num_floating_point_operations_so_far (int): flops so far
"""
if rankpulse:
try:
rankpulse.start(
interval_seconds = int(os.getenv('RANKPULSE_INTERVAL_SECONDS', '15')),
twindow_seconds = int(os.getenv('RANKPULSE_TWINDOW_SECONDS', '300')),
enable_gpu_debug_info = False if \
os.environ.get("RANKPULSE_GPU_DEBUG_INFO", "1").lower() in ["0", "false", "no", "n"] else True
)
except Exception as e:
print(f"WARNING: Failed to start rankpulse: {e}")

args = get_args()
one_logger = get_one_logger()

Expand Down Expand Up @@ -463,6 +480,12 @@ def track_app_tag(batch_size, world_size, seq_length):
def finish():
"""Flush E2E metrics to remote server
"""
if rankpulse:
try:
rankpulse.stop(timeout_seconds=3.0)
except Exception as e:
print(f"WARNING: Failed to stop rankpulse: {e}")

one_logger = get_one_logger()
if one_logger:
with one_logger.get_context_manager():
Expand Down