Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 55 additions & 26 deletions nemo_rl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import os
import re
import subprocess
import tempfile
import threading
import time
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -63,9 +62,10 @@ class TensorboardConfig(TypedDict):


class MLflowConfig(TypedDict):
experiment_name: str
run_name: str
tracking_uri: NotRequired[str]
experiment_name: NotRequired[str | None]
run_id: NotRequired[str | None]
run_name: NotRequired[str | None]
tracking_uri: NotRequired[str | None]
artifact_location: NotRequired[str | None]


Expand Down Expand Up @@ -772,26 +772,50 @@ def __init__(self, cfg: MLflowConfig, log_dir: Optional[str] = None):
cfg: MLflow configuration
log_dir: Optional log directory (used as fallback if artifact_location not in cfg)
"""
tracking_uri = cfg.get("tracking_uri")
if tracking_uri:
tracking_uri = cfg.get("tracking_uri") or os.getenv("MLFLOW_TRACKING_URI")
if tracking_uri and not mlflow.is_tracking_uri_set():
mlflow.set_tracking_uri(tracking_uri)

experiment_name = cfg["experiment_name"]
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment is None:
mlflow.create_experiment(
name=experiment_name,
**{"artifact_location": cfg.get("artifact_location", log_dir)}
if "artifact_location" in cfg or log_dir
else {},
)
run_id = cfg.get("run_id") or os.getenv("MLFLOW_RUN_ID")
experiment_name = cfg.get("experiment_name") or os.getenv(
"MLFLOW_EXPERIMENT_NAME"
)
run_name = cfg.get("run_name") or os.getenv("MLFLOW_RUN_NAME")

run = mlflow.active_run()

# If run_id is provided, try to use it directly
if run_id:
# If there is an active run but it's not the one we want, end it
if run and run.info.run_id != run_id:
mlflow.end_run()
run = None

# Start/resume the specified run
if run is None:
run = mlflow.start_run(run_id=run_id)

# If no run_id provided, fall back to experiment name logic
else:
mlflow.set_experiment(experiment_name)
# End any existing active run to start fresh or ensure correct context
if run:
mlflow.end_run()

if experiment_name is not None:
experiment = mlflow.get_experiment_by_name(experiment_name)
# if name is set but experiment is not found, create it
if experiment is None:
mlflow.create_experiment(
name=experiment_name,
artifact_location=cfg.get("artifact_location") or log_dir,
)
# set the experiment context manager
mlflow.set_experiment(experiment_name)
# Start a new run
run = mlflow.start_run(run_name=run_name)

# Start run
run_name = cfg["run_name"]
run_kwargs = {"run_name": run_name}
self.run = mlflow.start_run(**run_kwargs)
self.run = run
self.run_id = run.info.run_id
print(
f"Initialized MLflowLogger for experiment {experiment_name}, run {run_name}"
)
Expand All @@ -812,10 +836,14 @@ def log_metrics(
prefix: Optional prefix for metric names
step_metric: Optional step metric name (ignored in MLflow)
"""
for name, value in metrics.items():
metrics_to_log = {}
flattened_metrics = flatten_dict(metrics)
for name, value in flattened_metrics.items():
if prefix:
name = f"{prefix}/{name}"
mlflow.log_metric(name, value, step=step)
metrics_to_log[name] = value

mlflow.log_metrics(metrics_to_log, step=step, run_id=self.run_id)

def log_hyperparams(self, params: Mapping[str, Any]) -> None:
"""Log hyperparameters to MLflow.
Expand All @@ -824,7 +852,7 @@ def log_hyperparams(self, params: Mapping[str, Any]) -> None:
params: Dictionary of hyperparameters to log
"""
# MLflow does not support nested dicts
mlflow.log_params(flatten_dict(params))
mlflow.log_params(flatten_dict(params), run_id=self.run_id)

def log_plot(self, figure: plt.Figure, step: int, name: str) -> None:
"""Log a plot to MLflow.
Expand All @@ -834,9 +862,10 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None:
step: Global step value
name: Name of the plot
"""
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
figure.savefig(tmp_file.name, format="png", bbox_inches="tight")
mlflow.log_artifact(tmp_file.name, f"plots/{name}")
# Use bbox_inches="tight" to remove extra whitespace/padding around the plot
mlflow.log_figure(
figure, f"plots/{name}.png", save_kwargs={"bbox_inches": "tight"}
)

def log_histogram(self, histogram: list[Any], step: int, name: str) -> None:
"""Log histogram metrics to MLflow."""
Expand Down
90 changes: 59 additions & 31 deletions tests/unit/utils/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,9 @@ def temp_dir(self):
@patch("nemo_rl.utils.logger.mlflow")
def test_init_basic_config(self, mock_mlflow, temp_dir):
"""Test initialization of MLflowLogger with basic config."""
# Ensure active_run returns None so initialization logic runs
mock_mlflow.active_run.return_value = None

