Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions benchmarks/doc_understanding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
except Exception: # pragma: no cover - fallback when torch is unavailable
torch = None # type: ignore[assignment]

from nd_llm.bottleneck import CompressionResult, IBottleneck
from nd_llm.orchestration import CompressionRecord, Orchestrator, UsageEvent
from nd_llm.bottleneck import CompressionResult, IBottleneck, RegistryAwareBudgetAllocator
from nd_llm.orchestration import Orchestrator, UsageEvent
from nd_llm.stm import STM
from nd_llm.utils import (
DEFAULT_BACKEND,
Expand Down Expand Up @@ -158,8 +158,16 @@ def run_funsd_benchmark(
data_root: Optional[Path | str] = None,
use_sample: bool = True,
seed: int = 0,
allocator_kwargs: Optional[Mapping[str, Any]] = None,
) -> Dict[str, Any]:
"""Evaluate FUNSD documents for numeric-answer retention under budget constraints."""
"""Evaluate FUNSD documents for numeric-answer retention under budget constraints.

Parameters
----------
allocator_kwargs:
Optional mapping forwarded to :class:`RegistryAwareBudgetAllocator` so callers can
specify advanced controls such as ``field_weights`` or ``field_min_quota``.
"""

registry = build_funsd_registry()
build_funsd_encoders(registry)
Expand All @@ -183,6 +191,7 @@ def run_funsd_benchmark(
seed=seed,
ablations=ablations,
mi_field_priorities=("text", "layout"),
allocator_kwargs=allocator_kwargs,
)
runs.append(budget_run)

Expand Down Expand Up @@ -254,15 +263,18 @@ def _evaluate_budget(
seed: int,
ablations: Optional[Mapping[str, "AblationFn"]] = None,
mi_field_priorities: Optional[Sequence[str]] = None,
allocator_kwargs: Optional[Mapping[str, Any]] = None,
) -> BudgetRun:
bottleneck = IBottleneck(target_budget=int(budget))
budget_allocator = RegistryAwareBudgetAllocator(**dict(allocator_kwargs or {}))
bottleneck = IBottleneck(target_budget=int(budget), budget_allocator=budget_allocator)

ablation_totals: Dict[str, Dict[str, Any]] = {}
ablation_bottlenecks: Dict[str, IBottleneck] = {}
if ablations:
ablation_totals = {name: _make_ablation_totals() for name in ablations}
ablation_bottlenecks = {
name: IBottleneck(target_budget=int(budget)) for name in ablations
name: IBottleneck(target_budget=int(budget), budget_allocator=budget_allocator)
for name in ablations
}

with TemporaryDirectory(prefix="ndllm-bench-") as tmp:
Expand Down Expand Up @@ -377,7 +389,10 @@ def _evaluate_budget(
)
ab_bottleneck = ablation_bottlenecks.get(name)
if ab_bottleneck is None:
ab_bottleneck = IBottleneck(target_budget=int(budget))
ab_bottleneck = IBottleneck(
target_budget=int(budget),
budget_allocator=budget_allocator,
)
ablation_bottlenecks[name] = ab_bottleneck

ab_result = ab_bottleneck.compress(
Expand Down
66 changes: 65 additions & 1 deletion nd_llm/bottleneck/ib.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def __init__(
salience_bonus: float = 1.5,
key_weight: float = 0.3,
min_weight: float = 0.05,
field_weights: Optional[Mapping[str, Any]] = None,
field_min_quota: Optional[Mapping[str, Any]] = None,
) -> None:
if salience_bonus <= 0:
raise ValueError("salience_bonus must be positive")
Expand All @@ -177,6 +179,24 @@ def __init__(
self.salience_bonus = salience_bonus
self.key_weight = key_weight
self.min_weight = min_weight
self.field_weights: Dict[str, float] = {}
for name, value in (field_weights or {}).items():
try:
weight = float(value)
except (TypeError, ValueError):
continue
if weight <= 0:
continue
self.field_weights[str(name)] = weight
self.field_min_quota: Dict[str, int] = {}
for name, value in (field_min_quota or {}).items():
try:
quota = int(value)
except (TypeError, ValueError):
continue
if quota <= 0:
continue
self.field_min_quota[str(name)] = quota

def __call__(
self,
Expand All @@ -188,10 +208,21 @@ def __call__(
available = sum(token_counts.values())
limit = min(int(total_budget), available)
weights: Dict[str, float] = {}
quotas: Dict[str, int] = {}
if self.field_min_quota:
for field in scores:
quota = self.field_min_quota.get(field, 0)
if quota <= 0:
continue
quotas[field] = min(int(quota), token_counts.get(field, 0))

for field in scores:
info = metadata.get(field, {})
budget_weight = info.get("budget_weight") if isinstance(info, Mapping) else None
if isinstance(budget_weight, (int, float)):
override_weight = self.field_weights.get(field)
if override_weight is not None:
weight = float(override_weight)
elif isinstance(budget_weight, (int, float)):
weight = float(budget_weight)
else:
keys = info.get("keys", []) if isinstance(info, Mapping) else []
Expand Down Expand Up @@ -255,6 +286,39 @@ def __call__(
allocations[field] += 1
assigned += 1

if quotas:
donors: Dict[str, int] = {
field: allocations.get(field, 0) - quotas.get(field, 0) for field in scores
}
for field, quota in quotas.items():
current = allocations.get(field, 0)
if quota <= 0 or current >= quota:
continue
needed = quota - current
remaining_capacity = token_counts.get(field, 0) - current
if remaining_capacity <= 0:
continue
needed = min(needed, remaining_capacity)
while needed > 0:
donor_field = None
donor_surplus = 0
for candidate, surplus in donors.items():
if candidate == field or surplus <= 0:
continue
if surplus > donor_surplus or (
surplus == donor_surplus and donor_field is not None and candidate < donor_field
):
donor_field = candidate
donor_surplus = surplus
if donor_field is None or donor_surplus <= 0:
break
transfer = min(needed, donor_surplus)
allocations[donor_field] -= transfer
donors[donor_field] -= transfer
allocations[field] += transfer
donors[field] = allocations[field] - quotas.get(field, 0)
needed -= transfer

return allocations, {field: weights.get(field, 0.0) for field in scores}


Expand Down
Loading