diff --git a/paymaster-relayer/.env.example b/paymaster-relayer/.env.example index 6169e90..2efff86 100644 --- a/paymaster-relayer/.env.example +++ b/paymaster-relayer/.env.example @@ -14,19 +14,22 @@ ROFL_ADAPTER_ADDRESS=0xYourROFLAdapterOnSapphire # =========================================== # BASE SEPOLIA (Chain ID: 84532) -> relayer-base # =========================================== -BASE_SOURCE_RPC_URL=https://base-sepolia.g.alchemy.com/v2/YOUR_KEY +BASE_SOURCE_RPC_URLS=https://base-sepolia.g.alchemy.com/v2/YOUR_KEY +# Add failover RPCs with commas: https://url1,https://url2,https://url3 BASE_VAULT_ADDRESS=0xYourPaymasterVaultOnBase # =========================================== # ETHEREUM SEPOLIA (Chain ID: 11155111) -> relayer-eth # =========================================== -ETH_SOURCE_RPC_URL=https://eth-sepolia.g.alchemy.com/v2/YOUR_KEY +ETH_SOURCE_RPC_URLS=https://eth-sepolia.g.alchemy.com/v2/YOUR_KEY +# Add failover RPCs with commas: https://url1,https://url2,https://url3 ETH_VAULT_ADDRESS=0xYourPaymasterVaultOnEthereum # =========================================== # ARBITRUM SEPOLIA (Chain ID: 421614) -> relayer-arb # =========================================== -ARB_SOURCE_RPC_URL=https://arb-sepolia.g.alchemy.com/v2/YOUR_KEY +ARB_SOURCE_RPC_URLS=https://arb-sepolia.g.alchemy.com/v2/YOUR_KEY +# Add failover RPCs with commas: https://url1,https://url2,https://url3 ARB_VAULT_ADDRESS=0xYourPaymasterVaultOnArbitrum # =========================================== diff --git a/paymaster-relayer/README.md b/paymaster-relayer/README.md index 8fc5f4d..d3f720d 100644 --- a/paymaster-relayer/README.md +++ b/paymaster-relayer/README.md @@ -42,7 +42,7 @@ ROFL deployment in one go. | Variable | Required | Description | |----------|----------|-------------| -| `SOURCE_RPC_URL` | Yes | Source chain RPC endpoint (HTTP) | +| `SOURCE_RPC_URLS` | Yes | Source chain RPC endpoints (comma-separated for failover) | | `PAYMASTER_VAULT_ADDRESS` | Yes | PaymasterVault contract address on source chain | | `TARGET_RPC_URL` | Yes | Sapphire RPC endpoint | | `PAYMASTER_PROXY_ADDRESS` | Yes | CrossChainPaymaster proxy address on Sapphire | diff --git a/paymaster-relayer/compose.local.yaml b/paymaster-relayer/compose.local.yaml index c319c3b..7069f65 100644 --- a/paymaster-relayer/compose.local.yaml +++ b/paymaster-relayer/compose.local.yaml @@ -5,7 +5,7 @@ x-relayer-common: &relayer-common build: context: . dockerfile: Dockerfile - image: ghcr.io/oasisprotocol/rofl-paymaster-relayer:multichain-test + image: ghcr.io/oasisprotocol/rofl-paymaster-relayer:multirpc-test platform: linux/amd64 entrypoint: /bin/sh -c 'uv run python -m paymaster_relayer --local' restart: on-failure @@ -31,7 +31,7 @@ services: environment: <<: *common-env CHAIN_NAME: base - SOURCE_RPC_URL: ${BASE_SOURCE_RPC_URL} + SOURCE_RPC_URLS: ${BASE_SOURCE_RPC_URLS} PAYMASTER_VAULT_ADDRESS: ${BASE_VAULT_ADDRESS} POLLING_INTERVAL: ${BASE_POLLING_INTERVAL:-6} @@ -43,7 +43,7 @@ services: environment: <<: *common-env CHAIN_NAME: ethereum - SOURCE_RPC_URL: ${ETH_SOURCE_RPC_URL} + SOURCE_RPC_URLS: ${ETH_SOURCE_RPC_URLS} PAYMASTER_VAULT_ADDRESS: ${ETH_VAULT_ADDRESS} POLLING_INTERVAL: ${ETH_POLLING_INTERVAL:-12} @@ -55,6 +55,6 @@ services: environment: <<: *common-env CHAIN_NAME: arbitrum - SOURCE_RPC_URL: ${ARB_SOURCE_RPC_URL} + SOURCE_RPC_URLS: ${ARB_SOURCE_RPC_URLS} PAYMASTER_VAULT_ADDRESS: ${ARB_VAULT_ADDRESS} POLLING_INTERVAL: ${ARB_POLLING_INTERVAL:-3} diff --git a/paymaster-relayer/compose.yaml b/paymaster-relayer/compose.yaml index 3b076fa..3028682 100644 --- a/paymaster-relayer/compose.yaml +++ b/paymaster-relayer/compose.yaml @@ -31,7 +31,7 @@ services: environment: <<: *common-env CHAIN_NAME: base - SOURCE_RPC_URL: ${BASE_SOURCE_RPC_URL} + SOURCE_RPC_URLS: ${BASE_SOURCE_RPC_URLS} PAYMASTER_VAULT_ADDRESS: ${BASE_VAULT_ADDRESS} POLLING_INTERVAL: ${BASE_POLLING_INTERVAL:-6} @@ -43,7 +43,7 @@ services: environment: <<: *common-env CHAIN_NAME: ethereum - SOURCE_RPC_URL: ${ETH_SOURCE_RPC_URL} + SOURCE_RPC_URLS: ${ETH_SOURCE_RPC_URLS} PAYMASTER_VAULT_ADDRESS: ${ETH_VAULT_ADDRESS} POLLING_INTERVAL: ${ETH_POLLING_INTERVAL:-12} @@ -55,6 +55,6 @@ services: environment: <<: *common-env CHAIN_NAME: arbitrum - SOURCE_RPC_URL: ${ARB_SOURCE_RPC_URL} + SOURCE_RPC_URLS: ${ARB_SOURCE_RPC_URLS} PAYMASTER_VAULT_ADDRESS: ${ARB_VAULT_ADDRESS} POLLING_INTERVAL: ${ARB_POLLING_INTERVAL:-3} diff --git a/paymaster-relayer/paymaster_relayer/__main__.py b/paymaster-relayer/paymaster_relayer/__main__.py index cb474db..e840670 100644 --- a/paymaster-relayer/paymaster_relayer/__main__.py +++ b/paymaster-relayer/paymaster_relayer/__main__.py @@ -19,7 +19,9 @@ async def main(): ) args = parser.parse_args() - logger.info(f"=== Paymaster Relayer Starting {'(LOCAL MODE)' if args.local else ''} ===") + logger.info( + f"=== Paymaster Relayer Starting {'(LOCAL MODE)' if args.local else ''} ===" + ) relayer = None @@ -29,11 +31,19 @@ async def main(): except ValueError as e: logger.error(f"Configuration error: {e}") logger.error("Required environment variables:") - logger.error(" - SOURCE_RPC_URL: Source chain RPC endpoint (e.g., Ethereum)") + logger.error( + " - SOURCE_RPC_URLS: Source chain RPC endpoints (comma-separated)" + ) logger.error(" - TARGET_RPC_URL: Target chain RPC endpoint (e.g., Sapphire)") - logger.error(" - PAYMASTER_VAULT_ADDRESS: PaymasterVault contract address (source)") - logger.error(" - PAYMASTER_PROXY_ADDRESS: CrossChainPaymaster contract address (target)") - logger.error(" - ROFL_ADAPTER_ADDRESS: ROFLAdapter contract address (target, HashStored)") + logger.error( + " - PAYMASTER_VAULT_ADDRESS: PaymasterVault contract address (source)" + ) + logger.error( + " - PAYMASTER_PROXY_ADDRESS: CrossChainPaymaster contract address (target)" + ) + logger.error( + " - ROFL_ADAPTER_ADDRESS: ROFLAdapter contract address (target, HashStored)" + ) if args.local: logger.error(" - PRIVATE_KEY: Private key for signing transactions") sys.exit(1) diff --git a/paymaster-relayer/paymaster_relayer/config.py b/paymaster-relayer/paymaster_relayer/config.py index 238abd9..3fe0453 100644 --- a/paymaster-relayer/paymaster_relayer/config.py +++ b/paymaster-relayer/paymaster_relayer/config.py @@ -9,12 +9,35 @@ import os from dataclasses import dataclass, field +from .utils.multi_rpc_provider import sanitize_url + + +def parse_rpc_urls() -> list[str]: + """ + Parse SOURCE_RPC_URLS env var into a list of RPC endpoint URLs. + + Splits on commas, strips whitespace, and filters empty entries. + + Returns: + List of non-empty, trimmed URLs + + Raises: + ValueError: If SOURCE_RPC_URLS is missing or contains no valid URLs + """ + raw = os.environ.get("SOURCE_RPC_URLS", "") + urls = [url.strip() for url in raw.split(",") if url.strip()] + + if not urls: + raise ValueError("SOURCE_RPC_URLS environment variable is missing or empty") + + return urls + @dataclass(frozen=True, slots=True) class SourceChainConfig: """Configuration for the source chain (e.g., Base/Sepolia).""" - rpc_url: str + rpc_urls: list[str] paymaster_vault_address: str @@ -37,7 +60,9 @@ class MonitoringConfig: retry_count: int = 3 lookback_blocks: int = 9 process_batch_size: int = 10 # max events to process in one batch - max_block_range: int = 10 # max blocks per get_logs request (Alchemy free tier limit) + max_block_range: int = ( + 10 # max blocks per get_logs request (Alchemy free tier limit) + ) def __post_init__(self) -> None: """Validate monitoring configuration.""" @@ -69,7 +94,6 @@ def __post_init__(self) -> None: f"Lookback blocks too high (max 1000), got {self.lookback_blocks}" ) - # Validate batch size if self.process_batch_size <= 0: raise ValueError( @@ -90,6 +114,7 @@ def __post_init__(self) -> None: f"Max block range too large (max 10000), got {self.max_block_range}" ) + @dataclass(frozen=True, slots=True) class RelayerConfig: """Main configuration class for the ROFL Relayer.""" @@ -110,13 +135,14 @@ def from_env(cls, local_mode: bool = False) -> "RelayerConfig": Raises: ValueError: If required environment variables are missing """ - # Source chain configuration - source_rpc_url = os.environ.get("SOURCE_RPC_URL") - if not source_rpc_url: + # Source chain configuration - parse comma-delimited RPC URLs + try: + source_rpc_urls = parse_rpc_urls() + except ValueError: raise ValueError( - "SOURCE_RPC_URL environment variable is required. " - "Example: https://ethereum-sepolia.publicnode.com" - ) + "SOURCE_RPC_URLS environment variable is required (comma-separated). " + "Example: SOURCE_RPC_URLS=https://rpc1.example.com,https://rpc2.example.com" + ) from None paymaster_vault_address = os.environ.get("PAYMASTER_VAULT_ADDRESS") if not paymaster_vault_address: @@ -173,7 +199,7 @@ def from_env(cls, local_mode: bool = False) -> "RelayerConfig": # Create configuration objects source_chain = SourceChainConfig( - rpc_url=source_rpc_url, + rpc_urls=source_rpc_urls, paymaster_vault_address=paymaster_vault_address, ) @@ -197,11 +223,13 @@ def log_config(self) -> None: print(f"Mode: {'LOCAL' if self.local_mode else 'ROFL'}") print("\n[Source Chain]") - print(f" RPC URL: {self.source_chain.rpc_url}") + print(f" RPC URLs ({len(self.source_chain.rpc_urls)} configured):") + for i, url in enumerate(self.source_chain.rpc_urls, 1): + print(f" [{i}] {sanitize_url(url)}") print(f" PaymasterVault: {self.source_chain.paymaster_vault_address}") print("\n[Target Chain]") - print(f" RPC URL: {self.target_chain.rpc_url}") + print(f" RPC URL: {sanitize_url(self.target_chain.rpc_url)}") print(f" CrossChainPaymaster: {self.target_chain.paymaster_address}") print(f" ROFLAdapter: {self.target_chain.rofl_adapter_address}") print( diff --git a/paymaster-relayer/paymaster_relayer/event_processor.py b/paymaster-relayer/paymaster_relayer/event_processor.py index 7367798..dee7f6e 100644 --- a/paymaster-relayer/paymaster_relayer/event_processor.py +++ b/paymaster-relayer/paymaster_relayer/event_processor.py @@ -182,7 +182,9 @@ async def process_hash_stored(self, event: EventData) -> tuple[int, str] | None: self.stored_hashes.popitem(last=False) self.stored_hashes[block_id] = block_hash - logger.info(f"Hash stored - Chain {domain} Block {block_id}: {block_hash[:10]}...") + logger.info( + f"Hash stored - Chain {domain} Block {block_id}: {block_hash[:10]}..." + ) matching_payments: list[PaymentEvent] = self.pending_payments.get( block_id, [] @@ -264,7 +266,9 @@ async def process_matched_payment(self, payment_event: PaymentEvent) -> bool: # This prevents race between HashStored handler and retry task if not self._remove_from_pending(payment_event): # Already removed by another task - skip to avoid duplicate submission - logger.debug(f"Payment {payment_event.tx_hash[:10]}... already being processed") + logger.debug( + f"Payment {payment_event.tx_hash[:10]}... already being processed" + ) return False paymaster_address = self.config.target_chain.paymaster_address diff --git a/paymaster-relayer/paymaster_relayer/proof_manager.py b/paymaster-relayer/paymaster_relayer/proof_manager.py index 32c7852..9a71c8e 100644 --- a/paymaster-relayer/paymaster_relayer/proof_manager.py +++ b/paymaster-relayer/paymaster_relayer/proof_manager.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .utils.contract_utility import ContractUtility + from .utils.multi_rpc_provider import MultiRpcProvider from .utils.rofl_utility import ROFLUtility logger = logging.getLogger(__name__) @@ -38,7 +39,7 @@ class ProofManager: def __init__( self, - w3_source: Web3, + source_provider: "MultiRpcProvider", contract_util: "ContractUtility", rofl_util: "ROFLUtility | None" = None, ): @@ -46,11 +47,11 @@ def __init__( Initialize the ProofManager. Args: - w3_source: Web3 instance for the source chain + source_provider: MultiRpcProvider for source chain with failover contract_util: Utility for contract interactions rofl_util: ROFL utility for transaction submission (optional) """ - self.w3_source = w3_source + self.source_provider = source_provider self.contract_util = contract_util self.rofl_util = rofl_util @@ -72,7 +73,9 @@ def _get_transaction_local_index(self, payment_event: PaymentEvent) -> int: Returns: Transaction-local index (position within transaction's logs) """ - receipt = self.w3_source.eth.get_transaction_receipt(HexStr(payment_event.tx_hash)) + receipt = self.source_provider.execute_with_failover( + lambda w3: w3.eth.get_transaction_receipt(HexStr(payment_event.tx_hash)) + ) if not receipt or "logs" not in receipt: logger.warning(f"No logs found in transaction {payment_event.tx_hash}") return 0 @@ -90,7 +93,9 @@ def _get_transaction_local_index(self, payment_event: PaymentEvent) -> int: return i # If not found (shouldn't happen), default to 0 - logger.warning("PaymentInitiated not found in transaction logs, defaulting to index 0") + logger.warning( + "PaymentInitiated not found in transaction logs, defaulting to index 0" + ) return 0 async def generate_proof(self, payment_event: PaymentEvent) -> list[Any]: @@ -109,18 +114,28 @@ async def generate_proof(self, payment_event: PaymentEvent) -> list[Any]: ValueError: If receipt or block not found, or proof generation fails """ # Calculate transaction-local log index from event content - log_index = self._get_transaction_local_index(payment_event) + log_index = await asyncio.to_thread( + self._get_transaction_local_index, payment_event + ) logger.info( f"Generating proof for tx {payment_event.tx_hash}, transaction-local log index {log_index}" ) # 1. Fetch receipt and block - receipt = self.w3_source.eth.get_transaction_receipt(HexStr(payment_event.tx_hash)) + receipt = await asyncio.to_thread( + self.source_provider.execute_with_failover, + lambda w3: w3.eth.get_transaction_receipt(HexStr(payment_event.tx_hash)), + ) if not receipt: - raise ValueError(f"Transaction receipt not found for {payment_event.tx_hash}") + raise ValueError( + f"Transaction receipt not found for {payment_event.tx_hash}" + ) block_number = receipt["blockNumber"] - block = self.w3_source.eth.get_block(block_number, full_transactions=True) + block = await asyncio.to_thread( + self.source_provider.execute_with_failover, + lambda w3: w3.eth.get_block(block_number, full_transactions=True), + ) if not block: raise ValueError(f"Block not found for block number {block_number}") @@ -129,7 +144,7 @@ async def generate_proof(self, payment_event: PaymentEvent) -> list[Any]: ) # 2. Get all receipts in block - receipts = self._get_block_receipts(block_number) + receipts = await asyncio.to_thread(self._get_block_receipts, block_number) logger.info(f"Fetched {len(receipts)} receipts from block") @@ -166,7 +181,10 @@ async def generate_proof(self, payment_event: PaymentEvent) -> list[Any]: # 6. Encode block header encoded_block_header = BlockchainEncoder.encode_block_header(block) - chain_id = int(self.w3_source.eth.chain_id) + chain_id = await asyncio.to_thread( + self.source_provider.execute_with_failover, + lambda w3: int(w3.eth.chain_id), + ) # 7. Create proof structure for Hashi proof = [ @@ -185,7 +203,9 @@ async def generate_proof(self, payment_event: PaymentEvent) -> list[Any]: ) return proof - async def submit_proof(self, proof: list[Any], paymaster_address: str) -> str | None: + async def submit_proof( + self, proof: list[Any], paymaster_address: str + ) -> str | None: """ Submit proof to CrossChainPaymaster contract. @@ -314,7 +334,9 @@ def _get_block_receipts(self, block_number: int) -> list[TxReceipt]: ValueError: If block receipts cannot be fetched """ try: - receipts = self.w3_source.eth.get_block_receipts(block_number) + receipts = self.source_provider.execute_with_failover( + lambda w3: w3.eth.get_block_receipts(block_number) + ) except Exception as e: logger.error(f"Failed to fetch receipts for block {block_number}: {e}") raise ValueError( diff --git a/paymaster-relayer/paymaster_relayer/relayer.py b/paymaster-relayer/paymaster_relayer/relayer.py index 238e14e..022f52f 100644 --- a/paymaster-relayer/paymaster_relayer/relayer.py +++ b/paymaster-relayer/paymaster_relayer/relayer.py @@ -10,12 +10,11 @@ import logging import os -from web3 import Web3 - from .config import RelayerConfig from .event_processor import EventProcessor from .proof_manager import ProofManager from .utils.contract_utility import ContractUtility +from .utils.multi_rpc_provider import MultiRpcProvider from .utils.polling_event_listener import PollingEventListener from .utils.rofl_utility import RoflUtility @@ -53,7 +52,9 @@ def __init__(self, config: RelayerConfig): self._init_utilities() # Get source chain ID for filtering HashStored events - source_chain_id = self.w3_source.eth.chain_id + source_chain_id = self.source_provider.execute_with_failover( + lambda w3: w3.eth.chain_id + ) self.event_processor = EventProcessor( proof_manager=self.proof_manager, @@ -69,12 +70,8 @@ def _init_utilities(self) -> None: """ Initialize utility classes for proof generation. """ - # Initialize Web3 for source chain - self.w3_source = Web3(Web3.HTTPProvider(self.config.source_chain.rpc_url)) - if not self.w3_source.is_connected(): - raise Exception( - f"Failed to connect to source chain at {self.config.source_chain.rpc_url}" - ) + # Initialize multi-RPC provider for source chain with failover + self.source_provider = MultiRpcProvider(self.config.source_chain.rpc_urls) # Initialize contract utility for target chain self.contract_util = ContractUtility( @@ -85,12 +82,14 @@ def _init_utilities(self) -> None: self.rofl_util = None if self.config.local_mode else RoflUtility() self.proof_manager = ProofManager( - w3_source=self.w3_source, + source_provider=self.source_provider, contract_util=self.contract_util, rofl_util=self.rofl_util, ) - source_chain_id = self.w3_source.eth.chain_id + source_chain_id = self.source_provider.execute_with_failover( + lambda w3: w3.eth.chain_id + ) logger.info( f"Paymaster Relayer initialized ({'LOCAL' if self.config.local_mode else 'ROFL'} mode, source chain: {source_chain_id})" ) @@ -121,9 +120,9 @@ async def init_event_monitoring(self) -> None: paymaster_vault_abi = self.contract_util.get_contract_abi("PaymasterVault") rofl_adapter_abi = self.contract_util.get_contract_abi("ROFLAdapter") - # Initialize PaymentInitiated event listener (source chain) + # Initialize PaymentInitiated event listener (source chain with multi-RPC failover) self.payment_listener = PollingEventListener( - rpc_url=self.config.source_chain.rpc_url, + provider=self.source_provider, contract_address=self.config.source_chain.paymaster_vault_address, event_name="PaymentInitiated", abi=paymaster_vault_abi, @@ -137,7 +136,7 @@ async def init_event_monitoring(self) -> None: # Initialize ROFLAdapter event listener (target chain - Sapphire) self.hash_listener = PollingEventListener( - rpc_url=self.config.target_chain.rpc_url, + provider=MultiRpcProvider([self.config.target_chain.rpc_url]), contract_address=self.config.target_chain.rofl_adapter_address, event_name="HashStored", abi=rofl_adapter_abi, @@ -250,3 +249,8 @@ def stop(self) -> None: """Stop the relayer service.""" self.running = False self.shutdown_event.set() + + # Signal providers to stop retrying + self.source_provider.shutdown() + if self.hash_listener: + self.hash_listener.provider.shutdown() diff --git a/paymaster-relayer/paymaster_relayer/utils/contract_utility.py b/paymaster-relayer/paymaster_relayer/utils/contract_utility.py index 08328ba..6799e0e 100644 --- a/paymaster-relayer/paymaster_relayer/utils/contract_utility.py +++ b/paymaster-relayer/paymaster_relayer/utils/contract_utility.py @@ -63,9 +63,7 @@ def get_contract_abi(self, contract_name: str) -> list[dict[str, Any]]: json.JSONDecodeError: If the contract file is invalid JSON """ contract_path: Path = ( - Path(__file__).parent.parent.parent - / "abis" - / f"{contract_name}.json" + Path(__file__).parent.parent.parent / "abis" / f"{contract_name}.json" ).resolve() with contract_path.open() as file: diff --git a/paymaster-relayer/paymaster_relayer/utils/multi_rpc_provider.py b/paymaster-relayer/paymaster_relayer/utils/multi_rpc_provider.py new file mode 100644 index 0000000..ec61548 --- /dev/null +++ b/paymaster-relayer/paymaster_relayer/utils/multi_rpc_provider.py @@ -0,0 +1,334 @@ +""" +Multi-RPC Web3 Provider with sequential failover. + +Wraps Web3 with automatic failover across multiple RPC endpoints. +""" + +import logging +import threading +from collections.abc import Callable +from typing import TypeVar +from urllib.parse import urlparse + +from requests.exceptions import ConnectionError as RequestsConnectionError +from requests.exceptions import Timeout as RequestsTimeout +from web3 import Web3 +from web3.exceptions import ProviderConnectionError, Web3RPCError + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +# Errors that indicate network/transport issues and should trigger failover. +# Application errors (ValueError, KeyError, etc.) are NOT included - they should propagate. +RETRYABLE_ERRORS = ( + Web3RPCError, + ProviderConnectionError, + RequestsConnectionError, + RequestsTimeout, + OSError, + TimeoutError, + ConnectionResetError, +) + + +def sanitize_url(url: str) -> str: + """Redact sensitive parts of URL for safe logging. + + Strips path, query params, and fragments to avoid leaking API keys. + Returns only scheme://host:port format. + """ + try: + parsed = urlparse(url) + # Include port if non-standard + netloc = parsed.hostname or "unknown" + if parsed.port: + netloc = f"{netloc}:{parsed.port}" + return f"{parsed.scheme}://{netloc}/***" + except Exception: + return "***redacted***" + + +class MultiRpcProvider: + """ + Web3 provider wrapper with sequential failover across multiple RPC endpoints. + + Automatically switches to the next provider when the current one fails. + Implements exponential backoff when all providers fail and cycles back. + """ + + def __init__( + self, + rpc_urls: list[str], + max_backoff: int = 30, + ) -> None: + """ + Initialize with a list of RPC URLs. + + Args: + rpc_urls: List of RPC endpoint URLs in priority order + max_backoff: Maximum backoff time in seconds (default 30) + + Raises: + ValueError: If rpc_urls is empty + Exception: If shutdown is requested during initialization + """ + if not rpc_urls: + raise ValueError("At least one RPC URL is required") + + self._rpc_urls = rpc_urls + self._max_backoff = max_backoff + self._current_index = 0 + self._web3: Web3 | None = None + self._cycle_count = 0 + self._expected_chain_id: int | None = None + self._shutdown_event = threading.Event() + + # Try to connect on initialization + self._connect_with_failover() + + # Cache chain ID from first successful connection for failover validation + try: + self._expected_chain_id = self._web3.eth.chain_id + logger.info(f"Connected to chain ID: {self._expected_chain_id}") + except Exception as e: + logger.warning(f"Could not determine chain ID on init: {e}") + + def _connect_with_failover(self) -> None: + """Attempt to connect, cycling through providers with backoff indefinitely.""" + attempts = 0 + total_providers = len(self._rpc_urls) + log_thresholds = [10, 50, 100] # Escalating warnings + + while not self._shutdown_event.is_set(): + url = self._rpc_urls[self._current_index] + logger.info(f"Connecting to RPC: {sanitize_url(url)}") + + try: + self._web3 = Web3(Web3.HTTPProvider(url)) + if self._web3.is_connected(): + logger.info(f"Connected to RPC: {sanitize_url(url)}") + self._cycle_count = 0 # Reset on successful connection + return + else: + logger.warning(f"Failed to connect to {sanitize_url(url)}") + except Exception as e: + logger.warning(f"Connection error for {sanitize_url(url)}: {e}") + + # Move to next provider + attempts += 1 + self._current_index = (self._current_index + 1) % total_providers + + # Check if we've completed a cycle + if self._current_index == 0 and attempts > 0: + self._cycle_count += 1 + backoff = min(2**self._cycle_count, self._max_backoff) + + # Progressive warning logs at thresholds + if self._cycle_count in log_thresholds: + logger.error( + f"All providers have failed for {self._cycle_count} cycles " + f"(~{self._cycle_count * backoff}s total). Still retrying..." + ) + elif ( + self._cycle_count > max(log_thresholds) + and self._cycle_count % 100 == 0 + ): + logger.error(f"Still failing after {self._cycle_count} cycles...") + else: + logger.warning( + f"All providers failed. Cycle {self._cycle_count}. " + f"Backing off for {backoff}s..." + ) + + # Interruptible sleep using Event.wait() + if self._shutdown_event.wait(timeout=backoff): + raise Exception("Shutdown requested during initialization") + + raise Exception("Shutdown requested during initialization") + + def _validate_chain_id(self, url: str) -> None: + """Validate that the current provider's chain ID matches the expected one. + + Called after each successful failover to catch misconfigured providers. + + Args: + url: The URL of the provider being validated (for error messages) + + Raises: + ValueError: If chain ID doesn't match the expected chain ID + """ + if self._expected_chain_id is None or self._web3 is None: + return + + try: + actual_chain_id = self._web3.eth.chain_id + if actual_chain_id != self._expected_chain_id: + raise ValueError( + f"Chain ID mismatch on failover to {sanitize_url(url)}: " + f"expected {self._expected_chain_id}, got {actual_chain_id}. " + "All providers must point to the same chain." + ) + except ValueError: + raise # Re-raise chain ID mismatches + except Exception as e: + logger.warning(f"Could not verify chain ID for {sanitize_url(url)}: {e}") + + @property + def current_provider_index(self) -> int: + """Return the index of the current active provider.""" + return self._current_index + + @property + def current_url(self) -> str: + """Return the URL of the current active provider.""" + return self._rpc_urls[self._current_index] + + @property + def current_url_sanitized(self) -> str: + """Return the sanitized URL of the current active provider (safe for logging).""" + return sanitize_url(self._rpc_urls[self._current_index]) + + @property + def chain_id(self) -> int: + """Return the chain ID from the current provider.""" + return self.get_web3().eth.chain_id + + def is_connected(self) -> bool: + """Check if currently connected to an RPC provider.""" + return self._web3 is not None and self._web3.is_connected() + + def get_web3(self) -> Web3: + """Return the current Web3 instance.""" + if self._web3 is None: + raise Exception("Not connected to any RPC provider") + return self._web3 + + def shutdown(self) -> None: + """Signal shutdown to interrupt retry loops gracefully.""" + self._shutdown_event.set() + logger.info("Shutdown signal received, interrupting retry loops...") + + def _failover_to_next(self) -> None: + """Switch to the next provider in the list, retrying indefinitely.""" + total_providers = len(self._rpc_urls) + start_index = self._current_index + log_thresholds = [10, 50, 100] + + while not self._shutdown_event.is_set(): + self._current_index = (self._current_index + 1) % total_providers + + # Track cycle completion before attempting provider + completed_cycle = self._current_index == start_index + if completed_cycle: + self._cycle_count += 1 + + url = self._rpc_urls[self._current_index] + logger.info(f"Failing over to: {sanitize_url(url)}") + + try: + self._web3 = Web3(Web3.HTTPProvider(url)) + if self._web3.is_connected(): + self._validate_chain_id(url) + logger.info(f"Successfully failed over to: {sanitize_url(url)}") + self._cycle_count = 0 # Reset on successful connection + return + except ValueError: + raise # Chain ID mismatch is fatal + except Exception as e: + logger.warning( + f"Failover connection error for {sanitize_url(url)}: {e}" + ) + + # Sleep AFTER confirming provider is still down on cycle completion + if completed_cycle: + backoff = min(2**self._cycle_count, self._max_backoff) + + # Progressive warning logs + if self._cycle_count in log_thresholds: + logger.error( + f"All providers have failed for {self._cycle_count} cycles " + f"(~{self._cycle_count * backoff}s total). Still retrying..." + ) + elif ( + self._cycle_count > max(log_thresholds) + and self._cycle_count % 100 == 0 + ): + logger.error(f"Still failing after {self._cycle_count} cycles...") + else: + logger.warning( + f"All providers failed. Cycle {self._cycle_count}. " + f"Backing off for {backoff}s..." + ) + + # Interruptible sleep + if self._shutdown_event.wait(timeout=backoff): + raise Exception("Shutdown requested during failover") + + raise Exception("Shutdown requested during failover") + + def _is_rate_limit_error(self, error: Exception) -> bool: + """Check if error is a rate limit (429) error.""" + if isinstance(error, Web3RPCError): + error_str = str(error).lower() + rpc_response = getattr(error, "rpc_response", {}) or {} + error_data = rpc_response.get("error", {}) + error_code = error_data.get("code", 0) + + return ( + "429" in error_str + or "rate" in error_str + or "too many" in error_str + or error_code == 429 + ) + return False + + def execute_with_failover( + self, + operation: Callable[[Web3], T], + ) -> T: + """ + Execute an operation with automatic failover on failure. + + Retries infinitely until success or shutdown signal. + + Args: + operation: Callable that takes a Web3 instance and returns a result + + Returns: + The result of the operation + + Raises: + Exception: If shutdown is requested + Other exceptions: Application errors propagate immediately + """ + rate_limit_retried = False + retries = 0 + + while not self._shutdown_event.is_set(): + try: + result = operation(self.get_web3()) + self._cycle_count = 0 # Reset on successful operation + return result + except RETRYABLE_ERRORS as e: + retries += 1 + logger.warning( + f"Operation failed on {self.current_url_sanitized} " + f"(retry {retries}): {e}" + ) + + # Handle rate limiting with one retry + if self._is_rate_limit_error(e) and not rate_limit_retried: + logger.info("Rate limited, retrying after 1s...") + if self._shutdown_event.wait(timeout=1): # Interruptible 1s wait + raise Exception( + "Shutdown requested during rate limit retry" + ) from None + rate_limit_retried = True + continue + + # Failover for rate limit (after retry), server errors, or other errors + rate_limit_retried = False # Reset for next provider + self._failover_to_next() + + raise Exception("Shutdown requested during operation execution") diff --git a/paymaster-relayer/paymaster_relayer/utils/polling_event_listener.py b/paymaster-relayer/paymaster_relayer/utils/polling_event_listener.py index 522bcbf..9e298a0 100644 --- a/paymaster-relayer/paymaster_relayer/utils/polling_event_listener.py +++ b/paymaster-relayer/paymaster_relayer/utils/polling_event_listener.py @@ -3,24 +3,30 @@ """ +from __future__ import annotations + import asyncio import logging from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any from web3 import Web3 from web3.types import EventData +if TYPE_CHECKING: + from .multi_rpc_provider import MultiRpcProvider + class PollingEventListener: """ Utility for polling blockchain events via HTTP RPC. + Uses MultiRpcProvider for automatic failover support. """ def __init__( self, - rpc_url: str, + provider: MultiRpcProvider, contract_address: str, event_name: str, abi: list[dict[str, Any]], @@ -31,29 +37,26 @@ def __init__( Initialize the polling event listener. Args: - rpc_url: HTTP RPC endpoint URL + provider: MultiRpcProvider instance for RPC failover contract_address: Address of the contract to monitor event_name: Name of the event to listen for abi: Contract ABI lookback_blocks: Number of blocks to look back on startup max_block_range: Max blocks per get_logs request (None = no limit) """ - self.rpc_url = rpc_url self.contract_address = Web3.to_checksum_address(contract_address) self.event_name = event_name self.lookback_blocks = lookback_blocks self.max_block_range = max_block_range + self.abi = abi + self.provider = provider - # Initialize Web3 connection - self.w3 = Web3(Web3.HTTPProvider(rpc_url)) - - # Create contract instance - self.contract = self.w3.eth.contract(address=self.contract_address, abi=abi) - - # Get the event object - if not hasattr(self.contract.events, event_name): + # Validate event exists in ABI + temp_contract = self.provider.get_web3().eth.contract( + address=self.contract_address, abi=abi + ) + if not hasattr(temp_contract.events, event_name): raise ValueError(f"Event {event_name} not found in contract ABI") - self.event_obj = getattr(self.contract.events, event_name) # State tracking self.last_processed_block: int | None = None @@ -62,20 +65,12 @@ def __init__( # Setup logging self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") - def _get_logs_chunked(self, from_block: int, to_block: int) -> list[EventData]: - """ - Fetch logs in chunks to respect RPC provider block range limits. - - Args: - from_block: Starting block number - to_block: Ending block number (inclusive) - - Returns: - List of all events across the block range - """ + def _get_logs_chunked_with_w3( + self, w3: Web3, from_block: int, to_block: int + ) -> list[EventData]: + """Fetch logs in chunks using a specific Web3 instance for consistency.""" if self.max_block_range is None or to_block - from_block < self.max_block_range: - # No chunking needed - return list(self.event_obj.get_logs(from_block=from_block, to_block=to_block)) + return self._get_logs_range_with_w3(w3, from_block, to_block) all_events: list[EventData] = [] current_from = from_block @@ -85,12 +80,47 @@ def _get_logs_chunked(self, from_block: int, to_block: int) -> list[EventData]: self.logger.debug( f"Fetching logs chunk: blocks {current_from}-{current_to}" ) - events = self.event_obj.get_logs(from_block=current_from, to_block=current_to) + events = self._get_logs_range_with_w3(w3, current_from, current_to) all_events.extend(events) current_from = current_to + 1 return all_events + def _get_logs_range_with_w3( + self, w3: Web3, from_block: int, to_block: int + ) -> list[EventData]: + """Get logs for a single block range using a specific Web3 instance.""" + contract = w3.eth.contract(address=self.contract_address, abi=self.abi) + event_obj = getattr(contract.events, self.event_name) + return list(event_obj.get_logs(from_block=from_block, to_block=to_block)) + + def _sync_cycle(self, w3: Web3, lookback: int) -> tuple[int, list[EventData]]: + """Execute a complete initial sync cycle using a single Web3 instance.""" + current_block = w3.eth.block_number + from_block = max(0, current_block - lookback) + events = self._get_logs_chunked_with_w3(w3, from_block, current_block) + return current_block, events + + def _poll_cycle(self, w3: Web3) -> tuple[int, list[EventData]]: + """Execute a complete poll cycle using a single Web3 instance. + + Ensures block number and logs come from the same provider, + preventing inconsistency if failover occurs between calls. + """ + current_block = w3.eth.block_number + + if self.last_processed_block and current_block <= self.last_processed_block: + return current_block, [] + + from_block = ( + (self.last_processed_block + 1) + if self.last_processed_block + else current_block + ) + + events = self._get_logs_chunked_with_w3(w3, from_block, current_block) + return current_block, events + async def initial_sync(self, callback: Callable[[EventData], Any]) -> None: """ Perform initial sync to catch up on recent events. @@ -99,17 +129,16 @@ async def initial_sync(self, callback: Callable[[EventData], Any]) -> None: callback: Async function to call for each event found """ try: - current_block = self.w3.eth.block_number - from_block = max(0, current_block - self.lookback_blocks) + # Single failover context for consistency + current_block, events = await asyncio.to_thread( + self.provider.execute_with_failover, + lambda w3: self._sync_cycle(w3, self.lookback_blocks), + ) self.logger.info( - f"Initial sync for {self.event_name} events " - f"from block {from_block} to {current_block}" + f"Initial sync for {self.event_name} events up to block {current_block}" ) - # Get historical events (chunked if max_block_range is set) - events = self._get_logs_chunked(from_block, current_block) - if events: self.logger.info( f"Found {len(events)} historical {self.event_name} events" @@ -134,25 +163,21 @@ async def poll_for_events(self, callback: Callable[[EventData], Any]) -> None: callback: Async function to call for each new event """ try: - current_block = self.w3.eth.block_number + # Single failover context ensures block number and + # logs come from the same provider (same sync state) + current_block, events = await asyncio.to_thread( + self.provider.execute_with_failover, + self._poll_cycle, + ) # Skip if no new blocks if self.last_processed_block and current_block <= self.last_processed_block: return - from_block = ( - (self.last_processed_block + 1) - if self.last_processed_block - else current_block - ) - - # Get new events (chunked if max_block_range is set) - events = self._get_logs_chunked(from_block, current_block) - if events: self.logger.info( f"Found {len(events)} new {self.event_name} events " - f"in blocks {from_block}-{current_block}" + f"in blocks up to {current_block}" ) for event in events: await callback(event) @@ -217,5 +242,5 @@ def get_status(self) -> dict[str, Any]: "last_processed_block": self.last_processed_block, "contract_address": self.contract_address, "event_name": self.event_name, - "rpc_url": self.rpc_url, + "rpc_url": self.provider.current_url_sanitized, } diff --git a/paymaster-relayer/tests/test_config.py b/paymaster-relayer/tests/test_config.py new file mode 100644 index 0000000..161c161 --- /dev/null +++ b/paymaster-relayer/tests/test_config.py @@ -0,0 +1,93 @@ +""" +Tests for config parsing of comma-delimited RPC URLs. + +Tests the parsing of SOURCE_RPC_URLS environment variable. +""" + +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import pytest + + +class TestParseRpcUrls: + """Tests for parse_rpc_urls function.""" + + def test_single_url(self, monkeypatch): + """Parse a single URL with no commas.""" + from paymaster_relayer.config import parse_rpc_urls + + monkeypatch.setenv("SOURCE_RPC_URLS", "https://rpc1.example.com") + + urls = parse_rpc_urls() + + assert urls == ["https://rpc1.example.com"] + + def test_multiple_urls(self, monkeypatch): + """Parse comma-separated URLs.""" + from paymaster_relayer.config import parse_rpc_urls + + monkeypatch.setenv( + "SOURCE_RPC_URLS", + "https://rpc1.example.com,https://rpc2.example.com,https://rpc3.example.com", + ) + + urls = parse_rpc_urls() + + assert urls == [ + "https://rpc1.example.com", + "https://rpc2.example.com", + "https://rpc3.example.com", + ] + + def test_whitespace_handling(self, monkeypatch): + """Spaces around commas are trimmed.""" + from paymaster_relayer.config import parse_rpc_urls + + monkeypatch.setenv( + "SOURCE_RPC_URLS", + " https://rpc1.example.com , https://rpc2.example.com ", + ) + + urls = parse_rpc_urls() + + assert urls == [ + "https://rpc1.example.com", + "https://rpc2.example.com", + ] + + def test_empty_entries_filtered(self, monkeypatch): + """Empty entries from consecutive/trailing commas are filtered out.""" + from paymaster_relayer.config import parse_rpc_urls + + monkeypatch.setenv( + "SOURCE_RPC_URLS", + "https://rpc1.example.com,,https://rpc2.example.com,", + ) + + urls = parse_rpc_urls() + + assert urls == [ + "https://rpc1.example.com", + "https://rpc2.example.com", + ] + + def test_error_when_not_set(self, monkeypatch): + """Error when SOURCE_RPC_URLS environment variable is missing.""" + from paymaster_relayer.config import parse_rpc_urls + + monkeypatch.delenv("SOURCE_RPC_URLS", raising=False) + + with pytest.raises(ValueError, match=r"missing or empty"): + parse_rpc_urls() + + def test_error_when_all_empty(self, monkeypatch): + """Error when value is only commas/whitespace.""" + from paymaster_relayer.config import parse_rpc_urls + + monkeypatch.setenv("SOURCE_RPC_URLS", ",,") + + with pytest.raises(ValueError, match=r"missing or empty"): + parse_rpc_urls() diff --git a/paymaster-relayer/tests/test_multi_rpc_integration.py b/paymaster-relayer/tests/test_multi_rpc_integration.py new file mode 100644 index 0000000..47b6418 --- /dev/null +++ b/paymaster-relayer/tests/test_multi_rpc_integration.py @@ -0,0 +1,233 @@ +""" +Tests for MultiRpcProvider integration with relayer components. + +Tests that relayer, event listener, and contract utility correctly use +the multi-RPC provider for failover. +""" + +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from unittest.mock import MagicMock, patch + + +class TestRelayerMultiRpcIntegration: + """Tests for relayer integration with MultiRpcProvider.""" + + def test_relayer_initializes_with_multi_rpc_provider(self, monkeypatch): + """Relayer should use MultiRpcProvider for source chain.""" + from paymaster_relayer.relayer import ROFLRelayer + + # Set required env vars + monkeypatch.setenv( + "SOURCE_RPC_URLS", "https://rpc1.example.com,https://rpc2.example.com" + ) + monkeypatch.setenv("TARGET_RPC_URL", "https://sapphire.example.com") + monkeypatch.setenv("PAYMASTER_VAULT_ADDRESS", "0x" + "1" * 40) + monkeypatch.setenv("PAYMASTER_PROXY_ADDRESS", "0x" + "2" * 40) + monkeypatch.setenv("ROFL_ADAPTER_ADDRESS", "0x" + "3" * 40) + monkeypatch.setenv("PRIVATE_KEY", "0x" + "4" * 64) + + with ( + patch("paymaster_relayer.relayer.MultiRpcProvider") as mock_provider, + patch("paymaster_relayer.relayer.ContractUtility"), + patch("paymaster_relayer.relayer.ProofManager"), + ): + mock_instance = MagicMock() + mock_instance.is_connected.return_value = True + mock_instance.get_web3.return_value = MagicMock() + mock_provider.return_value = mock_instance + + ROFLRelayer.from_env(local_mode=True) + + # Verify MultiRpcProvider was created with source chain URLs + mock_provider.assert_called_once() + call_args = mock_provider.call_args[0][0] + assert "https://rpc1.example.com" in call_args + assert "https://rpc2.example.com" in call_args + + def test_relayer_uses_provider_web3_instance(self, monkeypatch): + """Relayer should get Web3 instance from MultiRpcProvider.""" + from paymaster_relayer.relayer import ROFLRelayer + + monkeypatch.setenv("SOURCE_RPC_URLS", "https://rpc1.example.com") + monkeypatch.setenv("TARGET_RPC_URL", "https://sapphire.example.com") + monkeypatch.setenv("PAYMASTER_VAULT_ADDRESS", "0x" + "1" * 40) + monkeypatch.setenv("PAYMASTER_PROXY_ADDRESS", "0x" + "2" * 40) + monkeypatch.setenv("ROFL_ADAPTER_ADDRESS", "0x" + "3" * 40) + monkeypatch.setenv("PRIVATE_KEY", "0x" + "4" * 64) + + with ( + patch("paymaster_relayer.relayer.MultiRpcProvider") as mock_provider, + patch("paymaster_relayer.relayer.ContractUtility"), + patch("paymaster_relayer.relayer.ProofManager"), + ): + mock_web3 = MagicMock() + mock_instance = MagicMock() + mock_instance.is_connected.return_value = True + mock_instance.get_web3.return_value = mock_web3 + mock_provider.return_value = mock_instance + + relayer = ROFLRelayer.from_env(local_mode=True) + + # Verify relayer has source_provider attribute + assert relayer.source_provider is mock_instance + + +class TestPollingEventListenerMultiRpcIntegration: + """Tests for PollingEventListener integration with MultiRpcProvider.""" + + def test_relayer_passes_provider_web3_to_event_listener(self, monkeypatch): + """Relayer passes Web3 instance from MultiRpcProvider to PollingEventListener.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + # The integration pattern: relayer creates provider, gets Web3, passes to listener + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch( + "paymaster_relayer.utils.multi_rpc_provider.Web3" + ) as mock_web3_class: + mock_instance = MagicMock() + mock_instance.is_connected.return_value = True + mock_instance.eth.block_number = 12345 + mock_instance.eth.chain_id = 11155111 + mock_web3_class.return_value = mock_instance + + provider = MultiRpcProvider(urls) + w3 = provider.get_web3() + + # Verify the Web3 instance is from the provider + assert w3 is mock_instance + assert provider.current_provider_index == 0 + + def test_provider_failover_updates_web3_for_relayer(self): + """When provider fails over, subsequent get_web3() returns new instance.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch( + "paymaster_relayer.utils.multi_rpc_provider.Web3" + ) as mock_web3_class: + mock_instance1 = MagicMock() + mock_instance1.is_connected.return_value = True + mock_instance1.eth.block_number = 100 + mock_instance1.eth.chain_id = 11155111 + + mock_instance2 = MagicMock() + mock_instance2.is_connected.return_value = True + mock_instance2.eth.block_number = 200 + mock_instance2.eth.chain_id = 11155111 + + mock_web3_class.side_effect = [mock_instance1, mock_instance2] + + provider = MultiRpcProvider(urls) + + # First instance + w3_first = provider.get_web3() + assert w3_first.eth.block_number == 100 + + # Force failover + provider._failover_to_next() + + # New instance after failover + w3_after = provider.get_web3() + assert w3_after.eth.block_number == 200 + assert provider.current_provider_index == 1 + + +class TestFailoverDuringOperations: + """Tests for failover behavior during active operations.""" + + def test_failover_triggers_during_event_monitoring(self): + """Failover should trigger when provider fails during event fetch.""" + from web3.exceptions import Web3RPCError + + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch( + "paymaster_relayer.utils.multi_rpc_provider.Web3" + ) as mock_web3_class: + # First provider connects but fails on operation + mock_instance1 = MagicMock() + mock_instance1.is_connected.return_value = True + mock_instance1.eth.chain_id = 11155111 + + # Second provider works + mock_instance2 = MagicMock() + mock_instance2.is_connected.return_value = True + mock_instance2.eth.get_logs.return_value = [{"event": "test"}] + mock_instance2.eth.chain_id = 11155111 + + mock_web3_class.side_effect = [mock_instance1, mock_instance2] + + provider = MultiRpcProvider(urls) + + # Simulate failure during get_logs + call_count = [0] + + def failing_then_success(w3): + call_count[0] += 1 + if call_count[0] == 1: + raise Web3RPCError( + message="Server Error", + rpc_response={"error": {"code": -32000, "message": "Internal"}}, + ) + return w3.eth.get_logs({}) + + result = provider.execute_with_failover(failing_then_success) + + # Should have failed over and succeeded + assert provider.current_provider_index == 1 + assert result == [{"event": "test"}] + + def test_relayer_recovers_after_provider_switch(self, monkeypatch): + """Relayer should continue working after provider failover.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch( + "paymaster_relayer.utils.multi_rpc_provider.Web3" + ) as mock_web3_class: + mock_instance1 = MagicMock() + mock_instance1.is_connected.return_value = True + mock_instance1.eth.block_number = 100 + mock_instance1.eth.chain_id = 11155111 + + mock_instance2 = MagicMock() + mock_instance2.is_connected.return_value = True + mock_instance2.eth.block_number = 101 + mock_instance2.eth.chain_id = 11155111 + + mock_web3_class.side_effect = [mock_instance1, mock_instance2] + + provider = MultiRpcProvider(urls) + + # First operation works + result1 = provider.execute_with_failover(lambda w3: w3.eth.block_number) + assert result1 == 100 + assert provider.current_provider_index == 0 + + # Simulate first provider going down with a network error + # Note: Must use a retryable error type (OSError) - generic exceptions + # correctly propagate without triggering failover + mock_instance1.eth.block_number = property( + lambda self: (_ for _ in ()).throw(OSError("Connection lost")) + ) + + # Force failover by making the operation fail with a network error + def get_block_with_failover(w3): + if provider.current_provider_index == 0: + raise OSError("Connection lost") + return w3.eth.block_number + + result2 = provider.execute_with_failover(get_block_with_failover) + + # Should have switched to second provider + assert provider.current_provider_index == 1 + assert result2 == 101 diff --git a/paymaster-relayer/tests/test_multi_rpc_provider.py b/paymaster-relayer/tests/test_multi_rpc_provider.py new file mode 100644 index 0000000..e0d6b92 --- /dev/null +++ b/paymaster-relayer/tests/test_multi_rpc_provider.py @@ -0,0 +1,524 @@ +""" +Tests for MultiRpcProvider failover logic. + +Tests the sequential failover behavior across multiple RPC providers. +""" + +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from unittest.mock import MagicMock, patch + +import pytest +from web3.exceptions import Web3RPCError + + +class TestMultiRpcProvider: + """Tests for MultiRpcProvider class.""" + + def test_happy_path_first_provider_works(self): + """First provider works, no failover needed.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + # First provider connects successfully + mock_instance = MagicMock() + mock_instance.is_connected.return_value = True + mock_instance.eth.block_number = 12345 + mock_instance.eth.chain_id = 11155111 + mock_web3.return_value = mock_instance + + provider = MultiRpcProvider(urls) + + assert provider.current_provider_index == 0 + assert provider.is_connected() + # Should only have created one Web3 instance + assert mock_web3.call_count == 1 + + def test_single_failover_connection_error(self): + """First provider fails with connection error, second succeeds.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + # First call fails, second succeeds + mock_fail = MagicMock() + mock_fail.is_connected.return_value = False + + mock_success = MagicMock() + mock_success.is_connected.return_value = True + mock_success.eth.block_number = 12345 + mock_success.eth.chain_id = 11155111 + + mock_web3.side_effect = [mock_fail, mock_success] + + provider = MultiRpcProvider(urls) + + # Should have failed over to second provider + assert provider.current_provider_index == 1 + assert provider.is_connected() + + def test_multiple_failover_first_two_fail(self): + """First two providers fail, third succeeds.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = [ + "https://rpc1.example.com", + "https://rpc2.example.com", + "https://rpc3.example.com", + ] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + mock_fail1 = MagicMock() + mock_fail1.is_connected.return_value = False + + mock_fail2 = MagicMock() + mock_fail2.is_connected.return_value = False + + mock_success = MagicMock() + mock_success.is_connected.return_value = True + mock_success.eth.block_number = 12345 + mock_success.eth.chain_id = 11155111 + + mock_web3.side_effect = [mock_fail1, mock_fail2, mock_success] + + provider = MultiRpcProvider(urls) + + assert provider.current_provider_index == 2 + assert provider.is_connected() + + def test_http_5xx_triggers_immediate_failover(self): + """HTTP 5xx error triggers immediate failover without retry.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + mock_instance1 = MagicMock() + mock_instance1.is_connected.return_value = True + mock_instance1.eth.chain_id = 11155111 + + mock_instance2 = MagicMock() + mock_instance2.is_connected.return_value = True + mock_instance2.eth.block_number = 12345 + mock_instance2.eth.chain_id = 11155111 + + mock_web3.side_effect = [mock_instance1, mock_instance2] + + provider = MultiRpcProvider(urls) + + # Simulate 5xx error on first provider during operation + def raise_5xx(*args, **kwargs): + raise Web3RPCError( + message="Internal Server Error", + rpc_response={ + "error": {"code": -32000, "message": "Internal error"} + }, + ) + + mock_instance1.eth.get_block = raise_5xx + + # Execute with failover should switch to second provider + provider.execute_with_failover(lambda w3: w3.eth.get_block("latest")) + + assert provider.current_provider_index == 1 + + def test_rate_limit_429_retry_then_failover(self): + """Rate limit (429) retries once, then fails over.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + mock_instance1 = MagicMock() + mock_instance1.is_connected.return_value = True + mock_instance1.eth.chain_id = 11155111 + + mock_instance2 = MagicMock() + mock_instance2.is_connected.return_value = True + mock_instance2.eth.block_number = 99999 + mock_instance2.eth.chain_id = 11155111 + + mock_web3.side_effect = [mock_instance1, mock_instance2] + + # Track Event.wait calls + wait_calls = [] + + with patch("threading.Event") as mock_event_class: + mock_event = MagicMock() + mock_event.is_set.return_value = False + + def track_wait(timeout): + wait_calls.append(timeout) + return False # Don't signal shutdown + + mock_event.wait.side_effect = track_wait + mock_event_class.return_value = mock_event + + provider = MultiRpcProvider(urls) + + # Track call count to simulate retry then fail + call_count = [0] + + def raise_429(*args, **kwargs): + call_count[0] += 1 + raise Web3RPCError( + message="Too Many Requests", + rpc_response={ + "error": {"code": 429, "message": "Rate limited"} + }, + ) + + mock_instance1.eth.get_block = raise_429 + + # Execute should retry once (with 1s wait), then failover + provider.execute_with_failover(lambda w3: w3.eth.get_block("latest")) + + # Should have retried once before failover + assert call_count[0] == 2 # Initial + 1 retry + assert 1 in wait_calls # 1 second backoff for rate limit + assert provider.current_provider_index == 1 + + def test_full_cycle_with_exponential_backoff_then_recovery(self): + """All providers fail, cycle back with exponential backoff, then recover.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + # All providers fail initially, then first succeeds on retry + mock_fail1 = MagicMock() + mock_fail1.is_connected.return_value = False + + mock_fail2 = MagicMock() + mock_fail2.is_connected.return_value = False + + mock_success = MagicMock() + mock_success.is_connected.return_value = True + mock_success.eth.block_number = 12345 + mock_success.eth.chain_id = 11155111 + + # First cycle: fail, fail. Second cycle: success + mock_web3.side_effect = [mock_fail1, mock_fail2, mock_success] + + # Mock Event to capture backoff + backoff_captured = [] + + with patch("threading.Event") as mock_event_class: + mock_event = MagicMock() + mock_event.is_set.return_value = False + + def capture_wait(timeout): + backoff_captured.append(timeout) + return False # Don't signal shutdown + + mock_event.wait.side_effect = capture_wait + mock_event_class.return_value = mock_event + + provider = MultiRpcProvider(urls) + + # Should have cycled back to first with backoff + assert provider.current_provider_index == 0 + assert provider.is_connected() + + # Verify exponential backoff was applied (2 seconds for first cycle) + assert 2 in backoff_captured + + def test_backoff_capped_at_30_seconds(self): + """Exponential backoff is capped at 30 seconds (new default).""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + # Mock Event.wait to capture backoff values + backoff_values = [] + + def mock_wait(timeout): + backoff_values.append(timeout) + return False # Don't signal shutdown + + # Create many failing mocks, then one success + fails = [MagicMock() for _ in range(10)] + for f in fails: + f.is_connected.return_value = False + + success = MagicMock() + success.is_connected.return_value = True + success.eth.block_number = 12345 + success.eth.chain_id = 11155111 + + mock_web3.side_effect = [*fails, success] + + # Mock the shutdown event's wait method + with patch("threading.Event") as mock_event_class: + mock_event = MagicMock() + mock_event.is_set.return_value = False + mock_event.wait.side_effect = mock_wait + mock_event_class.return_value = mock_event + + MultiRpcProvider(urls) + + # Verify backoff was capped at 30 seconds + # Backoff sequence: 2, 4, 8, 16, 30, 30, 30, 30, 30, 30 + assert max(backoff_values) <= 30 + # Verify we hit the cap multiple times + assert backoff_values.count(30) >= 5 + + def test_current_url_property(self): + """current_url property returns the current provider URL.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + mock_instance = MagicMock() + mock_instance.is_connected.return_value = True + mock_instance.eth.chain_id = 11155111 + mock_web3.return_value = mock_instance + + provider = MultiRpcProvider(urls) + + assert provider.current_url == "https://rpc1.example.com" + + def test_get_web3_returns_current_instance(self): + """get_web3() returns the current Web3 instance.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + mock_instance = MagicMock() + mock_instance.is_connected.return_value = True + mock_instance.eth.chain_id = 11155111 + mock_web3.return_value = mock_instance + + provider = MultiRpcProvider(urls) + + assert provider.get_web3() is mock_instance + + def test_shutdown_interrupts_init_retry(self): + """Shutdown signal interrupts infinite retry during initialization.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + # All providers fail + mock_fail = MagicMock() + mock_fail.is_connected.return_value = False + mock_web3.return_value = mock_fail + + # Mock Event to simulate shutdown after first backoff + with patch("threading.Event") as mock_event_class: + mock_event = MagicMock() + shutdown_after_first_wait = [False] + + def wait_then_shutdown(timeout): + if not shutdown_after_first_wait[0]: + shutdown_after_first_wait[0] = True + return False # First wait - don't signal shutdown yet + return True # Subsequent waits - signal shutdown + + mock_event.is_set.return_value = False + mock_event.wait.side_effect = wait_then_shutdown + mock_event_class.return_value = mock_event + + with pytest.raises( + Exception, match="Shutdown requested during initialization" + ): + MultiRpcProvider(urls) + + def test_shutdown_interrupts_runtime_failover(self): + """Shutdown signal interrupts infinite retry during runtime failover.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + # First provider connects successfully + mock_success = MagicMock() + mock_success.is_connected.return_value = True + mock_success.eth.chain_id = 11155111 + + # Failover providers all fail + mock_fail = MagicMock() + mock_fail.is_connected.return_value = False + + mock_web3.side_effect = [mock_success, mock_fail, mock_fail] + + with patch("threading.Event") as mock_event_class: + mock_event = MagicMock() + failover_wait_count = [0] + + def wait_then_shutdown(timeout): + if failover_wait_count[0] < 1: + failover_wait_count[0] += 1 + return False # First wait during failover - continue + return True # Second wait - signal shutdown + + mock_event.is_set.return_value = False + mock_event.wait.side_effect = wait_then_shutdown + mock_event_class.return_value = mock_event + + provider = MultiRpcProvider(urls) + + # Trigger failover by raising retryable error + mock_success.eth.get_block.side_effect = Web3RPCError( + message="Server error", + rpc_response={ + "error": {"code": -32000, "message": "Internal error"} + }, + ) + + with pytest.raises( + Exception, match="Shutdown requested during failover" + ): + provider.execute_with_failover( + lambda w3: w3.eth.get_block("latest") + ) + + def test_progressive_logging_at_thresholds(self): + """Progressive ERROR logs appear at cycle thresholds 10, 50, 100.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com"] + + with ( + patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3, + patch("paymaster_relayer.utils.multi_rpc_provider.logger") as mock_logger, + patch("threading.Event") as mock_event_class, + ): + # Create failing providers for many cycles + fails = [MagicMock() for _ in range(15)] + for f in fails: + f.is_connected.return_value = False + + # Then success + success = MagicMock() + success.is_connected.return_value = True + success.eth.chain_id = 11155111 + mock_web3.side_effect = [*fails, success] + + # Mock Event to allow 15 cycles + mock_event = MagicMock() + mock_event.is_set.return_value = False + mock_event.wait.return_value = False # Never shutdown + mock_event_class.return_value = mock_event + + MultiRpcProvider(urls) + + # Check that ERROR logs appeared at threshold 10 + error_calls = [ + call + for call in mock_logger.error.call_args_list + if "10 cycles" in str(call) + ] + assert len(error_calls) >= 1 + + def test_execute_with_failover_infinite_retry(self): + """execute_with_failover retries infinitely until success.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com", "https://rpc2.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + # Initial connection succeeds + mock_instance1 = MagicMock() + mock_instance1.is_connected.return_value = True + mock_instance1.eth.chain_id = 11155111 + + # Failover instances all fail initially + mock_fail = MagicMock() + mock_fail.is_connected.return_value = False + + # Eventually one succeeds + mock_instance2 = MagicMock() + mock_instance2.is_connected.return_value = True + mock_instance2.eth.chain_id = 11155111 + mock_instance2.eth.get_block.return_value = {"number": 99999} + + # Sequence: init success, 5 failover fails, then failover success + mock_web3.side_effect = [ + mock_instance1, + mock_fail, + mock_fail, + mock_fail, + mock_fail, + mock_fail, + mock_instance2, + ] + + with patch("threading.Event") as mock_event_class: + mock_event = MagicMock() + mock_event.is_set.return_value = False + mock_event.wait.return_value = False # Never shutdown + mock_event_class.return_value = mock_event + + provider = MultiRpcProvider(urls) + + # Trigger failover by raising retryable error multiple times + call_count = [0] + + def raise_error_then_succeed(*args, **kwargs): + call_count[0] += 1 + if call_count[0] <= 5: + raise Web3RPCError( + message="Server error", + rpc_response={ + "error": {"code": -32000, "message": "Internal error"} + }, + ) + return {"number": 99999} + + mock_instance1.eth.get_block = raise_error_then_succeed + + # Should eventually succeed after many retries + result = provider.execute_with_failover( + lambda w3: w3.eth.get_block("latest") + ) + assert result["number"] == 99999 + + def test_shutdown_interrupts_rate_limit_retry(self): + """Shutdown signal interrupts rate limit retry sleep.""" + from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + urls = ["https://rpc1.example.com"] + + with patch("paymaster_relayer.utils.multi_rpc_provider.Web3") as mock_web3: + mock_instance = MagicMock() + mock_instance.is_connected.return_value = True + mock_instance.eth.chain_id = 11155111 + mock_web3.return_value = mock_instance + + with patch("threading.Event") as mock_event_class: + mock_event = MagicMock() + wait_count = [0] + + def wait_then_shutdown(timeout): + wait_count[0] += 1 + return wait_count[0] == 1 # First wait - signal shutdown + + mock_event.is_set.return_value = False + mock_event.wait.side_effect = wait_then_shutdown + mock_event_class.return_value = mock_event + + provider = MultiRpcProvider(urls) + + # Trigger rate limit error + mock_instance.eth.get_block.side_effect = Web3RPCError( + message="Too Many Requests", + rpc_response={"error": {"code": 429, "message": "Rate limited"}}, + ) + + with pytest.raises( + Exception, match="Shutdown requested during rate limit retry" + ): + provider.execute_with_failover( + lambda w3: w3.eth.get_block("latest") + ) diff --git a/paymaster-relayer/tests/test_proof_generation.py b/paymaster-relayer/tests/test_proof_generation.py index e72ec6b..1c27ceb 100644 --- a/paymaster-relayer/tests/test_proof_generation.py +++ b/paymaster-relayer/tests/test_proof_generation.py @@ -10,7 +10,9 @@ import os import sys from pathlib import Path +from unittest.mock import MagicMock, patch +import pytest from web3 import Web3 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -18,6 +20,37 @@ from paymaster_relayer.models import PaymentEvent from paymaster_relayer.proof_manager import ProofManager from paymaster_relayer.utils.contract_utility import ContractUtility +from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider + + +@pytest.fixture(autouse=True) +def mock_threading_event_for_tests(): + """ + Mock threading.Event to prevent infinite retry in tests. + + This fixture ensures that MultiRpcProvider instances created in tests + will fail fast instead of retrying infinitely when RPC connections fail. + """ + with patch( + "paymaster_relayer.utils.multi_rpc_provider.threading.Event" + ) as mock_event_class: + # Each test gets its own mock event instance with fresh counter + def create_mock_event(): + mock_event = MagicMock() + # Allow up to 10 wait calls per event instance + wait_count = [0] + + def limit_retries(timeout=None): + wait_count[0] += 1 + return wait_count[0] > 10 # Signal shutdown after 10 attempts + + mock_event.is_set.return_value = False + mock_event.wait.side_effect = limit_retries + return mock_event + + # Return a new mock event for each call + mock_event_class.side_effect = create_mock_event + yield mock_event_class async def test_proof_matches_typescript(): @@ -30,7 +63,9 @@ async def test_proof_matches_typescript(): proof_path = Path(__file__).parent.parent.parent / "pay" / "proof.json" if not proof_path.exists(): print(f"โŒ TypeScript proof not found at {proof_path}") - print(" Please run 'hardhat pay:generate-proof' in contracts to create proof.json") + print( + " Please run 'hardhat pay:generate-proof' in contracts to create proof.json" + ) return False with open(proof_path) as f: @@ -49,7 +84,7 @@ async def test_proof_matches_typescript(): # Initialize Web3 connection to source chain source_rpc = os.environ.get( - "SOURCE_RPC_URL", "https://ethereum-sepolia.publicnode.com" + "SOURCE_RPC_URLS", "https://ethereum-sepolia.publicnode.com" ) print(f"\n๐ŸŒ Connecting to source chain: {source_rpc}") @@ -67,7 +102,9 @@ async def test_proof_matches_typescript(): # Find the PaymentInitiated event in the logs # PaymentInitiated event signature - payment_topic = Web3.keccak(text="PaymentInitiated(address,address,address,uint256,bytes32)") + payment_topic = Web3.keccak( + text="PaymentInitiated(address,address,address,uint256,bytes32)" + ) payer = None event_block_number = None @@ -78,7 +115,9 @@ async def test_proof_matches_typescript(): payer_bytes = log["topics"][1][-20:] # Last 20 bytes is the address payer = Web3.to_checksum_address(payer_bytes) event_block_number = receipt["blockNumber"] - print(f" Found PaymentInitiated event - Payer: {payer}, Block: {event_block_number}") + print( + f" Found PaymentInitiated event - Payer: {payer}, Block: {event_block_number}" + ) break if payer is None and event_block_number is None: @@ -91,9 +130,10 @@ async def test_proof_matches_typescript(): rpc_url="http://localhost:8545" ) # Dummy URL for ABI-only mode - # Create ProofManager + # Create ProofManager with multi-RPC provider + source_provider = MultiRpcProvider([source_rpc]) proof_manager = ProofManager( - w3_source=web3_source, + source_provider=source_provider, contract_util=contract_util, rofl_util=None, # Testing without ROFL ) @@ -216,64 +256,66 @@ def normalize_hex(value): async def test_proof_generation_errors(): """ - Test error handling in proof generation. - """ - print("\n๐Ÿงช Testing error handling") + Test error handling in proof generation with mocked Web3 layer. - # Initialize Web3 connection - source_rpc = os.environ.get( - "SOURCE_RPC_URL", "https://ethereum-sepolia.publicnode.com" - ) - web3_source = Web3(Web3.HTTPProvider(source_rpc)) + This test verifies that ProofManager properly propagates errors from + the RPC layer without making real network connections. + """ + # Mock ContractUtility (only needs ABIs) + contract_util = ContractUtility(rpc_url="http://localhost:8545") - if not web3_source.is_connected(): - print("โš ๏ธ Skipping error tests - no connection to source chain") - return + # Mock MultiRpcProvider to avoid real network connections + mock_provider = MagicMock(spec=MultiRpcProvider) - # Initialize ProofManager - contract_util = ContractUtility( - rpc_url="http://localhost:8545" - ) # Dummy URL for ABI-only mode + # Create ProofManager with mocked provider proof_manager = ProofManager( - w3_source=web3_source, + source_provider=mock_provider, contract_util=contract_util, - rofl_util=None, # Testing without ROFL + rofl_util=None, ) - # Test with invalid transaction hash - print("\n๐Ÿ“ Testing with invalid transaction hash...") - try: - invalid_event = PaymentEvent( - tx_hash="0xinvalid", - block_number=0, - payer="0x0000000000000000000000000000000000000000", - recipient="0x0000000000000000000000000000000000000000", - token="0x0000000000000000000000000000000000000000", - amount=0, - ) + # Test 1: Invalid transaction hash + # Mock execute_with_failover to raise ValueError when w3.eth.get_transaction_receipt is called + def mock_invalid_hash_operation(operation): + mock_w3 = MagicMock() + mock_w3.eth.get_transaction_receipt.side_effect = ValueError("Invalid transaction hash format") + return operation(mock_w3) + + mock_provider.execute_with_failover.side_effect = mock_invalid_hash_operation + + invalid_event = PaymentEvent( + tx_hash="0xinvalid", + block_number=0, + payer="0x0000000000000000000000000000000000000000", + recipient="0x0000000000000000000000000000000000000000", + token="0x0000000000000000000000000000000000000000", + amount=0, + ) + + with pytest.raises(ValueError, match="Invalid transaction hash format"): await proof_manager.generate_proof(invalid_event) - print("โŒ Should have raised an error for invalid hash") - except Exception as e: - print(f"โœ… Correctly raised error: {type(e).__name__}") - # Test with non-existent transaction - print("\n๐Ÿ“ Testing with non-existent transaction...") - try: - fake_hash = "0x" + "0" * 64 - fake_event = PaymentEvent( - tx_hash=fake_hash, - block_number=0, - payer="0x0000000000000000000000000000000000000000", - recipient="0x0000000000000000000000000000000000000000", - token="0x0000000000000000000000000000000000000000", - amount=0, - ) - await proof_manager.generate_proof(fake_event) - print("โŒ Should have raised an error for non-existent tx") - except Exception as e: - print(f"โœ… Correctly raised error: {type(e).__name__}") + # Test 2: Non-existent transaction + # Mock execute_with_failover to raise ValueError for non-existent transaction + def mock_nonexistent_tx_operation(operation): + mock_w3 = MagicMock() + mock_w3.eth.get_transaction_receipt.side_effect = ValueError("Transaction not found") + return operation(mock_w3) + + mock_provider.execute_with_failover.side_effect = mock_nonexistent_tx_operation + + fake_hash = "0x" + "0" * 64 + fake_event = PaymentEvent( + tx_hash=fake_hash, + block_number=0, + payer="0x0000000000000000000000000000000000000000", + recipient="0x0000000000000000000000000000000000000000", + token="0x0000000000000000000000000000000000000000", + amount=0, + ) - print("\nโœ… Error handling tests completed") + with pytest.raises(ValueError, match="Transaction not found"): + await proof_manager.generate_proof(fake_event) async def main(): diff --git a/paymaster-relayer/tests/test_relayer.py b/paymaster-relayer/tests/test_relayer.py index 190a866..073f979 100644 --- a/paymaster-relayer/tests/test_relayer.py +++ b/paymaster-relayer/tests/test_relayer.py @@ -10,6 +10,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from paymaster_relayer.relayer import ROFLRelayer +from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider from paymaster_relayer.utils.polling_event_listener import PollingEventListener @@ -37,7 +38,7 @@ async def test_polling_listener_structure(): # Create listener instance listener = PollingEventListener( - rpc_url="https://ethereum-sepolia.publicnode.com", + provider=MultiRpcProvider(["https://ethereum-sepolia.publicnode.com"]), contract_address="0x0000000000000000000000000000000000000000", event_name="PaymentInitiated", abi=test_abi, @@ -61,7 +62,7 @@ async def test_relayer_with_real_contracts(): print("=" * 60) # Set up environment with real contract addresses - os.environ["SOURCE_RPC_URL"] = "https://ethereum-sepolia.publicnode.com" + os.environ["SOURCE_RPC_URLS"] = "https://ethereum-sepolia.publicnode.com" os.environ["TARGET_RPC_URL"] = "https://testnet.sapphire.oasis.io" os.environ["PAYMASTER_VAULT_ADDRESS"] = "0x0000000000000000000000000000000000000000" os.environ["PAYMASTER_PROXY_ADDRESS"] = "0x0000000000000000000000000000000000000000" diff --git a/paymaster-relayer/tests/test_retry_logic.py b/paymaster-relayer/tests/test_retry_logic.py index 301b176..5b72f43 100644 --- a/paymaster-relayer/tests/test_retry_logic.py +++ b/paymaster-relayer/tests/test_retry_logic.py @@ -26,6 +26,7 @@ FUTURE_PRICE_TIMESTAMP_ERROR_B64, ProofManager, ) +from paymaster_relayer.utils.multi_rpc_provider import MultiRpcProvider # Valid Ethereum addresses for testing TEST_ADDRESS = "0x0000000000000000000000000000000000000001" @@ -57,8 +58,15 @@ def mock_w3_source(self): @pytest.fixture def proof_manager(self, mock_w3_source, mock_contract_util, mock_rofl_util): """Create a ProofManager with mocked dependencies.""" + # Mock MultiRpcProvider + mock_provider = MagicMock(spec=MultiRpcProvider) + mock_provider.execute_with_failover = MagicMock( + side_effect=lambda op: op(mock_w3_source) + ) + mock_provider.get_web3.return_value = mock_w3_source + return ProofManager( - w3_source=mock_w3_source, + source_provider=mock_provider, contract_util=mock_contract_util, rofl_util=mock_rofl_util, ) @@ -156,7 +164,9 @@ async def test_future_price_timestamp_fails_after_max_retries( assert mock_sleep.call_count == FUTURE_PRICE_MAX_RETRIES - 1 # Should have tried max_retries times - assert mock_rofl_util.submit_tx.call_count == FUTURE_PRICE_MAX_RETRIES + assert ( + mock_rofl_util.submit_tx.call_count == FUTURE_PRICE_MAX_RETRIES + ) # Should return None (failure) assert result is None @@ -174,8 +184,15 @@ def proof_manager(self): mock_contract_util.w3.eth.gas_price = 1000000000 mock_w3_source = MagicMock() + # Mock MultiRpcProvider + mock_provider = MagicMock(spec=MultiRpcProvider) + mock_provider.execute_with_failover = MagicMock( + side_effect=lambda op: op(mock_w3_source) + ) + mock_provider.get_web3.return_value = mock_w3_source + return ProofManager( - w3_source=mock_w3_source, + source_provider=mock_provider, contract_util=mock_contract_util, rofl_util=mock_rofl_util, )