diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index 27a7f60a2a..a4214032f7 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -19,7 +19,6 @@ import os import re import subprocess -import tempfile import threading import time from abc import ABC, abstractmethod @@ -63,9 +62,10 @@ class TensorboardConfig(TypedDict): class MLflowConfig(TypedDict): - experiment_name: str - run_name: str - tracking_uri: NotRequired[str] + experiment_name: NotRequired[str | None] + run_id: NotRequired[str | None] + run_name: NotRequired[str | None] + tracking_uri: NotRequired[str | None] artifact_location: NotRequired[str | None] @@ -772,26 +772,50 @@ def __init__(self, cfg: MLflowConfig, log_dir: Optional[str] = None): cfg: MLflow configuration log_dir: Optional log directory (used as fallback if artifact_location not in cfg) """ - tracking_uri = cfg.get("tracking_uri") - if tracking_uri: + tracking_uri = cfg.get("tracking_uri") or os.getenv("MLFLOW_TRACKING_URI") + if tracking_uri and not mlflow.is_tracking_uri_set(): mlflow.set_tracking_uri(tracking_uri) - experiment_name = cfg["experiment_name"] - experiment = mlflow.get_experiment_by_name(experiment_name) - if experiment is None: - mlflow.create_experiment( - name=experiment_name, - **{"artifact_location": cfg.get("artifact_location", log_dir)} - if "artifact_location" in cfg or log_dir - else {}, - ) + run_id = cfg.get("run_id") or os.getenv("MLFLOW_RUN_ID") + experiment_name = cfg.get("experiment_name") or os.getenv( + "MLFLOW_EXPERIMENT_NAME" + ) + run_name = cfg.get("run_name") or os.getenv("MLFLOW_RUN_NAME") + + run = mlflow.active_run() + + # If run_id is provided, try to use it directly + if run_id: + # If there is an active run but it's not the one we want, end it + if run and run.info.run_id != run_id: + mlflow.end_run() + run = None + + # Start/resume the specified run + if run is None: + run = mlflow.start_run(run_id=run_id) + + # If no run_id provided, fall back to experiment name logic else: - mlflow.set_experiment(experiment_name) + # End any existing active run to start fresh or ensure correct context + if run: + mlflow.end_run() + + if experiment_name is not None: + experiment = mlflow.get_experiment_by_name(experiment_name) + # if name is set but experiment is not found, create it + if experiment is None: + mlflow.create_experiment( + name=experiment_name, + artifact_location=cfg.get("artifact_location") or log_dir, + ) + # set the experiment context manager + mlflow.set_experiment(experiment_name) + # Start a new run + run = mlflow.start_run(run_name=run_name) - # Start run - run_name = cfg["run_name"] - run_kwargs = {"run_name": run_name} - self.run = mlflow.start_run(**run_kwargs) + self.run = run + self.run_id = run.info.run_id print( f"Initialized MLflowLogger for experiment {experiment_name}, run {run_name}" ) @@ -812,10 +836,14 @@ def log_metrics( prefix: Optional prefix for metric names step_metric: Optional step metric name (ignored in MLflow) """ - for name, value in metrics.items(): + metrics_to_log = {} + flattened_metrics = flatten_dict(metrics) + for name, value in flattened_metrics.items(): if prefix: name = f"{prefix}/{name}" - mlflow.log_metric(name, value, step=step) + metrics_to_log[name] = value + + mlflow.log_metrics(metrics_to_log, step=step, run_id=self.run_id) def log_hyperparams(self, params: Mapping[str, Any]) -> None: """Log hyperparameters to MLflow. @@ -824,7 +852,7 @@ def log_hyperparams(self, params: Mapping[str, Any]) -> None: params: Dictionary of hyperparameters to log """ # MLflow does not support nested dicts - mlflow.log_params(flatten_dict(params)) + mlflow.log_params(flatten_dict(params), run_id=self.run_id) def log_plot(self, figure: plt.Figure, step: int, name: str) -> None: """Log a plot to MLflow. @@ -834,9 +862,10 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None: step: Global step value name: Name of the plot """ - with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file: - figure.savefig(tmp_file.name, format="png", bbox_inches="tight") - mlflow.log_artifact(tmp_file.name, f"plots/{name}") + # Use bbox_inches="tight" to remove extra whitespace/padding around the plot + mlflow.log_figure( + figure, f"plots/{name}.png", save_kwargs={"bbox_inches": "tight"} + ) def log_histogram(self, histogram: list[Any], step: int, name: str) -> None: """Log histogram metrics to MLflow.""" diff --git a/tests/unit/utils/test_logger.py b/tests/unit/utils/test_logger.py index 52b380a213..f44c2d083b 100644 --- a/tests/unit/utils/test_logger.py +++ b/tests/unit/utils/test_logger.py @@ -492,6 +492,9 @@ def temp_dir(self): @patch("nemo_rl.utils.logger.mlflow") def test_init_basic_config(self, mock_mlflow, temp_dir): """Test initialization of MLflowLogger with basic config.""" + # Ensure active_run returns None so initialization logic runs + mock_mlflow.active_run.return_value = None + cfg = { "experiment_name": "test-experiment", "run_name": "test-run", @@ -505,6 +508,11 @@ def test_init_basic_config(self, mock_mlflow, temp_dir): @patch("nemo_rl.utils.logger.mlflow") def test_init_full_config(self, mock_mlflow, temp_dir): """Test initialization of MLflowLogger with full config.""" + # Ensure active_run returns None so initialization logic runs + mock_mlflow.active_run.return_value = None + # Mock is_tracking_uri_set to return False so set_tracking_uri is called + mock_mlflow.is_tracking_uri_set.return_value = False + cfg = { "experiment_name": "test-experiment", "run_name": "test-run", @@ -531,9 +539,10 @@ def test_log_metrics(self, mock_mlflow, temp_dir): logger.log_metrics(metrics, step) # Check that log_metric was called for each metric - assert mock_mlflow.log_metric.call_count == 2 - mock_mlflow.log_metric.assert_any_call("loss", 0.5, step=10) - mock_mlflow.log_metric.assert_any_call("accuracy", 0.8, step=10) + assert mock_mlflow.log_metrics.call_count == 1 + mock_mlflow.log_metrics.assert_any_call( + {"loss": 0.5, "accuracy": 0.8}, step=10, run_id=logger.run_id + ) @patch("nemo_rl.utils.logger.mlflow") def test_log_metrics_with_prefix(self, mock_mlflow, temp_dir): @@ -551,9 +560,10 @@ def test_log_metrics_with_prefix(self, mock_mlflow, temp_dir): logger.log_metrics(metrics, step, prefix) # Check that log_metric was called for each metric with prefix - assert mock_mlflow.log_metric.call_count == 2 - mock_mlflow.log_metric.assert_any_call("train/loss", 0.5, step=10) - mock_mlflow.log_metric.assert_any_call("train/accuracy", 0.8, step=10) + assert mock_mlflow.log_metrics.call_count == 1 + mock_mlflow.log_metrics.assert_any_call( + {"train/loss": 0.5, "train/accuracy": 0.8}, step=10, run_id=logger.run_id + ) @patch("nemo_rl.utils.logger.mlflow") def test_log_hyperparams(self, mock_mlflow, temp_dir): @@ -574,16 +584,14 @@ def test_log_hyperparams(self, mock_mlflow, temp_dir): "lr": 0.001, "batch_size": 32, "model.hidden_size": 128, - } + }, + run_id=logger.run_id, ) @patch("nemo_rl.utils.logger.mlflow") @patch("nemo_rl.utils.logger.plt") - @patch("nemo_rl.utils.logger.os") - def test_log_plot(self, mock_os, mock_plt, mock_mlflow, temp_dir): + def test_log_plot(self, mock_plt, mock_mlflow, temp_dir): """Test logging plots to MLflowLogger.""" - import tempfile - cfg = { "experiment_name": "test-experiment", "run_name": "test-run", @@ -594,21 +602,12 @@ def test_log_plot(self, mock_os, mock_plt, mock_mlflow, temp_dir): # Mock the figure mock_figure = mock_plt.Figure.return_value - # Mock tempfile.NamedTemporaryFile - mock_temp_file = type("MockTempFile", (), {"name": "/tmp/test.png"})() - with patch.object(tempfile, "NamedTemporaryFile") as mock_tempfile: - mock_tempfile.return_value.__enter__.return_value = mock_temp_file - mock_tempfile.return_value.__exit__.return_value = None - - logger.log_plot(mock_figure, step=10, name="test_plot") + logger.log_plot(mock_figure, step=10, name="test_plot") - # Check that figure was saved and logged as artifact - mock_figure.savefig.assert_called_once_with( - "/tmp/test.png", format="png", bbox_inches="tight" - ) - mock_mlflow.log_artifact.assert_called_once_with( - "/tmp/test.png", "plots/test_plot" - ) + # Check that log_figure was called + mock_mlflow.log_figure.assert_called_once_with( + mock_figure, "plots/test_plot.png", save_kwargs={"bbox_inches": "tight"} + ) @patch("nemo_rl.utils.logger.mlflow") def test_cleanup(self, mock_mlflow, temp_dir): @@ -620,6 +619,9 @@ def test_cleanup(self, mock_mlflow, temp_dir): } logger = MLflowLogger(cfg, log_dir=temp_dir) + # Reset mocks to avoid counting calls from init + mock_mlflow.end_run.reset_mock() + # Trigger cleanup logger.__del__() @@ -629,6 +631,9 @@ def test_cleanup(self, mock_mlflow, temp_dir): @patch("nemo_rl.utils.logger.mlflow") def test_init_with_none_log_dir(self, mock_mlflow): """Test initialization with None log_dir uses server default artifact location.""" + # Ensure active_run returns None so initialization logic runs + mock_mlflow.active_run.return_value = None + cfg = { "experiment_name": "test-experiment", "run_name": "test-run", @@ -639,12 +644,17 @@ def test_init_with_none_log_dir(self, mock_mlflow): MLflowLogger(cfg, log_dir=None) # Verify create_experiment was called without artifact_location - mock_mlflow.create_experiment.assert_called_once_with(name="test-experiment") + mock_mlflow.create_experiment.assert_called_once_with( + name="test-experiment", artifact_location=None + ) mock_mlflow.start_run.assert_called_once_with(run_name="test-run") @patch("nemo_rl.utils.logger.mlflow") def test_init_with_custom_log_dir(self, mock_mlflow): """Test initialization with custom log_dir sets artifact_location.""" + # Ensure active_run returns None so initialization logic runs + mock_mlflow.active_run.return_value = None + cfg = { "experiment_name": "test-experiment", "run_name": "test-run", @@ -663,6 +673,11 @@ def test_init_with_custom_log_dir(self, mock_mlflow): @patch("nemo_rl.utils.logger.mlflow") def test_init_with_artifact_location_in_config(self, mock_mlflow): """Test initialization with artifact_location in config takes precedence over log_dir.""" + # Ensure active_run returns None so initialization logic runs + mock_mlflow.active_run.return_value = None + # Mock is_tracking_uri_set to return False so set_tracking_uri is called + mock_mlflow.is_tracking_uri_set.return_value = False + cfg = { "experiment_name": "test-experiment", "run_name": "test-run", @@ -682,7 +697,12 @@ def test_init_with_artifact_location_in_config(self, mock_mlflow): @patch("nemo_rl.utils.logger.mlflow") def test_init_with_artifact_location_none_in_config(self, mock_mlflow): - """Test initialization with artifact_location=None in config uses server default.""" + """Test initialization with artifact_location=None in config uses log_dir.""" + # Ensure active_run returns None so initialization logic runs + mock_mlflow.active_run.return_value = None + # Mock is_tracking_uri_set to return False so set_tracking_uri is called + mock_mlflow.is_tracking_uri_set.return_value = False + cfg = { "experiment_name": "test-experiment", "run_name": "test-run", @@ -693,10 +713,9 @@ def test_init_with_artifact_location_none_in_config(self, mock_mlflow): MLflowLogger(cfg, log_dir="/fallback/path") - # Verify create_experiment was called without artifact_location - # (None is explicitly set, so we don't pass it to MLflow) + # Verify create_experiment was called with log_dir since config is None mock_mlflow.create_experiment.assert_called_once_with( - name=cfg["experiment_name"], artifact_location=cfg["artifact_location"] + name=cfg["experiment_name"], artifact_location="/fallback/path" ) mock_mlflow.set_tracking_uri.assert_called_once_with(cfg["tracking_uri"]) mock_mlflow.start_run.assert_called_once_with(run_name=cfg["run_name"]) @@ -704,6 +723,11 @@ def test_init_with_artifact_location_none_in_config(self, mock_mlflow): @patch("nemo_rl.utils.logger.mlflow") def test_init_without_artifact_location_uses_log_dir(self, mock_mlflow): """Test initialization without artifact_location in config uses log_dir.""" + # Ensure active_run returns None so initialization logic runs + mock_mlflow.active_run.return_value = None + # Mock is_tracking_uri_set to return False so set_tracking_uri is called + mock_mlflow.is_tracking_uri_set.return_value = False + cfg = { "experiment_name": "test-experiment", "run_name": "test-run", @@ -1672,7 +1696,10 @@ def test_log_plot_token_mult_prob_error( @patch("nemo_rl.utils.logger.WandbLogger") @patch("nemo_rl.utils.logger.TensorboardLogger") - def test_init_mlflow_only(self, mock_tb_logger, mock_wandb_logger, temp_dir): + @patch("nemo_rl.utils.logger.MLflowLogger") + def test_init_mlflow_only( + self, mock_mlflow_logger, mock_tb_logger, mock_wandb_logger, temp_dir + ): """Test initialization with only MLflowLogger enabled.""" cfg = { "wandb_enabled": False, @@ -1692,6 +1719,7 @@ def test_init_mlflow_only(self, mock_tb_logger, mock_wandb_logger, temp_dir): assert len(logger.loggers) == 1 mock_wandb_logger.assert_not_called() mock_tb_logger.assert_not_called() + mock_mlflow_logger.assert_called_once() @patch("nemo_rl.utils.logger.WandbLogger") @patch("nemo_rl.utils.logger.TensorboardLogger")