@@ -503,7 +503,7 @@ def test_init_basic_config(self, mock_mlflow, temp_dir):
503503 MLflowLogger (cfg , log_dir = temp_dir )
504504
505505 mock_mlflow .set_experiment .assert_called_once_with ("test-experiment" )
506- mock_mlflow .start_run .assert_called_once_with (run_name = "test-run" , run_id = None )
506+ mock_mlflow .start_run .assert_called_once_with (run_name = "test-run" )
507507
508508 @patch ("nemo_rl.utils.logger.mlflow" )
509509 def test_init_full_config (self , mock_mlflow , temp_dir ):
@@ -522,7 +522,7 @@ def test_init_full_config(self, mock_mlflow, temp_dir):
522522
523523 mock_mlflow .set_tracking_uri .assert_called_once_with ("http://localhost:5000" )
524524 mock_mlflow .set_experiment .assert_called_once_with ("test-experiment" )
525- mock_mlflow .start_run .assert_called_once_with (run_name = "test-run" , run_id = None )
525+ mock_mlflow .start_run .assert_called_once_with (run_name = "test-run" )
526526
527527 @patch ("nemo_rl.utils.logger.mlflow" )
528528 def test_log_metrics (self , mock_mlflow , temp_dir ):
@@ -619,6 +619,9 @@ def test_cleanup(self, mock_mlflow, temp_dir):
619619 }
620620 logger = MLflowLogger (cfg , log_dir = temp_dir )
621621
622+ # Reset mocks to avoid counting calls from init
623+ mock_mlflow .end_run .reset_mock ()
624+
622625 # Trigger cleanup
623626 logger .__del__ ()
624627
@@ -641,8 +644,10 @@ def test_init_with_none_log_dir(self, mock_mlflow):
641644 MLflowLogger (cfg , log_dir = None )
642645
643646 # Verify create_experiment was called without artifact_location
644- mock_mlflow .create_experiment .assert_called_once_with (name = "test-experiment" )
645- mock_mlflow .start_run .assert_called_once_with (run_name = "test-run" , run_id = None )
647+ mock_mlflow .create_experiment .assert_called_once_with (
648+ name = "test-experiment" , artifact_location = None
649+ )
650+ mock_mlflow .start_run .assert_called_once_with (run_name = "test-run" )
646651
647652 @patch ("nemo_rl.utils.logger.mlflow" )
648653 def test_init_with_custom_log_dir (self , mock_mlflow ):
@@ -663,7 +668,7 @@ def test_init_with_custom_log_dir(self, mock_mlflow):
663668 mock_mlflow .create_experiment .assert_called_once_with (
664669 name = "test-experiment" , artifact_location = "/custom/path"
665670 )
666- mock_mlflow .start_run .assert_called_once_with (run_name = "test-run" , run_id = None )
671+ mock_mlflow .start_run .assert_called_once_with (run_name = "test-run" )
667672
668673 @patch ("nemo_rl.utils.logger.mlflow" )
669674 def test_init_with_artifact_location_in_config (self , mock_mlflow ):
@@ -689,12 +694,12 @@ def test_init_with_artifact_location_in_config(self, mock_mlflow):
689694 )
690695 mock_mlflow .set_tracking_uri .assert_called_once_with (cfg ["tracking_uri" ])
691696 mock_mlflow .start_run .assert_called_once_with (
692- run_name = cfg ["run_name" ], run_id = None
697+ run_name = cfg ["run_name" ]
693698 )
694699
695700 @patch ("nemo_rl.utils.logger.mlflow" )
696701 def test_init_with_artifact_location_none_in_config (self , mock_mlflow ):
697- """Test initialization with artifact_location=None in config uses server default ."""
702+ """Test initialization with artifact_location=None in config uses log_dir ."""
698703 # Ensure active_run returns None so initialization logic runs
699704 mock_mlflow .active_run .return_value = None
700705 # Mock is_tracking_uri_set to return False so set_tracking_uri is called
@@ -710,14 +715,13 @@ def test_init_with_artifact_location_none_in_config(self, mock_mlflow):
710715
711716 MLflowLogger (cfg , log_dir = "/fallback/path" )
712717
713- # Verify create_experiment was called without artifact_location
714- # (None is explicitly set, so we don't pass it to MLflow)
718+ # Verify create_experiment was called with log_dir since config is None
715719 mock_mlflow .create_experiment .assert_called_once_with (
716- name = cfg ["experiment_name" ], artifact_location = cfg [ "artifact_location" ]
720+ name = cfg ["experiment_name" ], artifact_location = "/fallback/path"
717721 )
718722 mock_mlflow .set_tracking_uri .assert_called_once_with (cfg ["tracking_uri" ])
719723 mock_mlflow .start_run .assert_called_once_with (
720- run_name = cfg ["run_name" ], run_id = None
724+ run_name = cfg ["run_name" ]
721725 )
722726
723727 @patch ("nemo_rl.utils.logger.mlflow" )
@@ -744,7 +748,7 @@ def test_init_without_artifact_location_uses_log_dir(self, mock_mlflow):
744748 )
745749 mock_mlflow .set_tracking_uri .assert_called_once_with (cfg ["tracking_uri" ])
746750 mock_mlflow .start_run .assert_called_once_with (
747- run_name = cfg ["run_name" ], run_id = None
751+ run_name = cfg ["run_name" ]
748752 )
749753
750754
0 commit comments