Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates needed for Sentinel-2 vessel attribute prediction. #151

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion rslearn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from rslearn.tile_stores import get_tile_store_with_layer
from rslearn.train.data_module import RslearnDataModule
from rslearn.train.lightning_module import RslearnLightningModule
from rslearn.utils import Projection, STGeometry, parse_disabled_layers
from rslearn.utils import Projection, STGeometry

logger = get_logger(__name__)

Expand Down Expand Up @@ -64,6 +64,11 @@ def parse_time_range(
return (parse_time(start), parse_time(end))


def parse_disabled_layers(disabled_layers: str) -> list[str]:
"""Parse the disabled layers string."""
return disabled_layers.split(",") if disabled_layers else []


@register_handler("dataset", "add_windows")
def add_windows() -> None:
"""Handler for the rslearn dataset add_windows command."""
Expand Down
20 changes: 13 additions & 7 deletions rslearn/train/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import BasePredictionWriter
from upath import UPath
Expand Down Expand Up @@ -112,14 +113,12 @@ def write_on_batch_end(
"""
assert isinstance(pl_module, RslearnLightningModule)
metadatas = batch[2]
outputs = [
outputs: list[Any] = [
pl_module.task.process_output(output, metadata)
for output, metadata in zip(prediction, metadatas)
]

for output, metadata in zip(outputs, metadatas):
if not isinstance(output, dict):
raise ValueError(f"Unsupported output type {type(output)}")
for k in self.selector:
output = output[k]

Expand All @@ -128,9 +127,12 @@ def write_on_batch_end(
window_bounds = metadata["window_bounds"]

if self.layer_config.layer_type == LayerType.RASTER:
if window_name not in self.pending_outputs and isinstance(
output, np.ndarray
):
if not isinstance(output, npt.NDArray):
raise ValueError(
"expected output for raster layer to be numpy array"
)

if window_name not in self.pending_outputs:
self.pending_outputs[window_name] = np.zeros(
(
output.shape[0],
Expand All @@ -151,9 +153,13 @@ def write_on_batch_end(
)

elif self.layer_config.layer_type == LayerType.VECTOR:
if not isinstance(output, list):
raise ValueError(
"expected output for vector layer to be list of features"
)

if window_name not in self.pending_outputs:
self.pending_outputs[window_name] = []

self.pending_outputs[window_name].extend(output)

if metadata["patch_idx"] < metadata["num_patches"] - 1:
Expand Down
107 changes: 96 additions & 11 deletions rslearn/train/tasks/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

import numpy as np
import numpy.typing as npt
import shapely
import torch
import torchmetrics
from PIL import Image, ImageDraw
from torchmetrics import Metric, MetricCollection

from rslearn.utils import Feature
from rslearn.utils.feature import Feature
from rslearn.utils.geometry import STGeometry

from .task import BasicTask

Expand All @@ -20,10 +22,12 @@ class RegressionTask(BasicTask):
def __init__(
self,
property_name: str,
filters: list[tuple[str, str]] | None,
filters: list[tuple[str, str]] | None = None,
allow_invalid: bool = False,
scale_factor: float = 1,
metric_mode: str = "mse",
use_within_factor_metric: bool = False,
within_factor: float = 0.1,
**kwargs: Any,
) -> None:
"""Initialize a new RegressionTask.
Expand All @@ -37,6 +41,10 @@ def __init__(
at a window, simply mark the example invalid for this task
scale_factor: multiply the label value by this factor
metric_mode: what metric to use, either mse or l1
use_within_factor_metric: include metric that reports percentage of
examples where output is within a factor of the ground truth.
within_factor: the factor for within factor metric. If it's 0.2, and ground
truth is 5.0, then values from 5.0*0.8 to 5.0*1.2 are accepted.
kwargs: other arguments to pass to BasicTask
"""
super().__init__(**kwargs)
Expand All @@ -45,6 +53,8 @@ def __init__(
self.allow_invalid = allow_invalid
self.scale_factor = scale_factor
self.metric_mode = metric_mode
self.use_within_factor_metric = use_within_factor_metric
self.within_factor = within_factor

if not self.filters:
self.filters = []
Expand Down Expand Up @@ -92,6 +102,31 @@ def process_inputs(
"valid": torch.tensor(0, dtype=torch.float32),
}

def process_output(
self, raw_output: Any, metadata: dict[str, Any]
) -> npt.NDArray[Any] | list[Feature]:
"""Processes an output into raster or vector data.

Args:
raw_output: the output from prediction head.
metadata: metadata about the patch being read

Returns:
either raster or vector data.
"""
output = raw_output.item() / self.scale_factor
feature = Feature(
STGeometry(
metadata["projection"],
shapely.Point(metadata["bounds"][0], metadata["bounds"][1]),
None,
),
{
self.property_name: output,
},
)
return [feature]

def visualize(
self,
input_dict: dict[str, Any],
Expand Down Expand Up @@ -125,17 +160,24 @@ def visualize(

def get_metrics(self) -> MetricCollection:
"""Get the metrics for this task."""
metric_dict: dict[str, Metric] = {}

if self.metric_mode == "mse":
metric = torchmetrics.MeanSquaredError()
metric_dict["mse"] = RegressionMetricWrapper(
metric=torchmetrics.MeanSquaredError(), scale_factor=self.scale_factor
)
elif self.metric_mode == "l1":
metric = torchmetrics.MeanAbsoluteError()
return MetricCollection(
{
self.metric_mode: RegressionMetricWrapper(
metric=metric, scale_factor=self.scale_factor
)
}
)
metric_dict["l1"] = RegressionMetricWrapper(
metric=torchmetrics.MeanAbsoluteError(), scale_factor=self.scale_factor
)

if self.use_within_factor_metric:
metric_dict["within_factor"] = RegressionMetricWrapper(
metric=WithinFactorMetric(self.within_factor),
scale_factor=self.scale_factor,
)

return MetricCollection(metric_dict)


class RegressionHead(torch.nn.Module):
Expand Down Expand Up @@ -241,3 +283,46 @@ def reset(self) -> None:
def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
"""Returns a plot of the metric."""
return self.metric.plot(*args, **kwargs)


class WithinFactorMetric(Metric):
"""Percentage of examples with estimate within some factor of ground truth."""

def __init__(self, factor: float) -> None:
"""Initialize a new RegressionMetricWrapper.

Args:
factor: the factor so if estimate is within this much of ground truth then
it is marked correct.
"""
super().__init__()
self.factor = factor
self.correct = 0
self.total = 0

def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
"""Update metric.

Args:
preds: the predictions
labels: the ground truth data
"""
decisions = (preds >= labels * (1 - self.factor)) & (
preds <= labels * (1 + self.factor)
)
self.correct += torch.count_nonzero(decisions)
self.total += len(decisions)

def compute(self) -> Any:
"""Returns the computed metric."""
return torch.tensor(self.correct / self.total)

def reset(self) -> None:
"""Reset metric."""
super().reset()
self.correct = 0
self.total = 0

def plot(self, *args: list[Any], **kwargs: dict[str, Any]) -> Any:
"""Returns a plot of the metric."""
return None
4 changes: 2 additions & 2 deletions rslearn/train/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def process_inputs(

def process_output(
self, raw_output: Any, metadata: dict[str, Any]
) -> npt.NDArray[Any] | list[Feature]:
) -> npt.NDArray[Any] | list[Feature] | dict[str, Any]:
"""Processes an output into raster or vector data.

Args:
raw_output: the output from prediction head.
metadata: metadata about the patch being read

Returns:
either raster or vector data.
raster data, vector data, or multi-task dictionary output.
"""
raise NotImplementedError

Expand Down
3 changes: 0 additions & 3 deletions rslearn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .get_utm_ups_crs import get_utm_ups_crs
from .grid_index import GridIndex
from .time import daterange
from .utils import open_atomic, parse_disabled_layers

logger = get_logger(__name__)

Expand All @@ -27,7 +26,5 @@
"get_utm_ups_crs",
"is_same_resolution",
"logger",
"open_atomic",
"parse_disabled_layers",
"shp_intersects",
)
30 changes: 0 additions & 30 deletions rslearn/utils/utils.py

This file was deleted.

27 changes: 27 additions & 0 deletions tests/unit/train/tasks/test_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
import torch

from rslearn.const import WGS84_PROJECTION
from rslearn.train.tasks.regression import RegressionTask
from rslearn.utils.feature import Feature


def test_process_output() -> None:
"""Ensure that RegressionTask.process_output produces correct Feature."""
property_name = "property_name"
scale_factor = 0.01
task = RegressionTask(
property_name=property_name,
scale_factor=scale_factor,
)
expected_value = 5
raw_output = torch.tensor(expected_value * scale_factor)
metadata = dict(
projection=WGS84_PROJECTION,
bounds=[0, 0, 1, 1],
)
features = task.process_output(raw_output, metadata)
assert len(features) == 1
feature = features[0]
assert isinstance(feature, Feature)
assert feature.properties[property_name] == pytest.approx(expected_value)
Loading