Skip to content

Commit

Permalink
add test for regression task and commit within factor metric
Browse files Browse the repository at this point in the history
  • Loading branch information
favyen2 committed Mar 7, 2025
1 parent 9219c95 commit bd95171
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 10 deletions.
78 changes: 68 additions & 10 deletions rslearn/train/tasks/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ class RegressionTask(BasicTask):
def __init__(
self,
property_name: str,
filters: list[tuple[str, str]] | None,
filters: list[tuple[str, str]] | None = None,
allow_invalid: bool = False,
scale_factor: float = 1,
metric_mode: str = "mse",
use_within_factor_metric: bool = False,
within_factor: float = 0.1,
**kwargs: Any,
) -> None:
"""Initialize a new RegressionTask.
Expand All @@ -39,6 +41,10 @@ def __init__(
at a window, simply mark the example invalid for this task
scale_factor: multiply the label value by this factor
metric_mode: what metric to use, either mse or l1
use_within_factor_metric: include metric that reports percentage of
examples where output is within a factor of the ground truth.
within_factor: the factor for within factor metric. If it's 0.2, and ground
truth is 5.0, then values from 5.0*0.8 to 5.0*1.2 are accepted.
kwargs: other arguments to pass to BasicTask
"""
super().__init__(**kwargs)
Expand All @@ -47,6 +53,8 @@ def __init__(
self.allow_invalid = allow_invalid
self.scale_factor = scale_factor
self.metric_mode = metric_mode
self.use_within_factor_metric = use_within_factor_metric
self.within_factor = within_factor

if not self.filters:
self.filters = []
Expand Down Expand Up @@ -152,17 +160,24 @@ def visualize(

def get_metrics(self) -> MetricCollection:
"""Get the metrics for this task."""
metric_dict: dict[str, Metric] = {}

if self.metric_mode == "mse":
metric = torchmetrics.MeanSquaredError()
metric_dict["mse"] = RegressionMetricWrapper(
metric=torchmetrics.MeanSquaredError(), scale_factor=self.scale_factor
)
elif self.metric_mode == "l1":
metric = torchmetrics.MeanAbsoluteError()
return MetricCollection(
{
self.metric_mode: RegressionMetricWrapper(
metric=metric, scale_factor=self.scale_factor
)
}
)
metric_dict["l1"] = RegressionMetricWrapper(
metric=torchmetrics.MeanAbsoluteError(), scale_factor=self.scale_factor
)

if self.use_within_factor_metric:
metric_dict["within_factor"] = RegressionMetricWrapper(
metric=WithinFactorMetric(self.within_factor),
scale_factor=self.scale_factor,
)

return MetricCollection(metric_dict)


class RegressionHead(torch.nn.Module):
Expand Down Expand Up @@ -268,3 +283,46 @@ def reset(self) -> None:
def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
"""Returns a plot of the metric."""
return self.metric.plot(*args, **kwargs)


class WithinFactorMetric(Metric):
"""Percentage of examples with estimate within some factor of ground truth."""

def __init__(self, factor: float) -> None:
"""Initialize a new RegressionMetricWrapper.
Args:
factor: the factor so if estimate is within this much of ground truth then
it is marked correct.
"""
super().__init__()
self.factor = factor
self.correct = 0
self.total = 0

def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
"""Update metric.
Args:
preds: the predictions
labels: the ground truth data
"""
decisions = (preds >= labels * (1 - self.factor)) & (
preds <= labels * (1 + self.factor)
)
self.correct += torch.count_nonzero(decisions)
self.total += len(decisions)

def compute(self) -> Any:
"""Returns the computed metric."""
return torch.tensor(self.correct / self.total)

def reset(self) -> None:
"""Reset metric."""
super().reset()
self.correct = 0
self.total = 0

def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
"""Returns a plot of the metric."""
return None
27 changes: 27 additions & 0 deletions tests/unit/train/tasks/test_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import torch

from rslearn.const import WGS84_PROJECTION
from rslearn.train.tasks.regression import RegressionTask
from rslearn.utils.feature import Feature


def test_process_output() -> None:
"""Ensure that RegressionTask.process_output produces correct Feature."""
property_name = "property_name"
scale_factor = 0.01
task = RegressionTask(
property_name=property_name,
scale_factor=scale_factor,
)
expected_value = 5
raw_output = torch.tensor(expected_value * scale_factor)
metadata = dict(
projection=WGS84_PROJECTION,
bounds=[0, 0, 1, 1],
)
features = task.process_output(raw_output, metadata)
assert len(features) == 1
feature = features[0]
assert isinstance(feature, Feature)
assert feature.properties[property_name] == pytest.approx(expected_value)

0 comments on commit bd95171

Please sign in to comment.