diff --git a/game.py b/game.py
index fcf66bf9..0946076d 100644
--- a/game.py
+++ b/game.py
@@ -15,6 +15,7 @@
class Game:
def __init__(self, seed=None):
self.bus = bus
+ self.bus.reset()
self.seed = seed
if self.seed is not None:
random.seed(self.seed)
diff --git a/message_bus_tools.py b/message_bus_tools.py
index 72888ccc..99747fb2 100644
--- a/message_bus_tools.py
+++ b/message_bus_tools.py
@@ -39,9 +39,12 @@ class MessageBus():
registering a callback function that will be called when that message is published.
'''
def __init__(self, debug=True):
- self.subscribers = dict(dict()) # noqa: C408
self.debug = debug
- self.death_messages = [] # what is this?
+ self.reset()
+
+ def reset(self):
+ self.subscribers: dict[Message, dict[callable, int]] = dict(dict())
+ self.death_messages = []
self.unsubscribe_set = set()
self.subscribe_set = set()
self.lock_count = 0
@@ -49,28 +52,32 @@ def __init__(self, debug=True):
def _clear_subscribes(self):
if self.lock_count > 0:
return
- for event_type, callback, uid in self.subscribe_set:
- self.subscribe(event_type, callback, uid)
+ for event_type, callback, uid, priority in self.subscribe_set:
+ self.subscribe(event_type, callback, uid, priority)
self.subscribe_set.clear()
- def subscribe(self, event_type: Message, callback, uid):
+ def subscribe(self, event_type: Message, callback, uid, priority=50):
+ '''Subscribes a callback to a message. The callback will be called when the message is published.
+ Priority is a number from 1 to 100, with 1 being the highest priority and 100 being the lowest.
+ '''
+ assert 1 <= priority <= 100, "Priority must be between 1 and 100"
if self.lock_count > 0:
if self.debug:
ansiprint(f"MESSAGEBUS: {event_type} | Locked. Adding {callback.__qualname__} to subscribe list.")
- self.subscribe_set.add((event_type, callback, uid))
+ self.subscribe_set.add((event_type, callback, uid, priority))
else:
if event_type not in self.subscribers:
self.subscribers[event_type] = {}
- self.subscribers[event_type][uid] = callback
+ self.subscribers[event_type][uid] = (callback, priority)
if self.debug:
- ansiprint(f"MESSAGEBUS: {event_type} | Subscribed {callback.__qualname__}")
+ ansiprint(f"MESSAGEBUS: {event_type} | Subscribed {callback.__qualname__} with priority {priority}")
def _clear_unsubscribes(self):
if self.lock_count > 0:
return
for event_type, uid in self.unsubscribe_set:
if self.debug:
- ansiprint(f"MESSAGEBUS: Unsubscribing {self.subscribers[event_type][uid].__qualname__} from {', '.join(event_type).replace(', ', '')}")
+ ansiprint(f"MESSAGEBUS: Unsubscribing {self.subscribers[event_type][uid][0].__qualname__} from {', '.join(event_type).replace(', ', '')}")
self.unsubscribe(event_type, uid)
self.unsubscribe_set.clear()
@@ -80,18 +87,18 @@ def unsubscribe(self, event_type, uid):
ansiprint(f"MESSAGEBUS: Locked. Adding {event_type} - {uid} to unsubscribe list.")
self.unsubscribe_set.add((event_type, uid))
else:
- if uid in self.subscribers[event_type]:
+ if event_type in self.subscribers and uid in self.subscribers[event_type]:
if self.debug:
- ansiprint(f"MESSAGEBUS: Unsubscribed {self.subscribers[event_type][uid].__qualname__} from {', '.join(event_type).replace(', ', '')}")
+ ansiprint(f"MESSAGEBUS: Unsubscribed {self.subscribers[event_type][uid][0].__qualname__} from {', '.join(event_type).replace(', ', '')}")
del self.subscribers[event_type][uid]
def publish(self, event_type: Message, data):
self.lock_count += 1
if event_type in self.subscribers:
- for uid, callback in self.subscribers[event_type].items():
- _ = uid
+ sorted_callbacks = sorted([(x[1][0], x[1][1]) for x in self.subscribers[event_type].items()], key=lambda x: x[1])
+ for callback, priority in sorted_callbacks:
if self.debug:
- ansiprint(f"MESSAGEBUS: {event_type} | Calling {callback.__qualname__}")
+ ansiprint(f"MESSAGEBUS: {event_type} | Calling {callback.__qualname__} with priority {priority}")
callback(event_type, data)
self.lock_count -= 1
self._clear_subscribes()
@@ -100,10 +107,16 @@ def publish(self, event_type: Message, data):
class Registerable():
registers = []
+ priorities = []
def register(self, bus):
- for message in self.registers:
- bus.subscribe(message, self.callback, self.uid)
+ if self.priorities:
+ assert len(self.registers) == len(self.priorities), "Registers and priorities must be the same length."
+ for message, priority in zip(self.registers, self.priorities):
+ bus.subscribe(message, self.callback, self.uid, priority)
+ else:
+ for message in self.registers:
+ bus.subscribe(message, self.callback, self.uid)
self.subscribed = True
def unsubscribe(self, event_types: list[Message]=None):
diff --git a/tests/test_game.py b/tests/test_game.py
index 7cb9d100..c7df6c99 100644
--- a/tests/test_game.py
+++ b/tests/test_game.py
@@ -33,7 +33,7 @@ def repeat_check(repeat_catcher, last_return, current_return) -> tuple[int, bool
def autoplayer(game: game.Game):
'''Returns a patched input function that can play the game, maybe.
- Usage:
+ Usage:
with monkeypatch.context() as m:
m.setattr('builtins.input', autoplayer(game))
'''
@@ -53,7 +53,7 @@ def patched_input(*args, **kwargs):
# Handle Start Node
if mygame.game_map.current.type == definitions.EncounterType.START:
choice, reason = str(random.choice(range(1, len(mygame.game_map.current.children)))), "Start node"
-
+
# Handle dead
player = mygame.player
if player.state == definitions.State.DEAD:
@@ -88,7 +88,7 @@ def patched_input(*args, **kwargs):
tmp = all_possible_choices.copy()
tmp.remove(choice)
choice, reason = random.choice(tmp), "Player is stuck in a loop"
-
+
last_return = choice
print(f"AutoPlayer: {choice} ({reason})")
return choice
@@ -97,7 +97,7 @@ def patched_input(*args, **kwargs):
@pytest.mark.timeout(10)
-@pytest.mark.parametrize("seed", list(range(3)))
+@pytest.mark.parametrize("seed", list(range(20)))
def test_e2e(seed, monkeypatch, sleepless):
'''Test the game from start to finish
Plays with (more or less) random inputs to test the game.
@@ -111,6 +111,7 @@ def test_e2e(seed, monkeypatch, sleepless):
try:
start = time.time()
+ assert len(mygame.bus.subscribers) == 0
mygame.start()
except Exception as e:
ansiprint(f"Failed with seed: {seed}")
diff --git a/tests/test_message_bus.py b/tests/test_message_bus.py
index cf30e356..380f0ea8 100644
--- a/tests/test_message_bus.py
+++ b/tests/test_message_bus.py
@@ -112,7 +112,7 @@ def side_effect(*args, **kwargs):
bus.publish(Message.BEFORE_ATTACK, "second time")
assert callbackA.call_count == 2, "Callback A should have been called twice"
- callbackB.assert_called_once_with(Message.BEFORE_ATTACK, "second time")
+ callbackB.assert_called_once_with(Message.BEFORE_ATTACK, "second time")
def test_can_unsubscribe_during_publish(self):
@@ -134,7 +134,7 @@ def side_effect(*args, **kwargs):
callbackA.assert_called_once_with(Message.BEFORE_ATTACK, "data")
assert callbackB.call_count == 2, "Callback B should have been called twice"
-
+
def test_can_unsubscribe_during_nested_publish(self):
# What happens when you try to unsubscribe from a nested publish?
bus = MessageBus(debug=True)
@@ -143,9 +143,9 @@ def side_effect_A(*args, **kwargs):
bus.publish(Message.AFTER_ATTACK, "calls B")
bus.unsubscribe(Message.BEFORE_ATTACK, 1)
callbackA = MagicMock(__qualname__="callbackA", side_effect=side_effect_A)
-
+
def side_effect_B(*args, **kwargs):
- print("B called.")
+ print("B called.")
bus.unsubscribe(Message.AFTER_ATTACK, 2)
callbackB = MagicMock(__qualname__="callbackB", side_effect=side_effect_B)
@@ -160,5 +160,25 @@ def side_effect_B(*args, **kwargs):
bus.publish(Message.BEFORE_ATTACK, "data")
# No additional calls should be made (i.e. unsubscribe was successful)
- callbackA.assert_called_once()
- callbackB.assert_called_once()
\ No newline at end of file
+ callbackA.assert_called_once()
+ callbackB.assert_called_once()
+
+ def test_priority_calls_are_respected(self):
+ bus = MessageBus(debug=True)
+ ordering = []
+ def callbackA(_, data):
+ ordering.append("A")
+ def callbackB(_, data):
+ ordering.append("B")
+ def callbackC(_, data):
+ ordering.append("C")
+
+ bus.subscribe(Message.BEFORE_ATTACK, callbackC, uid=1, priority=100)
+ bus.subscribe(Message.BEFORE_ATTACK, callbackA, uid=2, priority=50)
+ bus.subscribe(Message.BEFORE_ATTACK, callbackB, uid=3, priority=1)
+
+ bus.publish(Message.BEFORE_ATTACK, "data")
+
+ # Need to assert that the order is correct
+ assert ordering == ["B", "A", "C"], "Callbacks should be called in order of priority"
+