Skip to content

Commit 30a40e2

Browse files
committed
Use a clearer hierarchy of run id -> experiment name -> run name
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
1 parent 60d0139 commit 30a40e2

2 files changed

Lines changed: 35 additions & 19 deletions

File tree

nemo_rl/utils/logger.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -783,8 +783,21 @@ def __init__(self, cfg: MLflowConfig, log_dir: Optional[str] = None):
783783
run_name = cfg.get("run_name") or os.getenv("MLFLOW_RUN_NAME")
784784

785785
run = mlflow.active_run()
786-
# Start a new run if there is no active run, or if the active run doesn't match the requested run_id
787-
if run is None or (run_id and run.info.run_id != run_id):
786+
787+
# If run_id is provided, try to use it directly
788+
if run_id:
789+
# If there is an active run but it's not the one we want, end it
790+
if run and run.info.run_id != run_id:
791+
mlflow.end_run()
792+
run = None
793+
794+
# Start/resume the specified run
795+
if run is None:
796+
run = mlflow.start_run(run_id=run_id)
797+
798+
# If no run_id provided, fall back to experiment name logic
799+
else:
800+
# End any existing active run to start fresh or ensure correct context
788801
if run:
789802
mlflow.end_run()
790803

@@ -794,14 +807,13 @@ def __init__(self, cfg: MLflowConfig, log_dir: Optional[str] = None):
794807
if experiment is None:
795808
mlflow.create_experiment(
796809
name=experiment_name,
797-
**{"artifact_location": cfg.get("artifact_location", log_dir)}
798-
if "artifact_location" in cfg or log_dir
799-
else {},
810+
artifact_location=cfg.get("artifact_location") or log_dir,
800811
)
801812
# set the experiment context manager
802813
mlflow.set_experiment(experiment_name)
803-
# if run_id is set explicitly, will use. Otherwise, from env var. Otherwise, new run with run name
804-
run = mlflow.start_run(run_name=run_name, run_id=run_id)
814+
# Start a new run
815+
run = mlflow.start_run(run_name=run_name)
816+
805817
self.run = run
806818
self.run_id = run.info.run_id
807819
print(

tests/unit/utils/test_logger.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def test_init_basic_config(self, mock_mlflow, temp_dir):
503503
MLflowLogger(cfg, log_dir=temp_dir)
504504

505505
mock_mlflow.set_experiment.assert_called_once_with("test-experiment")
506-
mock_mlflow.start_run.assert_called_once_with(run_name="test-run", run_id=None)
506+
mock_mlflow.start_run.assert_called_once_with(run_name="test-run")
507507

508508
@patch("nemo_rl.utils.logger.mlflow")
509509
def test_init_full_config(self, mock_mlflow, temp_dir):
@@ -522,7 +522,7 @@ def test_init_full_config(self, mock_mlflow, temp_dir):
522522

523523
mock_mlflow.set_tracking_uri.assert_called_once_with("http://localhost:5000")
524524
mock_mlflow.set_experiment.assert_called_once_with("test-experiment")
525-
mock_mlflow.start_run.assert_called_once_with(run_name="test-run", run_id=None)
525+
mock_mlflow.start_run.assert_called_once_with(run_name="test-run")
526526

527527
@patch("nemo_rl.utils.logger.mlflow")
528528
def test_log_metrics(self, mock_mlflow, temp_dir):
@@ -619,6 +619,9 @@ def test_cleanup(self, mock_mlflow, temp_dir):
619619
}
620620
logger = MLflowLogger(cfg, log_dir=temp_dir)
621621

622+
# Reset mocks to avoid counting calls from init
623+
mock_mlflow.end_run.reset_mock()
624+
622625
# Trigger cleanup
623626
logger.__del__()
624627

@@ -641,8 +644,10 @@ def test_init_with_none_log_dir(self, mock_mlflow):
641644
MLflowLogger(cfg, log_dir=None)
642645

643646
# Verify create_experiment was called without artifact_location
644-
mock_mlflow.create_experiment.assert_called_once_with(name="test-experiment")
645-
mock_mlflow.start_run.assert_called_once_with(run_name="test-run", run_id=None)
647+
mock_mlflow.create_experiment.assert_called_once_with(
648+
name="test-experiment", artifact_location=None
649+
)
650+
mock_mlflow.start_run.assert_called_once_with(run_name="test-run")
646651

647652
@patch("nemo_rl.utils.logger.mlflow")
648653
def test_init_with_custom_log_dir(self, mock_mlflow):
@@ -663,7 +668,7 @@ def test_init_with_custom_log_dir(self, mock_mlflow):
663668
mock_mlflow.create_experiment.assert_called_once_with(
664669
name="test-experiment", artifact_location="/custom/path"
665670
)
666-
mock_mlflow.start_run.assert_called_once_with(run_name="test-run", run_id=None)
671+
mock_mlflow.start_run.assert_called_once_with(run_name="test-run")
667672

668673
@patch("nemo_rl.utils.logger.mlflow")
669674
def test_init_with_artifact_location_in_config(self, mock_mlflow):
@@ -689,12 +694,12 @@ def test_init_with_artifact_location_in_config(self, mock_mlflow):
689694
)
690695
mock_mlflow.set_tracking_uri.assert_called_once_with(cfg["tracking_uri"])
691696
mock_mlflow.start_run.assert_called_once_with(
692-
run_name=cfg["run_name"], run_id=None
697+
run_name=cfg["run_name"]
693698
)
694699

695700
@patch("nemo_rl.utils.logger.mlflow")
696701
def test_init_with_artifact_location_none_in_config(self, mock_mlflow):
697-
"""Test initialization with artifact_location=None in config uses server default."""
702+
"""Test initialization with artifact_location=None in config uses log_dir."""
698703
# Ensure active_run returns None so initialization logic runs
699704
mock_mlflow.active_run.return_value = None
700705
# Mock is_tracking_uri_set to return False so set_tracking_uri is called
@@ -710,14 +715,13 @@ def test_init_with_artifact_location_none_in_config(self, mock_mlflow):
710715

711716
MLflowLogger(cfg, log_dir="/fallback/path")
712717

713-
# Verify create_experiment was called without artifact_location
714-
# (None is explicitly set, so we don't pass it to MLflow)
718+
# Verify create_experiment was called with log_dir since config is None
715719
mock_mlflow.create_experiment.assert_called_once_with(
716-
name=cfg["experiment_name"], artifact_location=cfg["artifact_location"]
720+
name=cfg["experiment_name"], artifact_location="/fallback/path"
717721
)
718722
mock_mlflow.set_tracking_uri.assert_called_once_with(cfg["tracking_uri"])
719723
mock_mlflow.start_run.assert_called_once_with(
720-
run_name=cfg["run_name"], run_id=None
724+
run_name=cfg["run_name"]
721725
)
722726

723727
@patch("nemo_rl.utils.logger.mlflow")
@@ -744,7 +748,7 @@ def test_init_without_artifact_location_uses_log_dir(self, mock_mlflow):
744748
)
745749
mock_mlflow.set_tracking_uri.assert_called_once_with(cfg["tracking_uri"])
746750
mock_mlflow.start_run.assert_called_once_with(
747-
run_name=cfg["run_name"], run_id=None
751+
run_name=cfg["run_name"]
748752
)
749753

750754

0 commit comments

Comments
 (0)