cfg = {
"experiment_name": "test-experiment",
"run_name": "test-run",
Expand All @@ -505,6 +508,11 @@ def test_init_basic_config(self, mock_mlflow, temp_dir):
@patch("nemo_rl.utils.logger.mlflow")
def test_init_full_config(self, mock_mlflow, temp_dir):
"""Test initialization of MLflowLogger with full config."""
# Ensure active_run returns None so initialization logic runs
mock_mlflow.active_run.return_value = None
# Mock is_tracking_uri_set to return False so set_tracking_uri is called
mock_mlflow.is_tracking_uri_set.return_value = False

cfg = {
"experiment_name": "test-experiment",
"run_name": "test-run",
Expand All @@ -531,9 +539,10 @@ def test_log_metrics(self, mock_mlflow, temp_dir):
logger.log_metrics(metrics, step)

# Check that log_metric was called for each metric
assert mock_mlflow.log_metric.call_count == 2
mock_mlflow.log_metric.assert_any_call("loss", 0.5, step=10)
mock_mlflow.log_metric.assert_any_call("accuracy", 0.8, step=10)
assert mock_mlflow.log_metrics.call_count == 1
mock_mlflow.log_metrics.assert_any_call(
{"loss": 0.5, "accuracy": 0.8}, step=10, run_id=logger.run_id
)

@patch("nemo_rl.utils.logger.mlflow")
def test_log_metrics_with_prefix(self, mock_mlflow, temp_dir):
Expand All @@ -551,9 +560,10 @@ def test_log_metrics_with_prefix(self, mock_mlflow, temp_dir):
logger.log_metrics(metrics, step, prefix)

# Check that log_metric was called for each metric with prefix
assert mock_mlflow.log_metric.call_count == 2
mock_mlflow.log_metric.assert_any_call("train/loss", 0.5, step=10)
mock_mlflow.log_metric.assert_any_call("train/accuracy", 0.8, step=10)
assert mock_mlflow.log_metrics.call_count == 1
mock_mlflow.log_metrics.assert_any_call(
{"train/loss": 0.5, "train/accuracy": 0.8}, step=10, run_id=logger.run_id
)

@patch("nemo_rl.utils.logger.mlflow")
def test_log_hyperparams(self, mock_mlflow, temp_dir):
Expand All @@ -574,16 +584,14 @@ def test_log_hyperparams(self, mock_mlflow, temp_dir):
"lr": 0.001,
"batch_size": 32,
"model.hidden_size": 128,
}
},
run_id=logger.run_id,
)

@patch("nemo_rl.utils.logger.mlflow")
@patch("nemo_rl.utils.logger.plt")
@patch("nemo_rl.utils.logger.os")
def test_log_plot(self, mock_os, mock_plt, mock_mlflow, temp_dir):
def test_log_plot(self, mock_plt, mock_mlflow, temp_dir):
"""Test logging plots to MLflowLogger."""
import tempfile

cfg = {
"experiment_name": "test-experiment",
"run_name": "test-run",
Expand All @@ -594,21 +602,12 @@ def test_log_plot(self, mock_os, mock_plt, mock_mlflow, temp_dir):
# Mock the figure
mock_figure = mock_plt.Figure.return_value

# Mock tempfile.NamedTemporaryFile
mock_temp_file = type("MockTempFile", (), {"name": "/tmp/test.png"})()
with patch.object(tempfile, "NamedTemporaryFile") as mock_tempfile:
mock_tempfile.return_value.__enter__.return_value = mock_temp_file
mock_tempfile.return_value.__exit__.return_value = None

logger.log_plot(mock_figure, step=10, name="test_plot")
logger.log_plot(mock_figure, step=10, name="test_plot")

# Check that figure was saved and logged as artifact
mock_figure.savefig.assert_called_once_with(
"/tmp/test.png", format="png", bbox_inches="tight"
)
mock_mlflow.log_artifact.assert_called_once_with(
"/tmp/test.png", "plots/test_plot"
)
# Check that log_figure was called
mock_mlflow.log_figure.assert_called_once_with(
mock_figure, "plots/test_plot.png", save_kwargs={"bbox_inches": "tight"}
)

@patch("nemo_rl.utils.logger.mlflow")
def test_cleanup(self, mock_mlflow, temp_dir):
Expand All @@ -620,6 +619,9 @@ def test_cleanup(self, mock_mlflow, temp_dir):
}
logger = MLflowLogger(cfg, log_dir=temp_dir)

# Reset mocks to avoid counting calls from init
mock_mlflow.end_run.reset_mock()

# Trigger cleanup
logger.__del__()

