diff --git a/src/network/nonce_tracker.py b/src/network/nonce_tracker.py index 7a343ffb..50a90e34 100644 --- a/src/network/nonce_tracker.py +++ b/src/network/nonce_tracker.py @@ -115,6 +115,161 @@ def _evaluate_routing_topology(self) -> None: logger.error("CRITICAL FAILURE: Comprehensive Horizon node matrix completely unreachable. No healthy nodes found.") - def get_active_endpoint_url(self) -> str: - """Returns the currently active, validated node URL for ledger submissions.""" - return self.active_node.url \ No newline at end of file + +class RPCNodeFailoverSupervisor: + """Proactive RPC node failover supervisor that monitors node connectivity. + + It maintains a list of endpoints and runs a background thread to check their + latency and health using lightweight JSON-RPC requests. If the active node + experiences a latency drop or fails, the supervisor instantly shifts the + active traffic to the fastest available secondary node. + + Complexity: + Time: O(1) for active endpoint lookup, O(N) for checking N endpoints. + Space: O(N) to store latency stats for N endpoints. + """ + + def __init__( + self, + endpoints: Optional[List[str]] = None, + check_interval_sec: float = 2.0, + latency_threshold_ms: float = 500.0, + ping_timeout_sec: float = 1.0, + ) -> None: + self.check_interval_sec = check_interval_sec + self.latency_threshold_ms = latency_threshold_ms + self.ping_timeout_sec = ping_timeout_sec + + if endpoints is None: + primary = os.environ.get("RPC_URL") + fallbacks = os.environ.get("FALLBACK_RPC_URLS") + loaded = [] + if primary: + loaded.append(primary.strip()) + if fallbacks: + for f in fallbacks.split(","): + if f.strip(): + loaded.append(f.strip()) + if not loaded: + loaded = [ + "https://rpc.testnet.stellar.org", + "https://rpc.mainnet.stellar.org", + ] + self.endpoints = loaded + else: + self.endpoints = list(endpoints) + + self._lock = threading.Lock() + self._active_endpoint = self.endpoints[0] if self.endpoints else "" + self._latencies: Dict[str, float] = {ep: 0.0 for ep in self.endpoints} + self._healthy_endpoints: set = set(self.endpoints) + + self._stop_event = threading.Event() + self._monitor_thread: Optional[threading.Thread] = None + + def start(self) -> None: + """Start the background monitoring thread.""" + with self._lock: + if self._monitor_thread is not None and self._monitor_thread.is_alive(): + return + self._stop_event.clear() + self._monitor_thread = threading.Thread( + target=self._run_monitor, + name="RPCNodeFailoverSupervisor-Monitor", + daemon=True, + ) + self._monitor_thread.start() + logger.info("[RPCNodeFailoverSupervisor] Started proactive background monitoring.") + + def stop(self) -> None: + """Stop the background monitoring thread.""" + self._stop_event.set() + if self._monitor_thread is not None: + self._monitor_thread.join(timeout=1.0) + self._monitor_thread = None + logger.info("[RPCNodeFailoverSupervisor] Stopped background monitoring.") + + def get_active_endpoint(self) -> str: + """Return the currently selected active RPC endpoint.""" + with self._lock: + return self._active_endpoint + + def _ping_node(self, endpoint: str) -> Optional[float]: + """Perform a fast, lightweight check on a single node and return its latency in ms.""" + try: + start = time.time() + response = requests.post( + endpoint, + json={"jsonrpc": "2.0", "id": 1, "method": "getHealth"}, + timeout=self.ping_timeout_sec, + ) + latency_ms = (time.time() - start) * 1000.0 + if response.status_code == 200: + data = response.json() + if "result" in data or "error" in data: + return latency_ms + return None + except Exception: + return None + + def _run_monitor(self) -> None: + """Main loop for the background monitoring thread.""" + while not self._stop_event.is_set(): + temp_latencies = {} + temp_healthy = set() + + for ep in self.endpoints: + latency = self._ping_node(ep) + if latency is not None: + temp_latencies[ep] = latency + temp_healthy.add(ep) + else: + temp_latencies[ep] = float("inf") + + with self._lock: + self._latencies.update(temp_latencies) + self._healthy_endpoints = temp_healthy + + active_ok = False + active_latency = self._latencies.get(self._active_endpoint, float("inf")) + + if ( + self._active_endpoint in self._healthy_endpoints + and active_latency <= self.latency_threshold_ms + ): + active_ok = True + + if not active_ok: + best_endpoint = self._active_endpoint + best_latency = active_latency + + for ep in self.endpoints: + ep_latency = self._latencies.get(ep, float("inf")) + if ep in self._healthy_endpoints and ep_latency < best_latency: + best_endpoint = ep + best_latency = ep_latency + + if best_endpoint != self._active_endpoint: + logger.warning( + "[RPCNodeFailoverSupervisor] Shifted traffic from %s (latency: %.1fms) to %s (latency: %.1fms)", + self._active_endpoint, + active_latency, + best_endpoint, + best_latency, + ) + self._active_endpoint = best_endpoint + + self._stop_event.wait(self.check_interval_sec) + + +rpc_supervisor = RPCNodeFailoverSupervisor() + + +__all__ = [ + "NonceTracker", + "NonceWindow", + "nonce_tracker", + "nonce_window", + "RPCNodeFailoverSupervisor", + "rpc_supervisor", +] diff --git a/src/network/rpc_client.py b/src/network/rpc_client.py index a17721af..cdc5e3a2 100644 --- a/src/network/rpc_client.py +++ b/src/network/rpc_client.py @@ -1,37 +1,62 @@ import logging import requests from typing import List, Dict, Any +from network.nonce_tracker import RPCNodeFailoverSupervisor logger = logging.getLogger(__name__) + class FailoverRouter: - """ - Automated RPC Endpoint Switching Routine. + """Automated RPC Endpoint Switching Routine. + Automatically switches data transmission paths to backup node endpoints - if a target fails to respond within a 3500ms window. + using a proactive RPC supervisor to avoid connection timeouts. """ - + def __init__(self, primary_endpoint: str, backup_endpoints: List[str]): self.primary_endpoint = primary_endpoint self.backup_endpoints = backup_endpoints self.timeout_sec = 3.5 # 3500ms window + self.supervisor = RPCNodeFailoverSupervisor( + endpoints=[primary_endpoint] + backup_endpoints, + check_interval_sec=2.0, + latency_threshold_ms=500.0, + ping_timeout_sec=1.0, + ) + self.supervisor.start() def transmit(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]: - endpoints = [self.primary_endpoint] + self.backup_endpoints - + active_url = self.supervisor.get_active_endpoint() + endpoints = [active_url] + [ + ep for ep in self.supervisor.endpoints if ep != active_url + ] + for url in endpoints: target_url = f"{url.rstrip('/')}/{path.lstrip('/')}" try: response = requests.post( - target_url, - json=payload, - timeout=self.timeout_sec + target_url, json=payload, timeout=self.timeout_sec ) response.raise_for_status() return response.json() except requests.exceptions.Timeout: - logger.warning(f"Node {target_url} timed out after {self.timeout_sec}s. Switching to backup.") + logger.warning( + f"Node {target_url} timed out after {self.timeout_sec}s. Switching to backup." + ) except requests.exceptions.RequestException as e: - logger.warning(f"Node {target_url} failed: {e}. Switching to backup.") - + logger.warning( + f"Node {target_url} failed: {e}. Switching to backup." + ) + raise ConnectionError("All RPC endpoints failed to respond.") + + def close(self) -> None: + """Stop the proactive supervisor thread.""" + try: + self.supervisor.stop() + except Exception: + pass + + def __del__(self) -> None: + self.close() + diff --git a/src/utils/state.py b/src/utils/state.py index cb92ea05..3da66fcc 100644 --- a/src/utils/state.py +++ b/src/utils/state.py @@ -1,4 +1,3 @@ -'''state.py """Utility module providing a process‑safe state register for internal worker flags. The register maintains a mapping from arbitrary string identifiers (e.g. ``asset_pair`` @@ -25,24 +24,104 @@ inter‑process communication) without side effects. """ +import os +import json +import tempfile import multiprocessing -from typing import Dict +from typing import Dict, Optional + +try: + import fcntl +except ImportError: + fcntl = None class StateRegister: """Process‑safe registry for boolean activity flags. - Attributes - ---------- - _flags: Dict[str, bool] - Internal mapping from a key to its active/inactive state. - _lock: multiprocessing.Lock - Inter-process mutex guarding all modifications and reads of ``_flags``. + Attributes: + _filepath: Filepath to the local operational metadata file layout. + _lock: Inter-process mutex guarding all modifications and reads of the state file. """ - def __init__(self) -> None: - self._flags: Dict[str, bool] = {} + def __init__(self, filepath: str = "state_register.json") -> None: + self._filepath = filepath + self._lock_filepath = filepath + ".lock" self._lock = multiprocessing.Lock() + with self._lock: + dir_name = os.path.dirname(self._filepath) or "." + if dir_name and not os.path.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + self._execute_with_file_lock(self._init_file) + + def _init_file(self) -> None: + if not os.path.exists(self._filepath): + self._write_state_unlocked({}) + + def _execute_with_file_lock(self, func, *args, **kwargs): + """Execute a function while holding an advisory file lock on Linux.""" + if fcntl is None: + return func(*args, **kwargs) + with open(self._lock_filepath, "w") as lock_file: + try: + fcntl.flock(lock_file, fcntl.LOCK_EX) + return func(*args, **kwargs) + finally: + try: + fcntl.flock(lock_file, fcntl.LOCK_UN) + except Exception: + pass + + def _load_state(self) -> Dict[str, bool]: + """Load state map from the file. + + Complexity: + Time: O(S) where S is the size of the JSON file (de-serialization). + Space: O(S) memory footprint to hold the parsed state map. + """ + return self._execute_with_file_lock(self._load_state_unlocked) + + def _load_state_unlocked(self) -> Dict[str, bool]: + if not os.path.exists(self._filepath): + return {} + try: + with open(self._filepath, "r", encoding="utf-8") as f: + content = f.read().strip() + if not content: + return {} + return json.loads(content) + except Exception: + return {} + + def _write_state(self, flags: Dict[str, bool]) -> None: + """Atomically persist state map to the file using a temporary file. + + Complexity: + Time: O(S) where S is the size of the JSON file (serialization). + Space: O(S) for temporary buffers. + """ + self._execute_with_file_lock(self._write_state_unlocked, flags) + + def _write_state_unlocked(self, flags: Dict[str, bool]) -> None: + dir_name = os.path.dirname(self._filepath) or "." + fd, temp_path = tempfile.mkstemp( + dir=dir_name, + prefix=f".{os.path.basename(self._filepath)}.", + suffix=".tmp" + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(flags, f) + f.flush() + os.fsync(f.fileno()) + os.replace(temp_path, self._filepath) + except Exception: + if os.path.exists(temp_path): + try: + os.unlink(temp_path) + except OSError: + pass + raise def is_active(self, key: str) -> bool: """Return ``True`` if the flag for *key* is set, ``False`` otherwise. @@ -50,7 +129,8 @@ def is_active(self, key: str) -> bool: This method acquires the internal lock to guarantee a consistent view. """ with self._lock: - return self._flags.get(key, False) + flags = self._load_state() + return flags.get(key, False) def activate(self, key: str) -> None: """Mark the flag for *key* as active (``True``). @@ -58,7 +138,9 @@ def activate(self, key: str) -> None: If the key does not yet exist, it is created. """ with self._lock: - self._flags[key] = True + flags = self._load_state() + flags[key] = True + self._write_state(flags) def try_acquire(self, key: str) -> bool: """Atomically check if *key* is inactive and, if so, activate it. @@ -68,10 +150,13 @@ def try_acquire(self, key: str) -> bool: already ``True``. """ with self._lock: - if self._flags.get(key, False): + flags = self._load_state() + if flags.get(key, False): return False - self._flags[key] = True + flags[key] = True + self._write_state(flags) return True + def deactivate(self, key: str) -> None: """Mark the flag for *key* as inactive (``False``). @@ -79,7 +164,9 @@ def deactivate(self, key: str) -> None: without raising ``KeyError``. """ with self._lock: - self._flags[key] = False + flags = self._load_state() + flags[key] = False + self._write_state(flags) # Alias for clarity when releasing a worker lock def release(self, key: str) -> None: @@ -88,13 +175,16 @@ def release(self, key: str) -> None: This can be used by ingestion code to explicitly free the allocation flag. """ self.deactivate(key) + def clear(self, key: str) -> None: """Remove *key* from the registry entirely. After removal, ``is_active`` will return ``False`` for the key. """ with self._lock: - self._flags.pop(key, None) + flags = self._load_state() + flags.pop(key, None) + self._write_state(flags) def snapshot(self) -> Dict[str, bool]: """Return a shallow copy of the current flags mapping. @@ -103,7 +193,7 @@ def snapshot(self) -> Dict[str, bool]: iterate over the result without further synchronization. """ with self._lock: - return dict(self._flags) + return self._load_state() # Optional convenience context manager for safe activation/deactivation def guard(self, key: str): @@ -131,6 +221,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Propagate any exception return False + # Create a module‑level singleton for convenient import state_register = StateRegister() -''' diff --git a/tests/test_nonce_tracker.py b/tests/test_nonce_tracker.py index 843cac8c..1da66545 100644 --- a/tests/test_nonce_tracker.py +++ b/tests/test_nonce_tracker.py @@ -292,3 +292,123 @@ def test_invalid_window_size_raises() -> None: NonceWindow(window_size=0) with pytest.raises(ValueError): NonceWindow(window_size=-1) + + +def test_rpc_node_failover_supervisor_basic(monkeypatch) -> None: + import time + from network.nonce_tracker import RPCNodeFailoverSupervisor + import requests + + endpoints = [ + "https://rpc-primary.stellar.org", + "https://rpc-secondary.stellar.org", + ] + + class MockResponse: + status_code = 200 + + def json(self): + return {"result": {"network": "testnet"}} + + mock_calls = [] + + def mock_post(url, json=None, timeout=None): + mock_calls.append(url) + return MockResponse() + + monkeypatch.setattr(requests, "post", mock_post) + + supervisor = RPCNodeFailoverSupervisor( + endpoints=endpoints, + check_interval_sec=0.1, + latency_threshold_ms=100.0, + ping_timeout_sec=0.5, + ) + + assert supervisor.get_active_endpoint() == endpoints[0] + + supervisor.start() + time.sleep(0.3) + supervisor.stop() + + assert len(mock_calls) > 0 + assert endpoints[0] in mock_calls + + +def test_rpc_node_failover_supervisor_latency_failover(monkeypatch) -> None: + import time + from network.nonce_tracker import RPCNodeFailoverSupervisor + import requests + + endpoints = [ + "https://rpc-primary.stellar.org", + "https://rpc-secondary.stellar.org", + ] + + def mock_post(url, json=None, timeout=None): + class MockResponse: + status_code = 200 + + def json(self): + return {"result": {"network": "testnet"}} + + if "primary" in url: + time.sleep(0.15) + return MockResponse() + + monkeypatch.setattr(requests, "post", mock_post) + + supervisor = RPCNodeFailoverSupervisor( + endpoints=endpoints, + check_interval_sec=0.1, + latency_threshold_ms=100.0, + ping_timeout_sec=0.5, + ) + + assert supervisor.get_active_endpoint() == endpoints[0] + + supervisor.start() + time.sleep(0.3) + supervisor.stop() + + assert supervisor.get_active_endpoint() == endpoints[1] + + +def test_rpc_node_failover_supervisor_failure_failover(monkeypatch) -> None: + import time + from network.nonce_tracker import RPCNodeFailoverSupervisor + import requests + + endpoints = [ + "https://rpc-primary.stellar.org", + "https://rpc-secondary.stellar.org", + ] + + def mock_post(url, json=None, timeout=None): + class MockResponse: + status_code = 200 + + def json(self): + return {"result": {"network": "testnet"}} + + if "primary" in url: + raise requests.exceptions.ConnectionError("Connection refused") + return MockResponse() + + monkeypatch.setattr(requests, "post", mock_post) + + supervisor = RPCNodeFailoverSupervisor( + endpoints=endpoints, + check_interval_sec=0.1, + latency_threshold_ms=100.0, + ping_timeout_sec=0.5, + ) + + assert supervisor.get_active_endpoint() == endpoints[0] + + supervisor.start() + time.sleep(0.3) + supervisor.stop() + + assert supervisor.get_active_endpoint() == endpoints[1] + diff --git a/tests/test_rpc_client.py b/tests/test_rpc_client.py new file mode 100644 index 00000000..a2a67ec4 --- /dev/null +++ b/tests/test_rpc_client.py @@ -0,0 +1,47 @@ +import os +import sys +import pytest +import time + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +from network.rpc_client import FailoverRouter +import requests + + +def test_failover_router_uses_supervisor_active_endpoint(monkeypatch) -> None: + primary = "https://rpc-primary.stellar.org" + backup = "https://rpc-backup.stellar.org" + + class MockResponse: + status_code = 200 + + def json(self): + return {"result": {"status": "healthy"}} + + def raise_for_status(self): + pass + + monkeypatch.setattr( + requests, + "post", + lambda url, json=None, timeout=None: MockResponse(), + ) + + router = FailoverRouter(primary_endpoint=primary, backup_endpoints=[backup]) + time.sleep(0.1) + + transmit_calls = [] + + def mock_transmit_post(url, json=None, timeout=None): + transmit_calls.append(url) + return MockResponse() + + monkeypatch.setattr(requests, "post", mock_transmit_post) + + router.transmit("/submit", {"tx": "xyz"}) + + assert len(transmit_calls) == 1 + assert "submit" in transmit_calls[0] + + router.supervisor.stop() diff --git a/tests/test_state.py b/tests/test_state.py new file mode 100644 index 00000000..f29d15ac --- /dev/null +++ b/tests/test_state.py @@ -0,0 +1,70 @@ +import os +import sys +import time +import pytest +from multiprocessing import Process + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +from utils.state import StateRegister + + +def test_state_register_basic_operations(tmp_path) -> None: + filepath = str(tmp_path / "test_state.json") + register = StateRegister(filepath=filepath) + + assert not register.is_active("worker-1") + register.activate("worker-1") + assert register.is_active("worker-1") + + # Try acquire should fail when active + assert not register.try_acquire("worker-1") + + register.deactivate("worker-1") + assert not register.is_active("worker-1") + + # Try acquire should succeed when inactive + assert register.try_acquire("worker-1") + assert register.is_active("worker-1") + + register.clear("worker-1") + assert not register.is_active("worker-1") + + +def test_state_register_snapshot_and_release(tmp_path) -> None: + filepath = str(tmp_path / "test_state.json") + register = StateRegister(filepath=filepath) + + register.activate("worker-a") + register.activate("worker-b") + + snap = register.snapshot() + assert snap.get("worker-a") is True + assert snap.get("worker-b") is True + + register.release("worker-a") + assert not register.is_active("worker-a") + + +def _worker_process_task(register: StateRegister, worker_id: str) -> None: + # Loop and write to test concurrent stress + for i in range(50): + key = f"key-{worker_id}-{i}" + register.activate(key) + assert register.is_active(key) + register.deactivate(key) + + +def test_state_register_multiprocess_safety(tmp_path) -> None: + filepath = str(tmp_path / "test_state_multiprocess.json") + register = StateRegister(filepath=filepath) + + processes = [] + for idx in range(4): + p = Process(target=_worker_process_task, args=(register, f"worker-{idx}")) + processes.append(p) + p.start() + + for p in processes: + p.join() + assert p.exitcode == 0