Skip to content

Commit b130463

Browse files
committed
feat: automaticall call the update handler if it is provided
1 parent b464dcb commit b130463

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

postgresql_watcher/watcher.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
self.child_conn: Connection | None = None
7070
self.subscription_process: Process | None = None
7171
self._create_subscription_process(start_listening)
72-
self.update_callback: Optional[Callable] = None
72+
self.update_callback: Optional[Callable[[None], None]] = None
7373

7474
def __del__(self) -> None:
7575
self._cleanup_connections_and_processes()
@@ -129,7 +129,7 @@ def _cleanup_connections_and_processes(self) -> None:
129129
self.subscription_process.terminate()
130130
self.subscription_process = None
131131

132-
def set_update_callback(self, update_handler: Callable):
132+
def set_update_callback(self, update_handler: Optional[Callable[[None], None]]):
133133
"""
134134
Set the handler called, when the Watcher detects an update.
135135
Recommendation: `casbin_enforcer.adapter.load_policy`
@@ -164,7 +164,10 @@ def should_reload(self) -> bool:
164164
try:
165165
if self.parent_conn.poll():
166166
message = int(self.parent_conn.recv())
167-
return message == _ChannelSubscriptionMessage.RECEIVED_UPDATE
167+
received_update = message == _ChannelSubscriptionMessage.RECEIVED_UPDATE
168+
if received_update and self.update_callback is not None:
169+
self.update_callback()
170+
return received_update
168171
except EOFError:
169172
self.logger.warning(
170173
"Child casbin-watcher subscribe process has stopped, "

tests/test_postgresql_watcher.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
22
import unittest
3+
from unittest.mock import MagicMock
34
from multiprocessing.connection import Pipe
45
from time import sleep
56
import logging
@@ -80,6 +81,26 @@ def test_no_update_mutiple_pg_watcher(self):
8081
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2)
8182
for watcher in other_watchers:
8283
self.assertFalse(watcher.should_reload())
84+
self.assertFalse(main_watcher.should_reload())
85+
86+
def test_update_handler_called(self):
87+
channel_name = "test_update_handler_called"
88+
main_watcher = get_watcher(channel_name)
89+
handler = MagicMock()
90+
main_watcher.set_update_callback(handler)
91+
main_watcher.update()
92+
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2)
93+
self.assertTrue(main_watcher.should_reload())
94+
self.assertTrue(handler.call_count == 1)
95+
96+
def test_update_handler_not_called(self):
97+
channel_name = "test_update_handler_not_called"
98+
main_watcher = get_watcher(channel_name)
99+
handler = MagicMock()
100+
main_watcher.set_update_callback(handler)
101+
sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2)
102+
self.assertFalse(main_watcher.should_reload())
103+
self.assertTrue(handler.call_count == 0)
83104

84105

85106
if __name__ == "__main__":

0 commit comments

Comments
 (0)