Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions mesa_geo/property_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np


class PropertyLayer:
"""
PropertyLayer manages raster attributes and cell data separately
from RasterLayer to improve modularity and maintainability.
"""

def __init__(self, width, height, raster_layer):
self.width = width
self.height = height
self.raster_layer = raster_layer

# store attribute names
self.attributes = set()

# store raster data
self.data = {}

def apply_raster(self, data, attr_name=None):
if data.ndim == 2:
data = data[np.newaxis, ...]

num_bands = data.shape[0]

if attr_name is None:
attr_names = [f"attribute_{i}" for i in range(num_bands)]
elif isinstance(attr_name, str):
attr_names = [f"{attr_name}_{i}" for i in range(num_bands)]
else:
attr_names = attr_name

for band_idx, attr in enumerate(attr_names):
self.attributes.add(attr)

for grid_x in range(self.width):
for grid_y in range(self.height):
setattr(
self.raster_layer.cells[grid_x][grid_y],
attr,
data[band_idx, self.height - grid_y - 1, grid_x],
)

self.data[attr] = data[band_idx]

def get_raster(self, attr_name=None):
if attr_name is None:
return self.data

if isinstance(attr_name, str):
return self.data[attr_name]

return {name: self.data[name] for name in attr_name}
95 changes: 11 additions & 84 deletions mesa_geo/raster_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

from mesa_geo.geo_base import GeoBase

# for the property_layer (custom code)
from mesa_geo.property_layer import PropertyLayer


class RasterBase(GeoBase):
"""
Expand Down Expand Up @@ -378,7 +381,8 @@ def __init__(
self.model = model
self.cell_cls = cell_cls
self._initialize_cells()
self._attributes = set()
# self._attributes = set()
self.property_layer = PropertyLayer(width, height, self)
self._neighborhood_cache = {}

def _update_transform(self) -> None:
Expand Down Expand Up @@ -444,7 +448,8 @@ def attributes(self) -> set[str]:
:return: Attributes of the cells in the raster layer.
:rtype: Set[str]
"""
return self._attributes
# Delegating to PropertyLayer to handle attributes
return self.property_layer.attributes

@overload
def __getitem__(self, index: int) -> list[Cell]: ...
Expand Down Expand Up @@ -531,58 +536,8 @@ def apply_raster(
names are generated. Default is None.
:raises ValueError: If the shape of the data does not match the raster.
"""

if data.ndim != 3 or data.shape[1:] != (self.height, self.width):
raise ValueError(
f"Data shape does not match raster shape. "
f"Expected (*, {self.height}, {self.width}), received {data.shape}."
)
num_bands = data.shape[0]

if num_bands == 1:
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
if len(attr_name) != 1:
raise ValueError(
"attr_name sequence length must match the number of raster bands; "
f"expected {num_bands} band names, got {len(attr_name)}."
)
names = [attr_name[0]]
else:
names = [cast(str | None, attr_name)]
else:
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
if len(attr_name) != num_bands:
raise ValueError(
"attr_name sequence length must match the number of raster bands; "
f"expected {num_bands} band names, got {len(attr_name)}."
)
names = list(attr_name)
elif isinstance(attr_name, str):
names = [f"{attr_name}_{band_idx + 1}" for band_idx in range(num_bands)]
else:
names = [None] * num_bands

def _default_attr_name() -> str:
base = f"attribute_{len(self.cell_cls.__dict__)}"
if base not in self._attributes:
return base
suffix = 1
candidate = f"{base}_{suffix}"
while candidate in self._attributes:
suffix += 1
candidate = f"{base}_{suffix}"
return candidate

for band_idx, name in enumerate(names):
attr = _default_attr_name() if name is None else name
self._attributes.add(attr)
for grid_x in range(self.width):
for grid_y in range(self.height):
setattr(
self.cells[grid_x][grid_y],
attr,
data[band_idx, self.height - grid_y - 1, grid_x],
)
# Delegating to PropertyLayer for separation of concerns
self.property_layer.apply_raster(data, attr_name)

def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray:
"""
Expand All @@ -594,36 +549,8 @@ def get_raster(self, attr_name: str | Sequence[str] | None = None) -> np.ndarray
(bands, height, width).
:rtype: np.ndarray
"""

if isinstance(attr_name, str) and attr_name not in self.attributes:
raise ValueError(
f"Attribute {attr_name} does not exist. "
f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all."
)
if isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
missing = [name for name in attr_name if name not in self.attributes]
if missing:
raise ValueError(
f"Attribute {missing[0]} does not exist. "
f"Choose from {self.attributes}, or set `attr_name` to `None` to retrieve all."
)
if attr_name is None:
num_bands = len(self.attributes)
attr_names = self.attributes
elif isinstance(attr_name, Sequence) and not isinstance(attr_name, str):
num_bands = len(attr_name)
attr_names = list(attr_name)
else:
num_bands = 1
attr_names = [attr_name]
data = np.empty((num_bands, self.height, self.width))
for ind, name in enumerate(attr_names):
for grid_x in range(self.width):
for grid_y in range(self.height):
data[ind, self.height - grid_y - 1, grid_x] = getattr(
self.cells[grid_x][grid_y], name
)
return data
# Delegating to PropertyLayer for separation of concerns
return self.property_layer.get_raster(attr_name)

def get_random_xy(
self,
Expand Down
89 changes: 89 additions & 0 deletions tests/test_property_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
from mesa import Model

from mesa_geo.raster_layers import RasterLayer


class DummyModel(Model):
pass


def test_property_layer_basic():

model = DummyModel()

raster = RasterLayer(
width=5,
height=5,
crs="EPSG:4326",
total_bounds=[0, 0, 5, 5],
model=model,
)

data = np.ones((5, 5))

raster.apply_raster(data, "temperature")

# Fixed assertion based on PropertyLayer naming convention
assert "temperature_0" in raster.property_layer.attributes


def test_property_layer_multiband():

model = DummyModel()

raster = RasterLayer(
width=5,
height=5,
crs="EPSG:4326",
total_bounds=[0, 0, 5, 5],
model=model,
)

data = np.ones((2, 5, 5))

raster.apply_raster(data, "band")

assert len(raster.property_layer.attributes) == 2


def test_cell_attribute():

model = DummyModel()

raster = RasterLayer(
width=5,
height=5,
crs="EPSG:4326",
total_bounds=[0, 0, 5, 5],
model=model,
)

data = np.ones((5, 5))

raster.apply_raster(data, "temperature")

cell = raster.cells[0][0]

assert hasattr(cell, "temperature_0")


def test_get_raster():

model = DummyModel()

raster = RasterLayer(
width=5,
height=5,
crs="EPSG:4326",
total_bounds=[0, 0, 5, 5],
model=model,
)

data = np.ones((5, 5))

raster.apply_raster(data, "temperature")

result = raster.get_raster("temperature_0")

assert result is not None