Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SampledMetrics callback #105

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/cfp/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
ComputationCallback,
LoggingCallback,
Metrics,
SampledMetrics,
PCADecodedMetrics,
WandbLogger,
)
Expand All @@ -19,4 +20,5 @@
"CallbackRunner",
"PCADecodedMetrics",
"PCADecoder",
"SampledMetrics",
]
96 changes: 96 additions & 0 deletions src/cfp/training/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax.tree_util as jtu
import numpy as np
from numpy.typing import ArrayLike
import random

from cfp.metrics._metrics import compute_e_distance, compute_r_squared, compute_scalar_mmd, compute_sinkhorn_div

Expand All @@ -18,6 +19,7 @@
"WandbLogger",
"CallbackRunner",
"PCADecodedMetrics",
"SampledMetrics",
]


Expand Down Expand Up @@ -210,6 +212,100 @@ def on_train_end(
Predicted data
"""
return self.on_log_iteration(validation_data, predicted_data)



class SampledMetrics(ComputationCallback):
"""Callback to compute metrics on sampled validation data during training

Parameters
----------
sample_size : int
Number of samples to use for metric computation
metrics : list
List of metrics to compute. Supported metrics are "sinkhorn_div", "e_distance", and "mmd".
metric_aggregations : list
List of aggregation functions to use for each metric. Supported aggregations are "mean" and "median".
log_prefix : str
Prefix to add to the log keys.
"""

def __init__(
self,
sample_size: int,
metrics: list[Literal["sinkhorn_div", "e_distance", "mmd"]],
metric_aggregations: list[Literal["mean", "median"]] = None,
log_prefix: str = "sampled_",
):
self.sample_size = sample_size
self.metrics = metrics
self.metric_aggregation = (
["mean"] if metric_aggregations is None else metric_aggregations
)
self.log_prefix = log_prefix

for metric in metrics:
if metric not in ["sinkhorn_div", "e_distance", "mmd"]:
raise ValueError(
f"Metric {metric} not supported. Supported metrics are 'sinkhorn_div', 'e_distance', and 'mmd'"
)

def on_train_begin(self, *args: Any, **kwargs: Any) -> Any:
"""Called at the beginning of training."""
pass

def sample_data(self, data: ArrayLike) -> ArrayLike:
"""Sample data randomly"""
if len(data) <= self.sample_size:
return data
indices = random.sample(range(len(data)), self.sample_size)
return data[indices]

def on_log_iteration(
self,
validation_data: dict[str, dict[str, ArrayLike]],
predicted_data: dict[str, dict[str, ArrayLike]],
) -> dict[str, float]:
"""Called at each validation/log iteration to compute metrics on sampled data

Args:
validation_data: Validation data
predicted_data: Predicted data
"""
metrics = {}
for metric in self.metrics:
for k in validation_data.keys():
sampled_validation = self.sample_data(validation_data[k])
sampled_predicted = self.sample_data(predicted_data[k])

if metric == "sinkhorn_div":
result = compute_sinkhorn_div(sampled_validation, sampled_predicted)
elif metric == "e_distance":
result = compute_e_distance(sampled_validation, sampled_predicted)
elif metric == "mmd":
result = compute_scalar_mmd(sampled_validation, sampled_predicted)

for agg_fn in self.metric_aggregation:
metrics[f"{self.log_prefix}{k}_{metric}_{agg_fn}"] = agg_fn_to_func[agg_fn](result)

return metrics

def on_train_end(
self,
validation_data: dict[str, dict[str, ArrayLike]],
predicted_data: dict[str, dict[str, ArrayLike]],
) -> dict[str, float]:
"""Called at the end of training to compute metrics

Parameters
----------
validation_data : dict
Validation data
predicted_data : dict
Predicted data
"""
return self.on_log_iteration(validation_data, predicted_data)



class PCADecodedMetrics(Metrics):
Expand Down