Skip to content

Commit a7dc9b2

Browse files
Corvincepre-commit-ci[bot]quaquelEwoutH
authored
Generalize CellAgent (#2292)
* add some agents * restructure and rename * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Restructure mixins * rename and update Grid2DMovement * use direction map instead of match * Add Patch * tests for all new stuff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update cell_agent.py * Update test_cell_space.py * Rename Patch to FixedAgent Co-authored-by: Ewout ter Hoeven <[email protected]> * Rename Patch to FixedAgent in tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use FixedAgent in examples/benchmarks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jan Kwakkel <[email protected]> Co-authored-by: Ewout ter Hoeven <[email protected]>
1 parent 4e45300 commit a7dc9b2

File tree

6 files changed

+158
-17
lines changed

6 files changed

+158
-17
lines changed

benchmarks/WolfSheep/wolf_sheep.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import math
1111

1212
from mesa import Model
13-
from mesa.experimental.cell_space import CellAgent, OrthogonalVonNeumannGrid
13+
from mesa.experimental.cell_space import CellAgent, FixedAgent, OrthogonalVonNeumannGrid
1414
from mesa.experimental.devs import ABMSimulator
1515

1616

@@ -87,7 +87,7 @@ def feed(self):
8787
sheep_to_eat.remove()
8888

8989

90-
class GrassPatch(CellAgent):
90+
class GrassPatch(FixedAgent):
9191
"""A patch of grass that grows at a fixed rate and it is eaten by sheep."""
9292

9393
@property

mesa/experimental/cell_space/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
"""
77

88
from mesa.experimental.cell_space.cell import Cell
9-
from mesa.experimental.cell_space.cell_agent import CellAgent
9+
from mesa.experimental.cell_space.cell_agent import (
10+
CellAgent,
11+
FixedAgent,
12+
Grid2DMovingAgent,
13+
)
1014
from mesa.experimental.cell_space.cell_collection import CellCollection
1115
from mesa.experimental.cell_space.discrete_space import DiscreteSpace
1216
from mesa.experimental.cell_space.grid import (
@@ -22,6 +26,8 @@
2226
"CellCollection",
2327
"Cell",
2428
"CellAgent",
29+
"Grid2DMovingAgent",
30+
"FixedAgent",
2531
"DiscreteSpace",
2632
"Grid",
2733
"HexGrid",

mesa/experimental/cell_space/cell.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from random import Random
88
from typing import TYPE_CHECKING, Any
99

10+
from mesa.experimental.cell_space.cell_agent import CellAgent
1011
from mesa.experimental.cell_space.cell_collection import CellCollection
1112
from mesa.space import PropertyLayer
1213

1314
if TYPE_CHECKING:
1415
from mesa.agent import Agent
15-
from mesa.experimental.cell_space.cell_agent import CellAgent
1616

1717
Coordinate = tuple[int, ...]
1818

@@ -69,7 +69,7 @@ def __init__(
6969
self.agents: list[
7070
Agent
7171
] = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, )
72-
self.capacity: int = capacity
72+
self.capacity: int | None = capacity
7373
self.properties: dict[Coordinate, object] = {}
7474
self.random = random
7575
self._mesa_property_layers: dict[str, PropertyLayer] = {}
@@ -136,7 +136,7 @@ def __repr__(self): # noqa
136136
return f"Cell({self.coordinate}, {self.agents})"
137137

138138
@cached_property
139-
def neighborhood(self) -> CellCollection:
139+
def neighborhood(self) -> CellCollection[Cell]:
140140
"""Returns the direct neighborhood of the cell.
141141
142142
This is equivalent to cell.get_neighborhood(radius=1)
@@ -148,7 +148,7 @@ def neighborhood(self) -> CellCollection:
148148
@cache # noqa: B019
149149
def get_neighborhood(
150150
self, radius: int = 1, include_center: bool = False
151-
) -> CellCollection:
151+
) -> CellCollection[Cell]:
152152
"""Returns a list of all neighboring cells for the given radius.
153153
154154
For getting the direct neighborhood (i.e., radius=1) you can also use

mesa/experimental/cell_space/cell_agent.py

+93-9
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,24 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Protocol
66

7-
from mesa import Agent
7+
from mesa.agent import Agent
88

99
if TYPE_CHECKING:
10-
from mesa.experimental.cell_space.cell import Cell
10+
from mesa.experimental.cell_space import Cell
11+
12+
13+
class HasCellProtocol(Protocol):
14+
"""Protocol for discrete space cell holders."""
15+
16+
cell: Cell
1117

1218

1319
class HasCell:
1420
"""Descriptor for cell movement behavior."""
1521

16-
_mesa_cell: Cell = None
22+
_mesa_cell: Cell | None = None
1723

1824
@property
1925
def cell(self) -> Cell | None: # noqa: D102
@@ -33,17 +39,95 @@ def cell(self, cell: Cell | None) -> None:
3339
cell.add_agent(self)
3440

3541

36-
class CellAgent(Agent, HasCell):
42+
class BasicMovement:
43+
"""Mixin for moving agents in discrete space."""
44+
45+
def move_to(self: HasCellProtocol, cell: Cell) -> None:
46+
"""Move to a new cell."""
47+
self.cell = cell
48+
49+
def move_relative(self: HasCellProtocol, direction: tuple[int, ...]):
50+
"""Move to a cell relative to the current cell.
51+
52+
Args:
53+
direction: The direction to move in.
54+
"""
55+
new_cell = self.cell.connections.get(direction)
56+
if new_cell is not None:
57+
self.cell = new_cell
58+
else:
59+
raise ValueError(f"No cell in direction {direction}")
60+
61+
62+
class FixedCell(HasCell):
63+
"""Mixin for agents that are fixed to a cell."""
64+
65+
@property
66+
def cell(self) -> Cell | None: # noqa: D102
67+
return self._mesa_cell
68+
69+
@cell.setter
70+
def cell(self, cell: Cell) -> None:
71+
if self.cell is not None:
72+
raise ValueError("Cannot move agent in FixedCell")
73+
self._mesa_cell = cell
74+
75+
cell.add_agent(self)
76+
77+
78+
class CellAgent(Agent, HasCell, BasicMovement):
3779
"""Cell Agent is an extension of the Agent class and adds behavior for moving in discrete spaces.
3880
3981
Attributes:
40-
unique_id (int): A unique identifier for this agent.
41-
model (Model): The model instance to which the agent belongs
42-
pos: (Position | None): The position of the agent in the space
43-
cell: (Cell | None): the cell which the agent occupies
82+
cell (Cell): The cell the agent is currently in.
4483
"""
4584

4685
def remove(self):
4786
"""Remove the agent from the model."""
4887
super().remove()
4988
self.cell = None # ensures that we are also removed from cell
89+
90+
91+
class FixedAgent(Agent, FixedCell):
92+
"""A patch in a 2D grid."""
93+
94+
def remove(self):
95+
"""Remove the agent from the model."""
96+
super().remove()
97+
98+
# fixme we leave self._mesa_cell on the original value
99+
# so you cannot hijack remove() to move patches
100+
self.cell.remove_agent(self)
101+
102+
103+
class Grid2DMovingAgent(CellAgent):
104+
"""Mixin for moving agents in 2D grids."""
105+
106+
# fmt: off
107+
DIRECTION_MAP = {
108+
"n": (-1, 0), "north": (-1, 0), "up": (-1, 0),
109+
"s": (1, 0), "south": (1, 0), "down": (1, 0),
110+
"e": (0, 1), "east": (0, 1), "right": (0, 1),
111+
"w": (0, -1), "west": (0, -1), "left": (0, -1),
112+
"ne": (-1, 1), "northeast": (-1, 1), "upright": (-1, 1),
113+
"nw": (-1, -1), "northwest": (-1, -1), "upleft": (-1, -1),
114+
"se": (1, 1), "southeast": (1, 1), "downright": (1, 1),
115+
"sw": (1, -1), "southwest": (1, -1), "downleft": (1, -1)
116+
}
117+
# fmt: on
118+
119+
def move(self, direction: str, distance: int = 1):
120+
"""Move the agent in a cardinal direction.
121+
122+
Args:
123+
direction: The cardinal direction to move in.
124+
distance: The distance to move.
125+
"""
126+
direction = direction.lower() # Convert direction to lowercase
127+
128+
if direction not in self.DIRECTION_MAP:
129+
raise ValueError(f"Invalid direction: {direction}")
130+
131+
move_vector = self.DIRECTION_MAP[direction]
132+
for _ in range(distance):
133+
self.move_relative(move_vector)

mesa/experimental/devs/examples/wolf_sheep.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Example of using ABM simulator for Wolf-Sheep Predation Model."""
22

