Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions smr_alignment/src/smr_alignment/metric_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
from typing import Dict, List, Optional, Sequence, Tuple
from itertools import combinations
import time

import numpy as np
import cvxpy as cp
Expand All @@ -37,6 +38,7 @@ def __init__(self, G: nx.DiGraph):
self.G = G
self.metrics_: Optional[List[str]] = None
self.weights_: Optional[Dict[str, float]] = None
self.optimization_stats: Optional[Dict[str, str]] = None

# ---------- utilities ----------

Expand Down Expand Up @@ -121,13 +123,20 @@ def optimize_weights(
eps: float = 1e-8,
sample_triplets: Optional[int] = None,
entropy_lambda: float = 0.0,
solver: str = "ECOS",
solver: Optional[str] = "ECOS",
verbose: bool = False,
) -> Dict[str, float]:
"""
Solve: minimize sum_i pos(alpha_i · w) + entropy_lambda * sum w_i log w_i
s.t. w >= 0, sum w = 1

If `solver = None`: Let `cvxpy` choose the solver
"""
t0 = time.perf_counter()
n_nodes = self.G.number_of_nodes()
n_edges = self.G.number_of_edges()
print(f"[mao] start optimize_weights | nodes={n_nodes} edges={n_edges} eps={eps}")

pair_logs = self._collect_pair_logs(eps)
metrics = self.metrics_ or []
if not metrics:
Expand Down Expand Up @@ -160,6 +169,7 @@ def optimize_weights(
prob.solve(solver=solver, verbose=verbose)
except Exception:
prob.solve(solver="SCS", verbose=verbose)
used_solver = getattr(prob.solver_stats, "solver_name", str(solver))

if w.value is None: # type: ignore
raise RuntimeError("Weight optimization failed; check solver output/logs.")
Expand All @@ -172,6 +182,16 @@ def optimize_weights(
w_arr /= s

self.weights_ = {m: float(w_arr[i]) for i, m in enumerate(metrics)}
print(f"[mao] solve done | status={prob.status} | objective={prob.value:.6f} "
f"| solver={used_solver}")
print(f"[mao] learned weights: {self.weights_}")
print(f"[mao] total time: {time.perf_counter() - t0:.3f}s")
self.optimization_stats = {
"status": str(prob.status),
"solver": str(used_solver),
"total_time_seconds": f"{time.perf_counter() - t0:.3f}",
"number_of_edges": str(int(A.shape[0])),
}
return self.weights_

# ---------- apply unified score ----------
Expand Down Expand Up @@ -213,7 +233,7 @@ def apply_unified_scores(
# Unified cost/similarity (log-space blend = weighted geometric mean in real space)
c_hat = 0.0
for m, w_m in edge_w.items():
s_m = self._clamp_score(ms[m], eps) # or self._clamp01(...)
s_m = self._clamp_score(ms[m], eps)
c_hat += w_m * (-math.log(s_m))
s_hat = math.exp(-c_hat)

Expand All @@ -228,6 +248,7 @@ def apply_unified_scores(
metric_scores_weighted[m] = s_m ** w_m
d["metric_scores_weighted"] = metric_scores_weighted

# Working weight: max of weighted metric scores (your requested policy)
# Working weight: max of weighted metric scores
if replace_weight and metric_scores_weighted:
d["weight"] = max(metric_scores_weighted.values())
d["weight_before_optimization"] = d["weight"]
d["weight"] = s_hat
6 changes: 3 additions & 3 deletions smr_alignment/tests/test_metric_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def test_apply_unified_scores_sets_expected_fields_and_weight_policy(self):
self.assertAlmostEqual(d["metric_scores_weighted"]["m_good"], s_good ** 0.97, places=10)
self.assertAlmostEqual(d["metric_scores_weighted"]["m_bad"], s_bad ** 0.03, places=10)

# working weight policy: max of weighted metric scores
expected_operational = max(s_good ** 0.97, s_bad ** 0.03)
self.assertAlmostEqual(d["weight"], expected_operational, places=10)
# NEW policy: working weight equals unified similarity (s_hat), not max per-metric
self.assertIn("weight_before_optimization", d) # keep old value for auditability
self.assertAlmostEqual(d["weight"], s_hat_expected, places=10)

def test_no_metrics_returns_empty_weights_and_no_crash_on_apply(self):
G2 = nx.DiGraph()
Expand Down