Expand All @@ -629,6 +631,9 @@ def test_cleanup(self, mock_mlflow, temp_dir):
@patch("nemo_rl.utils.logger.mlflow")
def test_init_with_none_log_dir(self, mock_mlflow):
"""Test initialization with None log_dir uses server default artifact location."""
# Ensure active_run returns None so initialization logic runs
mock_mlflow.active_run.return_value = None

cfg = {
"experiment_name": "test-experiment",
"run_name": "test-run",
Expand All @@ -639,12 +644,17 @@ def test_init_with_none_log_dir(self, mock_mlflow):
MLflowLogger(cfg, log_dir=None)

# Verify create_experiment was called without artifact_location
mock_mlflow.create_experiment.assert_called_once_with(name="test-experiment")
mock_mlflow.create_experiment.assert_called_once_with(
name="test-experiment", artifact_location=None
)
mock_mlflow.start_run.assert_called_once_with(run_name="test-run")

@patch("nemo_rl.utils.logger.mlflow")
def test_init_with_custom_log_dir(self, mock_mlflow):
"""Test initialization with custom log_dir sets artifact_location."""
# Ensure active_run returns None so initialization logic runs
mock_mlflow.active_run.return_value = None

cfg = {
"experiment_name": "test-experiment",
"run_name": "test-run",
Expand All @@ -663,6 +673,11 @@ def test_init_with_custom_log_dir(self, mock_mlflow):
@patch("nemo_rl.utils.logger.mlflow")
def test_init_with_artifact_location_in_config(self, mock_mlflow):
"""Test initialization with artifact_location in config takes precedence over log_dir."""
# Ensure active_run returns None so initialization logic runs
mock_mlflow.active_run.return_value = None
# Mock is_tracking_uri_set to return False so set_tracking_uri is called
mock_mlflow.is_tracking_uri_set.return_value = False

cfg = {
"experiment_name": "test-experiment",
"run_name": "test-run",
Expand All @@ -682,7 +697,12 @@ def test_init_with_artifact_location_in_config(self, mock_mlflow):

@patch("nemo_rl.utils.logger.mlflow")
def test_init_with_artifact_location_none_in_config(self, mock_mlflow):
"""Test initialization with artifact_location=None in config uses server default."""
"""Test initialization with artifact_location=None in config uses log_dir."""
# Ensure active_run returns None so initialization logic runs
mock_mlflow.active_run.return_value = None
# Mock is_tracking_uri_set to return False so set_tracking_uri is called
mock_mlflow.is_tracking_uri_set.return_value = False

cfg = {
"experiment_name": "test-experiment",
"run_name": "test-run",
Expand All @@ -693,17 +713,21 @@ def test_init_with_artifact_location_none_in_config(self, mock_mlflow):

MLflowLogger(cfg, log_dir="/fallback/path")

# Verify create_experiment was called without artifact_location
# (None is explicitly set, so we don't pass it to MLflow)
# Verify create_experiment was called with log_dir since config is None
mock_mlflow.create_experiment.assert_called_once_with(
name=cfg["experiment_name"], artifact_location=cfg["artifact_location"]
name=cfg["experiment_name"], artifact_location="/fallback/path"
)
mock_mlflow.set_tracking_uri.assert_called_once_with(cfg["tracking_uri"])
mock_mlflow.start_run.assert_called_once_with(run_name=cfg["run_name"])

@patch("nemo_rl.utils.logger.mlflow")
def test_init_without_artifact_location_uses_log_dir(self, mock_mlflow):
"""Test initialization without artifact_location in config uses log_dir."""
# Ensure active_run returns None so initialization logic runs
mock_mlflow.active_run.return_value = None
# Mock is_tracking_uri_set to return False so set_tracking_uri is called
mock_mlflow.is_tracking_uri_set.return_value = False

cfg = {
"experiment_name": "test-experiment",
"run_name": "test-run",
Expand Down Expand Up @@ -1672,7 +1696,10 @@ def test_log_plot_token_mult_prob_error(

@patch("nemo_rl.utils.logger.WandbLogger")
@patch("nemo_rl.utils.logger.TensorboardLogger")
def test_init_mlflow_only(self, mock_tb_logger, mock_wandb_logger, temp_dir):
@patch("nemo_rl.utils.logger.MLflowLogger")
def test_init_mlflow_only(
self, mock_mlflow_logger, mock_tb_logger, mock_wandb_logger, temp_dir
):
"""Test initialization with only MLflowLogger enabled."""
cfg = {
"wandb_enabled": False,
Expand All @@ -1692,6 +1719,7 @@ def test_init_mlflow_only(self, mock_tb_logger, mock_wandb_logger, temp_dir):
assert len(logger.loggers) == 1
mock_wandb_logger.assert_not_called()
mock_tb_logger.assert_not_called()
mock_mlflow_logger.assert_called_once()

@patch("nemo_rl.utils.logger.WandbLogger")
@patch("nemo_rl.utils.logger.TensorboardLogger")
Expand Down
Loading