Skip to content

Commit e255a2d

Browse files
authored
replace model with random in AgentSet init (#2350)
* replace model with random in AgentSet `__init__` also closing #2323
1 parent 0082da2 commit e255a2d

File tree

3 files changed

+47
-49
lines changed

3 files changed

+47
-49
lines changed

mesa/agent.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,18 @@ class AgentSet(MutableSet, Sequence):
9999
which means that agents not referenced elsewhere in the program may be automatically removed from the AgentSet.
100100
"""
101101

102-
def __init__(self, agents: Iterable[Agent], model: Model):
102+
def __init__(self, agents: Iterable[Agent], random: Random | None = None):
103103
"""Initializes the AgentSet with a collection of agents and a reference to the model.
104104
105105
Args:
106106
agents (Iterable[Agent]): An iterable of Agent objects to be included in the set.
107-
model (Model): The ABM model instance to which this AgentSet belongs.
107+
random (Random): the random number generator
108108
"""
109-
self.model = model
109+
if random is None:
110+
random = (
111+
Random()
112+
) # FIXME see issue 1981, how to get the central rng from model
113+
self.random = random
110114
self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
111115

112116
def __len__(self) -> int:
@@ -177,7 +181,7 @@ def agent_generator(filter_func, agent_type, at_most):
177181

178182
agents = agent_generator(filter_func, agent_type, at_most)
179183

180-
return AgentSet(agents, self.model) if not inplace else self._update(agents)
184+
return AgentSet(agents, self.random) if not inplace else self._update(agents)
181185

182186
def shuffle(self, inplace: bool = False) -> AgentSet:
183187
"""Randomly shuffle the order of agents in the AgentSet.
@@ -200,7 +204,7 @@ def shuffle(self, inplace: bool = False) -> AgentSet:
200204
return self
201205
else:
202206
return AgentSet(
203-
(agent for ref in weakrefs if (agent := ref()) is not None), self.model
207+
(agent for ref in weakrefs if (agent := ref()) is not None), self.random
204208
)
205209

206210
def sort(
@@ -225,7 +229,7 @@ def sort(
225229
sorted_agents = sorted(self._agents.keys(), key=key, reverse=not ascending)
226230

227231
return (
228-
AgentSet(sorted_agents, self.model)
232+
AgentSet(sorted_agents, self.random)
229233
if not inplace
230234
else self._update(sorted_agents)
231235
)
@@ -477,26 +481,17 @@ def __getstate__(self):
477481
Returns:
478482
dict: A dictionary representing the state of the AgentSet.
479483
"""
480-
return {"agents": list(self._agents.keys()), "model": self.model}
484+
return {"agents": list(self._agents.keys()), "random": self.random}
481485

482486
def __setstate__(self, state):
483487
"""Set the state of the AgentSet during deserialization.
484488
485489
Args:
486490
state (dict): A dictionary representing the state to restore.
487491
"""
488-
self.model = state["model"]
492+
self.random = state["random"]
489493
self._update(state["agents"])
490494

491-
@property
492-
def random(self) -> Random:
493-
"""Provide access to the model's random number generator.
494-
495-
Returns:
496-
Random: The random number generator associated with the model.
497-
"""
498-
return self.model.random
499-
500495
def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:
501496
"""Group agents by the specified attribute or return from the callable.
502497
@@ -529,7 +524,7 @@ def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:
529524

530525
if result_type == "agentset":
531526
return GroupBy(
532-
{k: AgentSet(v, model=self.model) for k, v in groups.items()}
527+
{k: AgentSet(v, random=self.random) for k, v in groups.items()}
533528
)
534529
else:
535530
return GroupBy(groups)

mesa/model.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ def __init__(self, *args: Any, seed: float | None = None, **kwargs: Any) -> None
5252
self.running = True
5353
self.steps: int = 0
5454

55-
self._setup_agent_registration()
56-
5755
self._seed = seed
5856
if self._seed is None:
5957
# We explicitly specify the seed here so that we know its value in
@@ -65,6 +63,9 @@ def __init__(self, *args: Any, seed: float | None = None, **kwargs: Any) -> None
6563
self._user_step = self.step
6664
self.step = self._wrapped_step
6765

66+
# setup agent registration data structures
67+
self._setup_agent_registration()
68+
6869
def _wrapped_step(self, *args: Any, **kwargs: Any) -> None:
6970
"""Automatically increments time and steps after calling the user's step method."""
7071
# Automatically increment time and step counters
@@ -119,7 +120,9 @@ def _setup_agent_registration(self):
119120
self._agents_by_type: dict[
120121
type[Agent], AgentSet
121122
] = {} # a dict with an agentset for each class of agents
122-
self._all_agents = AgentSet([], self) # an agenset with all agents
123+
self._all_agents = AgentSet(
124+
[], random=self.random
125+
) # an agenset with all agents
123126

124127
def register_agent(self, agent):
125128
"""Register the agent with the model.
@@ -153,7 +156,7 @@ def register_agent(self, agent):
153156
[
154157
agent,
155158
],
156-
self,
159+
random=self.random,
157160
)
158161

159162
self._all_agents.add(agent)

tests/test_agent.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_agentset():
6262
model = Model()
6363
agents = [AgentTest(model) for _ in range(10)]
6464

65-
agentset = AgentSet(agents, model)
65+
agentset = AgentSet(agents, random=model.random)
6666

6767
assert agents[0] in agentset
6868
assert len(agentset) == len(agents)
@@ -118,7 +118,7 @@ def test_function(agent):
118118

119119
# because AgentSet uses weakrefs, we need hard refs as well....
120120
other_agents, another_set = pickle.loads( # noqa: S301
121-
pickle.dumps([agents, AgentSet(agents, model)])
121+
pickle.dumps([agents, AgentSet(agents, random=model.random)])
122122
)
123123
assert all(
124124
a1.unique_id == a2.unique_id for a1, a2 in zip(another_set, other_agents)
@@ -129,19 +129,19 @@ def test_function(agent):
129129
def test_agentset_initialization():
130130
"""Test agentset initialization."""
131131
model = Model()
132-
empty_agentset = AgentSet([], model)
132+
empty_agentset = AgentSet([], random=model.random)
133133
assert len(empty_agentset) == 0
134134

135135
agents = [AgentTest(model) for _ in range(10)]
136-
agentset = AgentSet(agents, model)
136+
agentset = AgentSet(agents, random=model.random)
137137
assert len(agentset) == 10
138138

139139

140140
def test_agentset_serialization():
141141
"""Test pickleability of agentset."""
142142
model = Model()
143143
agents = [AgentTest(model) for _ in range(5)]
144-
agentset = AgentSet(agents, model)
144+
agentset = AgentSet(agents, random=model.random)
145145

146146
serialized = pickle.dumps(agentset)
147147
deserialized = pickle.loads(serialized) # noqa: S301
@@ -156,7 +156,7 @@ def test_agent_membership():
156156
"""Test agent membership in AgentSet."""
157157
model = Model()
158158
agents = [AgentTest(model) for _ in range(5)]
159-
agentset = AgentSet(agents, model)
159+
agentset = AgentSet(agents, random=model.random)
160160

161161
assert agents[0] in agentset
162162
assert AgentTest(model) not in agentset
@@ -166,7 +166,7 @@ def test_agent_add_remove_discard():
166166
"""Test adding, removing and discarding agents from AgentSet."""
167167
model = Model()
168168
agent = AgentTest(model)
169-
agentset = AgentSet([], model)
169+
agentset = AgentSet([], random=model.random)
170170

171171
agentset.add(agent)
172172
assert agent in agentset
@@ -186,7 +186,7 @@ def test_agentset_get_item():
186186
"""Test integer based access to AgentSet."""
187187
model = Model()
188188
agents = [AgentTest(model) for _ in range(10)]
189-
agentset = AgentSet(agents, model)
189+
agentset = AgentSet(agents, random=model.random)
190190

191191
assert agentset[0] == agents[0]
192192
assert agentset[-1] == agents[-1]
@@ -200,7 +200,7 @@ def test_agentset_do_str():
200200
"""Test AgentSet.do with str."""
201201
model = Model()
202202
agents = [AgentTest(model) for _ in range(10)]
203-
agentset = AgentSet(agents, model)
203+
agentset = AgentSet(agents, random=model.random)
204204

205205
with pytest.raises(AttributeError):
206206
agentset.do("non_existing_method")
@@ -213,7 +213,7 @@ def test_agentset_do_str():
213213
n = 10
214214
model = Model()
215215
agents = [AgentDoTest(model) for _ in range(n)]
216-
agentset = AgentSet(agents, model)
216+
agentset = AgentSet(agents, random=model.random)
217217
for agent in agents:
218218
agent.agent_set = agentset
219219

@@ -223,7 +223,7 @@ def test_agentset_do_str():
223223
# setup
224224
model = Model()
225225
agents = [AgentDoTest(model) for _ in range(10)]
226-
agentset = AgentSet(agents, model)
226+
agentset = AgentSet(agents, random=model.random)
227227
for agent in agents:
228228
agent.agent_set = agentset
229229

@@ -235,7 +235,7 @@ def test_agentset_do_callable():
235235
"""Test AgentSet.do with callable."""
236236
model = Model()
237237
agents = [AgentTest(model) for _ in range(10)]
238-
agentset = AgentSet(agents, model)
238+
agentset = AgentSet(agents, random=model.random)
239239

240240
# Test callable with non-existent function
241241
with pytest.raises(AttributeError):
@@ -249,7 +249,7 @@ def test_agentset_do_callable():
249249
n = 10
250250
model = Model()
251251
agents = [AgentDoTest(model) for _ in range(n)]
252-
agentset = AgentSet(agents, model)
252+
agentset = AgentSet(agents, random=model.random)
253253
for agent in agents:
254254
agent.agent_set = agentset
255255

@@ -260,7 +260,7 @@ def test_agentset_do_callable():
260260
# setup again for lambda function tests
261261
model = Model()
262262
agents = [AgentDoTest(model) for _ in range(10)]
263-
agentset = AgentSet(agents, model)
263+
agentset = AgentSet(agents, random=model.random)
264264
for agent in agents:
265265
agent.agent_set = agentset
266266

@@ -278,7 +278,7 @@ def remove_function(agent):
278278
# setup again for actual function tests
279279
model = Model()
280280
agents = [AgentDoTest(model) for _ in range(n)]
281-
agentset = AgentSet(agents, model)
281+
agentset = AgentSet(agents, random=model.random)
282282
for agent in agents:
283283
agent.agent_set = agentset
284284

@@ -289,7 +289,7 @@ def remove_function(agent):
289289
# setup again for actual function tests
290290
model = Model()
291291
agents = [AgentDoTest(model) for _ in range(10)]
292-
agentset = AgentSet(agents, model)
292+
agentset = AgentSet(agents, random=model.random)
293293
for agent in agents:
294294
agent.agent_set = agentset
295295

@@ -354,7 +354,7 @@ def test_agentset_agg():
354354
agent.energy = i + 1
355355
agent.wealth = 10 * (i + 1)
356356

357-
agentset = AgentSet(agents, model)
357+
agentset = AgentSet(agents, random=model.random)
358358

359359
# Test min aggregation
360360
min_energy = agentset.agg("energy", min)
@@ -391,7 +391,7 @@ def __init__(self, model, age=None):
391391

392392
model = Model()
393393
agents = [TestAgentWithAttribute(model, age=i) for i in range(5)]
394-
agentset = AgentSet(agents, model)
394+
agentset = AgentSet(agents, random=model.random)
395395

396396
# Set a new attribute "health" and an existing attribute "age" for all agents
397397
agentset.set("health", 100).set("age", 50).set("status", "active")
@@ -410,7 +410,7 @@ def test_agentset_map_str():
410410
"""Test AgentSet.map with strings."""
411411
model = Model()
412412
agents = [AgentTest(model) for _ in range(10)]
413-
agentset = AgentSet(agents, model)
413+
agentset = AgentSet(agents, random=model.random)
414414

415415
with pytest.raises(AttributeError):
416416
agentset.do("non_existing_method")
@@ -423,7 +423,7 @@ def test_agentset_map_callable():
423423
"""Test AgentSet.map with callable."""
424424
model = Model()
425425
agents = [AgentTest(model) for _ in range(10)]
426-
agentset = AgentSet(agents, model)
426+
agentset = AgentSet(agents, random=model.random)
427427

428428
# Test callable with non-existent function
429429
with pytest.raises(AttributeError):
@@ -450,7 +450,7 @@ def test_method(self):
450450
self.called = True
451451

452452
agents = [TestAgentShuffleDo(model) for _ in range(100)]
453-
agentset = AgentSet(agents, model)
453+
agentset = AgentSet(agents, random=model.random)
454454

455455
# Test shuffle_do with a string method name
456456
agentset.shuffle_do("test_method")
@@ -477,7 +477,7 @@ def test_agentset_get_attribute():
477477
"""Test AgentSet.get for attributes."""
478478
model = Model()
479479
agents = [AgentTest(model) for _ in range(10)]
480-
agentset = AgentSet(agents, model)
480+
agentset = AgentSet(agents, random=model.random)
481481

482482
unique_ids = agentset.get("unique_id")
483483
assert unique_ids == [agent.unique_id for agent in agents]
@@ -491,7 +491,7 @@ def test_agentset_get_attribute():
491491
agent = AgentTest(model)
492492
agent.i = i**2
493493
agents.append(agent)
494-
agentset = AgentSet(agents, model)
494+
agentset = AgentSet(agents, random=model.random)
495495

496496
values = agentset.get(["unique_id", "i"])
497497

@@ -521,7 +521,7 @@ def test_agentset_select_by_type():
521521

522522
# Combine the two types of agents
523523
mixed_agents = test_agents + other_agents
524-
agentset = AgentSet(mixed_agents, model)
524+
agentset = AgentSet(mixed_agents, random=model.random)
525525

526526
# Test selection by type
527527
selected_test_agents = agentset.select(agent_type=AgentTest)
@@ -544,11 +544,11 @@ def test_agentset_shuffle():
544544
model = Model()
545545
test_agents = [AgentTest(model) for _ in range(12)]
546546

547-
agentset = AgentSet(test_agents, model=model)
547+
agentset = AgentSet(test_agents, random=model.random)
548548
agentset = agentset.shuffle()
549549
assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset))
550550

551-
agentset = AgentSet(test_agents, model=model)
551+
agentset = AgentSet(test_agents, random=model.random)
552552
agentset.shuffle(inplace=True)
553553
assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset))
554554

@@ -567,7 +567,7 @@ def get_unique_identifier(self):
567567

568568
model = Model()
569569
agents = [TestAgent(model) for _ in range(10)]
570-
agentset = AgentSet(agents, model)
570+
agentset = AgentSet(agents, random=model.random)
571571

572572
groups = agentset.groupby("even")
573573
assert len(groups.groups[True]) == 5

0 commit comments

Comments
 (0)