33
import mesa
4+
from mesa.experimental.cell_space import FixedAgent
45
from mesa.experimental.devs.simulator import ABMSimulator
56

67

@@ -90,7 +91,7 @@ def feed(self):
9091
sheep_to_eat.die()
9192

9293

93-
class GrassPatch(mesa.Agent):
94+
class GrassPatch(FixedAgent):
9495
"""A patch of grass that grows at a fixed rate and it is eaten by sheep."""
9596

9697
@property

tests/test_cell_space.py

+50
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
Cell,
1111
CellAgent,
1212
CellCollection,
13+
FixedAgent,
14+
Grid2DMovingAgent,
1315
HexGrid,
1416
Network,
1517
OrthogonalMooreGrid,
@@ -641,3 +643,51 @@ def test_cell_agent(): # noqa: D103
641643
assert agent not in model._all_agents
642644
assert agent not in cell1.agents
643645
assert agent not in cell2.agents
646+
647+
model = Model()
648+
agent = CellAgent(model)
649+
agent.cell = cell1
650+
agent.move_to(cell2)
651+
assert agent not in cell1.agents
652+
assert agent in cell2.agents
653+
654+
655+
def test_grid2DMovingAgent(): # noqa: D103
656+
# we first test on a moore grid because all directions are defined
657+
grid = OrthogonalMooreGrid((10, 10), torus=False)
658+
659+
model = Model()
660+
agent = Grid2DMovingAgent(model)
661+
662+
agent.cell = grid[4, 4]
663+
agent.move("up")
664+
assert agent.cell == grid[3, 4]
665+
666+
grid = OrthogonalVonNeumannGrid((10, 10), torus=False)
667+
668+
model = Model()
669+
agent = Grid2DMovingAgent(model)
670+
agent.cell = grid[4, 4]
671+
672+
with pytest.raises(ValueError): # test for invalid direction
673+
agent.move("upright")
674+
675+
with pytest.raises(ValueError): # test for unknown direction
676+
agent.move("back")
677+
678+
679+
def test_patch(): # noqa: D103
680+
cell1 = Cell((1,), capacity=None, random=random.Random())
681+
cell2 = Cell((2,), capacity=None, random=random.Random())
682+
683+
# connect
684+
# add_agent
685+
model = Model()
686+
agent = FixedAgent(model)
687+
agent.cell = cell1
688+
689+
with pytest.raises(ValueError):
690+
agent.cell = cell2
691+
692+
agent.remove()
693+
assert agent not in model._agents

0 commit comments

Comments
 (0)