diff --git a/game.py b/game.py index fcf66bf9..0946076d 100644 --- a/game.py +++ b/game.py @@ -15,6 +15,7 @@ class Game: def __init__(self, seed=None): self.bus = bus + self.bus.reset() self.seed = seed if self.seed is not None: random.seed(self.seed) diff --git a/message_bus_tools.py b/message_bus_tools.py index 72888ccc..99747fb2 100644 --- a/message_bus_tools.py +++ b/message_bus_tools.py @@ -39,9 +39,12 @@ class MessageBus(): registering a callback function that will be called when that message is published. ''' def __init__(self, debug=True): - self.subscribers = dict(dict()) # noqa: C408 self.debug = debug - self.death_messages = [] # what is this? + self.reset() + + def reset(self): + self.subscribers: dict[Message, dict[callable, int]] = dict(dict()) + self.death_messages = [] self.unsubscribe_set = set() self.subscribe_set = set() self.lock_count = 0 @@ -49,28 +52,32 @@ def __init__(self, debug=True): def _clear_subscribes(self): if self.lock_count > 0: return - for event_type, callback, uid in self.subscribe_set: - self.subscribe(event_type, callback, uid) + for event_type, callback, uid, priority in self.subscribe_set: + self.subscribe(event_type, callback, uid, priority) self.subscribe_set.clear() - def subscribe(self, event_type: Message, callback, uid): + def subscribe(self, event_type: Message, callback, uid, priority=50): + '''Subscribes a callback to a message. The callback will be called when the message is published. + Priority is a number from 1 to 100, with 1 being the highest priority and 100 being the lowest. + ''' + assert 1 <= priority <= 100, "Priority must be between 1 and 100" if self.lock_count > 0: if self.debug: ansiprint(f"MESSAGEBUS: {event_type} | Locked. Adding {callback.__qualname__} to subscribe list.") - self.subscribe_set.add((event_type, callback, uid)) + self.subscribe_set.add((event_type, callback, uid, priority)) else: if event_type not in self.subscribers: self.subscribers[event_type] = {} - self.subscribers[event_type][uid] = callback + self.subscribers[event_type][uid] = (callback, priority) if self.debug: - ansiprint(f"MESSAGEBUS: {event_type} | Subscribed {callback.__qualname__}") + ansiprint(f"MESSAGEBUS: {event_type} | Subscribed {callback.__qualname__} with priority {priority}") def _clear_unsubscribes(self): if self.lock_count > 0: return for event_type, uid in self.unsubscribe_set: if self.debug: - ansiprint(f"MESSAGEBUS: Unsubscribing {self.subscribers[event_type][uid].__qualname__} from {', '.join(event_type).replace(', ', '')}") + ansiprint(f"MESSAGEBUS: Unsubscribing {self.subscribers[event_type][uid][0].__qualname__} from {', '.join(event_type).replace(', ', '')}") self.unsubscribe(event_type, uid) self.unsubscribe_set.clear() @@ -80,18 +87,18 @@ def unsubscribe(self, event_type, uid): ansiprint(f"MESSAGEBUS: Locked. Adding {event_type} - {uid} to unsubscribe list.") self.unsubscribe_set.add((event_type, uid)) else: - if uid in self.subscribers[event_type]: + if event_type in self.subscribers and uid in self.subscribers[event_type]: if self.debug: - ansiprint(f"MESSAGEBUS: Unsubscribed {self.subscribers[event_type][uid].__qualname__} from {', '.join(event_type).replace(', ', '')}") + ansiprint(f"MESSAGEBUS: Unsubscribed {self.subscribers[event_type][uid][0].__qualname__} from {', '.join(event_type).replace(', ', '')}") del self.subscribers[event_type][uid] def publish(self, event_type: Message, data): self.lock_count += 1 if event_type in self.subscribers: - for uid, callback in self.subscribers[event_type].items(): - _ = uid + sorted_callbacks = sorted([(x[1][0], x[1][1]) for x in self.subscribers[event_type].items()], key=lambda x: x[1]) + for callback, priority in sorted_callbacks: if self.debug: - ansiprint(f"MESSAGEBUS: {event_type} | Calling {callback.__qualname__}") + ansiprint(f"MESSAGEBUS: {event_type} | Calling {callback.__qualname__} with priority {priority}") callback(event_type, data) self.lock_count -= 1 self._clear_subscribes() @@ -100,10 +107,16 @@ def publish(self, event_type: Message, data): class Registerable(): registers = [] + priorities = [] def register(self, bus): - for message in self.registers: - bus.subscribe(message, self.callback, self.uid) + if self.priorities: + assert len(self.registers) == len(self.priorities), "Registers and priorities must be the same length." + for message, priority in zip(self.registers, self.priorities): + bus.subscribe(message, self.callback, self.uid, priority) + else: + for message in self.registers: + bus.subscribe(message, self.callback, self.uid) self.subscribed = True def unsubscribe(self, event_types: list[Message]=None): diff --git a/tests/test_game.py b/tests/test_game.py index 7cb9d100..c7df6c99 100644 --- a/tests/test_game.py +++ b/tests/test_game.py @@ -33,7 +33,7 @@ def repeat_check(repeat_catcher, last_return, current_return) -> tuple[int, bool def autoplayer(game: game.Game): '''Returns a patched input function that can play the game, maybe. - Usage: + Usage: with monkeypatch.context() as m: m.setattr('builtins.input', autoplayer(game)) ''' @@ -53,7 +53,7 @@ def patched_input(*args, **kwargs): # Handle Start Node if mygame.game_map.current.type == definitions.EncounterType.START: choice, reason = str(random.choice(range(1, len(mygame.game_map.current.children)))), "Start node" - + # Handle dead player = mygame.player if player.state == definitions.State.DEAD: @@ -88,7 +88,7 @@ def patched_input(*args, **kwargs): tmp = all_possible_choices.copy() tmp.remove(choice) choice, reason = random.choice(tmp), "Player is stuck in a loop" - + last_return = choice print(f"AutoPlayer: {choice} ({reason})") return choice @@ -97,7 +97,7 @@ def patched_input(*args, **kwargs): @pytest.mark.timeout(10) -@pytest.mark.parametrize("seed", list(range(3))) +@pytest.mark.parametrize("seed", list(range(20))) def test_e2e(seed, monkeypatch, sleepless): '''Test the game from start to finish Plays with (more or less) random inputs to test the game. @@ -111,6 +111,7 @@ def test_e2e(seed, monkeypatch, sleepless): try: start = time.time() + assert len(mygame.bus.subscribers) == 0 mygame.start() except Exception as e: ansiprint(f"Failed with seed: {seed}") diff --git a/tests/test_message_bus.py b/tests/test_message_bus.py index cf30e356..380f0ea8 100644 --- a/tests/test_message_bus.py +++ b/tests/test_message_bus.py @@ -112,7 +112,7 @@ def side_effect(*args, **kwargs): bus.publish(Message.BEFORE_ATTACK, "second time") assert callbackA.call_count == 2, "Callback A should have been called twice" - callbackB.assert_called_once_with(Message.BEFORE_ATTACK, "second time") + callbackB.assert_called_once_with(Message.BEFORE_ATTACK, "second time") def test_can_unsubscribe_during_publish(self): @@ -134,7 +134,7 @@ def side_effect(*args, **kwargs): callbackA.assert_called_once_with(Message.BEFORE_ATTACK, "data") assert callbackB.call_count == 2, "Callback B should have been called twice" - + def test_can_unsubscribe_during_nested_publish(self): # What happens when you try to unsubscribe from a nested publish? bus = MessageBus(debug=True) @@ -143,9 +143,9 @@ def side_effect_A(*args, **kwargs): bus.publish(Message.AFTER_ATTACK, "calls B") bus.unsubscribe(Message.BEFORE_ATTACK, 1) callbackA = MagicMock(__qualname__="callbackA", side_effect=side_effect_A) - + def side_effect_B(*args, **kwargs): - print("B called.") + print("B called.") bus.unsubscribe(Message.AFTER_ATTACK, 2) callbackB = MagicMock(__qualname__="callbackB", side_effect=side_effect_B) @@ -160,5 +160,25 @@ def side_effect_B(*args, **kwargs): bus.publish(Message.BEFORE_ATTACK, "data") # No additional calls should be made (i.e. unsubscribe was successful) - callbackA.assert_called_once() - callbackB.assert_called_once() \ No newline at end of file + callbackA.assert_called_once() + callbackB.assert_called_once() + + def test_priority_calls_are_respected(self): + bus = MessageBus(debug=True) + ordering = [] + def callbackA(_, data): + ordering.append("A") + def callbackB(_, data): + ordering.append("B") + def callbackC(_, data): + ordering.append("C") + + bus.subscribe(Message.BEFORE_ATTACK, callbackC, uid=1, priority=100) + bus.subscribe(Message.BEFORE_ATTACK, callbackA, uid=2, priority=50) + bus.subscribe(Message.BEFORE_ATTACK, callbackB, uid=3, priority=1) + + bus.publish(Message.BEFORE_ATTACK, "data") + + # Need to assert that the order is correct + assert ordering == ["B", "A", "C"], "Callbacks should be called in order of priority" +