diff --git a/.changelog/calm-lakes-wink.md b/.changelog/calm-lakes-wink.md new file mode 100644 index 0000000..ba1278b --- /dev/null +++ b/.changelog/calm-lakes-wink.md @@ -0,0 +1,5 @@ +--- +pytempo: major +--- + +Added TIP-1011 `authorizeKey` support with `KeyRestrictions` struct (T3+) as the new `authorize_key` method, and renamed the previous flat-params variant to `authorize_key_legacy` for pre-T3 compatibility. Updated `IAccountKeychain` ABI with the new function signature and `LegacyAuthorizeKeySelectorChanged` error. diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0415ea6..3bf3ae7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -196,3 +196,4 @@ jobs: run: uv run pytest -v -s tests/test_integration.py env: TEMPO_RPC_URL: ${{ secrets.TEMPO_TESTNET_RPC_URL }} + TEMPO_HARDFORK: T2 diff --git a/README.md b/README.md index dce04df..2d01011 100644 --- a/README.md +++ b/README.md @@ -212,7 +212,7 @@ data = as_bytes("0xabcdef") # -> b'\xab\xcd\xef' ### `TempoTransaction` -Immutable, strongly-typed transaction (frozen dataclass). +Immutable, strongly-typed transaction (frozen attrs model). **Factory Methods:** @@ -237,6 +237,7 @@ Immutable, strongly-typed transaction (frozen dataclass). **Methods:** - `sign(private_key, for_fee_payer=False)` - Sign transaction (returns new instance) +- `sign_access_key(access_key_private_key, root_account)` - Sign with access key (returns new instance) - `encode()` - Encode to bytes for transmission - `hash()` - Get transaction hash - `get_signing_hash(for_fee_payer=False)` - Get hash to sign diff --git a/docs/guides/access-keys.md b/docs/guides/access-keys.md index 1784122..2d08caa 100644 --- a/docs/guides/access-keys.md +++ b/docs/guides/access-keys.md @@ -9,18 +9,18 @@ Create a {py:class}`~pytempo.KeyAuthorization`, sign it with the root account, a ```python from pytempo import ( TempoTransaction, Call, - create_key_authorization, SignatureType, + KeyAuthorization, SignatureType, TokenLimit, ) # Create authorization for a new access key -auth = create_key_authorization( +auth = KeyAuthorization( key_id="0xAccessKeyAddress...", chain_id=42429, key_type=SignatureType.SECP256K1, expiry=1893456000, # optional: expires ~2030 - limits=[ - {"token": "0xUSDCAddress...", "limit": 1000 * 10**6}, - ], + limits=( + TokenLimit(token="0xUSDCAddress...", limit=1000 * 10**6), + ), ) # Sign with root account @@ -32,16 +32,16 @@ tx = TempoTransaction.create( gas_limit=100_000, max_fee_per_gas=2_000_000_000, calls=(Call.create(to="0xRecipient...", value=1000),), - key_authorization=signed_auth.rlp_encode(), + key_authorization=signed_auth, ) ``` ## Signing with an access key -Use {py:func}`~pytempo.sign_tx_access_key` to sign a transaction as an access key holder: +Use {py:meth}`~pytempo.TempoTransaction.sign_access_key` to sign a transaction as an access key holder: ```python -from pytempo import TempoTransaction, Call, sign_tx_access_key +from pytempo import TempoTransaction, Call tx = TempoTransaction.create( chain_id=42429, @@ -50,8 +50,7 @@ tx = TempoTransaction.create( calls=(Call.create(to="0xRecipient...", value=1000),), ) -signed_tx = sign_tx_access_key( - tx, +signed_tx = tx.sign_access_key( access_key_private_key="0xAccessKeyPrivateKey...", root_account="0xRootAccountAddress...", ) @@ -78,9 +77,44 @@ remaining = AccountKeychain.get_remaining_limit( print(f"Remaining: {remaining}") ``` +## On-chain key authorization with call scopes (T3+) + +For on-chain key provisioning via the AccountKeychain precompile, you can restrict which contracts and functions the key is allowed to call: + +```python +from pytempo import CallScope, SignatureType +from pytempo.contracts import AccountKeychain, ALPHA_USD + +call = AccountKeychain.authorize_key( + key_id="0xAccessKeyAddress...", + signature_type=SignatureType.SECP256K1, + expiry=2**64 - 1, + allow_any_calls=False, + allowed_calls=( + CallScope.transfer(target=ALPHA_USD), + CallScope.approve(target=ALPHA_USD), + ), +) +``` + +Available call scope constructors: + +- `CallScope.unrestricted(target=...)` — allow all functions on a target +- `CallScope.transfer(target=...)` — allow `transfer(address,uint256)` on a TIP20 token +- `CallScope.approve(target=...)` — allow `approve(address,uint256)` on a TIP20 token +- `CallScope.transfer_with_memo(target=...)` — allow `transferWithMemo(address,uint256,bytes32)` on a TIP20 token + +```{note} +Before T3 is activated, pass ``legacy=True`` to use the pre-T3 encoding:: + + call = AccountKeychain.authorize_key(..., legacy=True) + +Remove ``legacy=True`` once T3 is live. +``` + ## Signature types -The {py:class}`~pytempo.SignatureType` constants define supported key types: +The {py:class}`~pytempo.SignatureType` enum defines supported key types: | Constant | Value | Description | |---|---|---| diff --git a/pytempo/__init__.py b/pytempo/__init__.py index 2f51165..3d10c06 100644 --- a/pytempo/__init__.py +++ b/pytempo/__init__.py @@ -24,7 +24,9 @@ INNER_SIGNATURE_LENGTH, KEYCHAIN_SIGNATURE_LENGTH, KEYCHAIN_SIGNATURE_TYPE, + CallScope, KeyAuthorization, + KeychainSignature, SignatureType, SignedKeyAuthorization, TokenLimit, @@ -42,10 +44,13 @@ Address, BytesLike, Hash32, + Selector, as_address, as_bytes, as_hash32, as_optional_address, + as_selector, + validate_nonempty_address, ) __version__ = "0.3.1" @@ -54,11 +59,14 @@ # Types "Address", "Hash32", + "Selector", "BytesLike", "as_address", "as_bytes", "as_hash32", "as_optional_address", + "as_selector", + "validate_nonempty_address", # Models "Call", "AccessListItem", @@ -72,11 +80,13 @@ "SignedKeyAuthorization", "SignatureType", "TokenLimit", - "create_key_authorization", - # Keychain signing + "CallScope", + "KeychainSignature", + # Keychain signing (deprecated free functions) "KEYCHAIN_SIGNATURE_TYPE", "KEYCHAIN_SIGNATURE_LENGTH", "INNER_SIGNATURE_LENGTH", "build_keychain_signature", + "create_key_authorization", "sign_tx_access_key", ] diff --git a/pytempo/contracts/abis/IAccountKeychain.json b/pytempo/contracts/abis/IAccountKeychain.json index e186edc..d9d31fd 100644 --- a/pytempo/contracts/abis/IAccountKeychain.json +++ b/pytempo/contracts/abis/IAccountKeychain.json @@ -1,14 +1,29 @@ [ + { + "inputs": [], + "name": "CallNotAllowed", + "type": "error" + }, { "inputs": [], "name": "ExpiryInPast", "type": "error" }, + { + "inputs": [], + "name": "InvalidCallScope", + "type": "error" + }, { "inputs": [], "name": "InvalidSignatureType", "type": "error" }, + { + "inputs": [], + "name": "InvalidSpendingLimit", + "type": "error" + }, { "inputs": [], "name": "KeyAlreadyExists", @@ -26,12 +41,34 @@ }, { "inputs": [], - "name": "KeyInactive", + "name": "KeyNotFound", "type": "error" }, { - "inputs": [], - "name": "KeyNotFound", + "inputs": [ + { + "internalType": "bytes4", + "name": "newSelector", + "type": "bytes4" + } + ], + "name": "LegacyAuthorizeKeySelectorChanged", + "type": "error" + }, + { + "inputs": [ + { + "internalType": "uint8", + "name": "expected", + "type": "uint8" + }, + { + "internalType": "uint8", + "name": "actual", + "type": "uint8" + } + ], + "name": "SignatureTypeMismatch", "type": "error" }, { @@ -49,6 +86,43 @@ "name": "ZeroPublicKey", "type": "error" }, + { + "anonymous": false, + "inputs": [ + { + "indexed": true, + "internalType": "address", + "name": "account", + "type": "address" + }, + { + "indexed": true, + "internalType": "address", + "name": "publicKey", + "type": "address" + }, + { + "indexed": true, + "internalType": "address", + "name": "token", + "type": "address" + }, + { + "indexed": false, + "internalType": "uint256", + "name": "amount", + "type": "uint256" + }, + { + "indexed": false, + "internalType": "uint256", + "name": "remainingLimit", + "type": "uint256" + } + ], + "name": "AccessKeySpend", + "type": "event" + }, { "anonymous": false, "inputs": [ @@ -165,7 +239,7 @@ "type": "uint256" } ], - "internalType": "struct IAccountKeychain.TokenLimit[]", + "internalType": "struct IAccountKeychain.LegacyTokenLimit[]", "name": "limits", "type": "tuple[]" } @@ -175,6 +249,150 @@ "stateMutability": "nonpayable", "type": "function" }, + { + "inputs": [ + { + "internalType": "address", + "name": "keyId", + "type": "address" + }, + { + "internalType": "enum IAccountKeychain.SignatureType", + "name": "signatureType", + "type": "uint8" + }, + { + "components": [ + { + "internalType": "uint64", + "name": "expiry", + "type": "uint64" + }, + { + "internalType": "bool", + "name": "enforceLimits", + "type": "bool" + }, + { + "components": [ + { + "internalType": "address", + "name": "token", + "type": "address" + }, + { + "internalType": "uint256", + "name": "amount", + "type": "uint256" + }, + { + "internalType": "uint64", + "name": "period", + "type": "uint64" + } + ], + "internalType": "struct IAccountKeychain.TokenLimit[]", + "name": "limits", + "type": "tuple[]" + }, + { + "internalType": "bool", + "name": "allowAnyCalls", + "type": "bool" + }, + { + "components": [ + { + "internalType": "address", + "name": "target", + "type": "address" + }, + { + "components": [ + { + "internalType": "bytes4", + "name": "selector", + "type": "bytes4" + }, + { + "internalType": "address[]", + "name": "recipients", + "type": "address[]" + } + ], + "internalType": "struct IAccountKeychain.SelectorRule[]", + "name": "selectorRules", + "type": "tuple[]" + } + ], + "internalType": "struct IAccountKeychain.CallScope[]", + "name": "allowedCalls", + "type": "tuple[]" + } + ], + "internalType": "struct IAccountKeychain.KeyRestrictions", + "name": "config", + "type": "tuple" + } + ], + "name": "authorizeKey", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "address", + "name": "account", + "type": "address" + }, + { + "internalType": "address", + "name": "keyId", + "type": "address" + } + ], + "name": "getAllowedCalls", + "outputs": [ + { + "internalType": "bool", + "name": "isScoped", + "type": "bool" + }, + { + "components": [ + { + "internalType": "address", + "name": "target", + "type": "address" + }, + { + "components": [ + { + "internalType": "bytes4", + "name": "selector", + "type": "bytes4" + }, + { + "internalType": "address[]", + "name": "recipients", + "type": "address[]" + } + ], + "internalType": "struct IAccountKeychain.SelectorRule[]", + "name": "selectorRules", + "type": "tuple[]" + } + ], + "internalType": "struct IAccountKeychain.CallScope[]", + "name": "scopes", + "type": "tuple[]" + } + ], + "stateMutability": "view", + "type": "function" + }, { "inputs": [ { @@ -248,13 +466,47 @@ "outputs": [ { "internalType": "uint256", - "name": "", + "name": "remaining", "type": "uint256" } ], "stateMutability": "view", "type": "function" }, + { + "inputs": [ + { + "internalType": "address", + "name": "account", + "type": "address" + }, + { + "internalType": "address", + "name": "keyId", + "type": "address" + }, + { + "internalType": "address", + "name": "token", + "type": "address" + } + ], + "name": "getRemainingLimitWithPeriod", + "outputs": [ + { + "internalType": "uint256", + "name": "remaining", + "type": "uint256" + }, + { + "internalType": "uint64", + "name": "periodEnd", + "type": "uint64" + } + ], + "stateMutability": "view", + "type": "function" + }, { "inputs": [], "name": "getTransactionKey", @@ -268,6 +520,24 @@ "stateMutability": "view", "type": "function" }, + { + "inputs": [ + { + "internalType": "address", + "name": "keyId", + "type": "address" + }, + { + "internalType": "address", + "name": "target", + "type": "address" + } + ], + "name": "removeAllowedCalls", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, { "inputs": [ { @@ -281,6 +551,48 @@ "stateMutability": "nonpayable", "type": "function" }, + { + "inputs": [ + { + "internalType": "address", + "name": "keyId", + "type": "address" + }, + { + "components": [ + { + "internalType": "address", + "name": "target", + "type": "address" + }, + { + "components": [ + { + "internalType": "bytes4", + "name": "selector", + "type": "bytes4" + }, + { + "internalType": "address[]", + "name": "recipients", + "type": "address[]" + } + ], + "internalType": "struct IAccountKeychain.SelectorRule[]", + "name": "selectorRules", + "type": "tuple[]" + } + ], + "internalType": "struct IAccountKeychain.CallScope[]", + "name": "scopes", + "type": "tuple[]" + } + ], + "name": "setAllowedCalls", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, { "inputs": [ { diff --git a/pytempo/contracts/account_keychain.py b/pytempo/contracts/account_keychain.py index 07dfcce..6756fba 100644 --- a/pytempo/contracts/account_keychain.py +++ b/pytempo/contracts/account_keychain.py @@ -3,13 +3,24 @@ Returns :class:`~pytempo.Call` objects ready to use in a :class:`~pytempo.TempoTransaction`:: + from pytempo import SignatureType from pytempo.contracts import AccountKeychain + # T3+ (default) call = AccountKeychain.authorize_key( key_id=access_key.address, - signature_type=0, + signature_type=SignatureType.SECP256K1, expiry=2**64 - 1, ) + + # Pre-T3 + call = AccountKeychain.authorize_key( + key_id=access_key.address, + signature_type=SignatureType.SECP256K1, + expiry=2**64 - 1, + legacy=True, + ) + tx = TempoTransaction.create(..., calls=(call,)) """ @@ -18,6 +29,7 @@ from eth_utils import to_checksum_address +from pytempo.keychain import CallScope, SignatureType from pytempo.models import Call from ._encode import encode_calldata @@ -39,26 +51,69 @@ class AccountKeychain: def authorize_key( *, key_id: str, - signature_type: int, + signature_type: SignatureType, expiry: int, enforce_limits: bool = False, - limits: Optional[Sequence[tuple[str, int]]] = None, + limits: Optional[Sequence[tuple[str, int] | tuple[str, int, int]]] = None, + allow_any_calls: bool = True, + allowed_calls: Optional[Sequence[CallScope]] = None, + legacy: bool = False, ) -> Call: - """Build an ``authorizeKey(address,uint8,uint64,bool,(address,uint256)[])`` call. + """Build an ``authorizeKey`` call. + + Uses the TIP-1011 ``KeyRestrictions`` struct encoding by default (T3+). + Pass ``legacy=True`` for the pre-T3 flat-parameter encoding. Args: key_id: The access key address to authorize. - signature_type: 0 = Secp256k1, 1 = P256, 2 = WebAuthn. + signature_type: Type of key being authorized (SignatureType.SECP256K1, P256, or WEBAUTHN) expiry: Unix timestamp when key expires (use ``2**64 - 1`` for never). enforce_limits: Whether to enforce spending limits. - limits: List of ``(token_address, amount)`` tuples for spending limits. + limits: List of ``(token_address, amount)`` or ``(token_address, amount, period)`` tuples. + Period defaults to 0 (one-time limit) if omitted. + allow_any_calls: Whether the key can call any contract (default True). + Ignored when ``legacy=True``. + allowed_calls: List of :class:`~pytempo.CallScope` restricting + which contracts/functions the key can call. + Only used when ``allow_any_calls`` is False. + Ignored when ``legacy=True``. + legacy: Use pre-T3 flat-parameter encoding. Pass ``True`` until T3 is activated, then remove this argument. """ - limit_tuples = list(limits) if limits else [] - data = encode_calldata( - _ABI, - "authorizeKey", - [key_id, signature_type, expiry, enforce_limits, limit_tuples], - ) + if legacy: + limit_tuples = ( + [(t, a) for t, a, *_ in ((*lim, 0)[:3] for lim in limits)] + if limits + else [] + ) + data = encode_calldata( + _ABI, + "authorizeKey", + [key_id, int(signature_type), expiry, enforce_limits, limit_tuples], + ) + else: + limit_tuples = ( + [(t, a, p) for t, a, p in ((*lim, 0)[:3] for lim in limits)] + if limits + else [] + ) + call_tuples = ( + [(bytes(s.target), [(bytes(s.selector), [])]) for s in allowed_calls] + if allowed_calls + else [] + ) + config = ( + expiry, + enforce_limits, + limit_tuples, + allow_any_calls, + call_tuples, + ) + data = encode_calldata( + _ABI, + "authorizeKey", + [key_id, int(signature_type), config], + ) + return Call.create(to=ACCOUNT_KEYCHAIN_ADDRESS, data=data) @staticmethod diff --git a/pytempo/keychain.py b/pytempo/keychain.py index bf48b04..f576c1a 100644 --- a/pytempo/keychain.py +++ b/pytempo/keychain.py @@ -24,43 +24,62 @@ Format: ``[chain_id, key_type, key_id, expiry?, limits?]`` """ -from dataclasses import dataclass -from typing import Optional +from __future__ import annotations +from enum import IntEnum +from typing import TYPE_CHECKING, ClassVar + +import attrs import rlp from eth_account import Account -from eth_utils import keccak, to_bytes, to_checksum_address -from rlp.sedes import Binary, big_endian_int +from eth_utils import keccak, to_checksum_address -# RLP sedes -address_sedes = Binary.fixed_length(20, allow_empty=True) -uint256_sedes = big_endian_int +from .types import ( + Address, + BytesLike, + Selector, + as_address, + as_selector, + validate_nonempty_address, +) + +if TYPE_CHECKING: + from .models import Signature, TempoTransaction + +# --------------------------------------------------------------------------- +# Signature type enum +# --------------------------------------------------------------------------- -# Keychain signature constants -KEYCHAIN_SIGNATURE_TYPE = 0x04 -INNER_SIGNATURE_LENGTH = 65 # r (32) + s (32) + v (1) -KEYCHAIN_SIGNATURE_LENGTH = 86 # type (1) + address (20) + inner (65) +class SignatureType(IntEnum): + """Signature type for access keys.""" -# Signature types for access keys -class SignatureType: - """Signature type constants for access keys.""" + SECP256K1 = 0 + P256 = 1 + WEBAUTHN = 2 - SECP256K1 = 0 # Standard Ethereum signature - P256 = 1 # NIST P-256 / secp256r1 (passkeys) - WEBAUTHN = 2 # WebAuthn/FIDO2 + def to_json_name(self) -> str: + return _SIG_TYPE_JSON_NAMES[self] -class TokenLimitRLP(rlp.Serializable): - """RLP serializable token spending limit.""" +_SIG_TYPE_JSON_NAMES = { + SignatureType.SECP256K1: "secp256k1", + SignatureType.P256: "p256", + SignatureType.WEBAUTHN: "webAuthn", +} - fields = [ - ("token", address_sedes), - ("limit", uint256_sedes), - ] +# --------------------------------------------------------------------------- +# Token limit +# --------------------------------------------------------------------------- -@dataclass + +def _validate_u256(instance: object, attribute: object, value: int) -> None: + if not (0 <= value <= 2**256 - 1): + raise ValueError(f"limit must be in [0, 2**256 - 1], got {value}") + + +@attrs.define(frozen=True) class TokenLimit: """Token spending limit for access keys. @@ -72,18 +91,106 @@ class TokenLimit: limit: Maximum spending amount for this token (enforced over the key's lifetime) """ - token: str - limit: int + token: Address = attrs.field( + converter=as_address, validator=validate_nonempty_address + ) + limit: int = attrs.field(validator=_validate_u256) + + def to_rlp(self) -> list: + return [bytes(self.token), self.limit] + + +# --------------------------------------------------------------------------- +# Call scope +# --------------------------------------------------------------------------- + +_WILDCARD_SELECTOR = Selector(b"\x00\x00\x00\x00") +_TIP20_PREFIX = bytes.fromhex("20C000000000000000000000") + +# Allowed TIP-20 selectors for call-scoped access keys. +_TIP20_TRANSFER = Selector(bytes.fromhex("a9059cbb")) +_TIP20_APPROVE = Selector(bytes.fromhex("095ea7b3")) +_TIP20_TRANSFER_WITH_MEMO = Selector(bytes.fromhex("95777d59")) + + +def _validate_tip20_address(target: BytesLike) -> Address: + addr = as_address(target) + if not bytes(addr).startswith(_TIP20_PREFIX): + raise ValueError( + f"target must be a TIP20 address (prefix 0x20C0...00), " + f"got 0x{bytes(addr)[:12].hex()}" + ) + return addr + + +@attrs.define(frozen=True) +class CallScope: + """Call scope restriction for access keys (TIP-1011). + + Restricts an access key to only call specific contract functions. + Used in ``AccountKeychain.authorize_key()`` when ``allow_any_calls`` is False. - def to_rlp(self) -> TokenLimitRLP: - """Convert to RLP-serializable format.""" - token_bytes = ( - to_bytes(hexstr=self.token) if isinstance(self.token, str) else self.token + Construct via the named constructors: + + - ``CallScope.unrestricted(target=...)`` — allow all functions on a target. + - ``CallScope.transfer(target=...)`` — allow ``transfer`` on a TIP20 token. + - ``CallScope.approve(target=...)`` — allow ``approve`` on a TIP20 token. + - ``CallScope.transfer_with_memo(target=...)`` — allow ``transferWithMemo`` + on a TIP20 token. + + Args: + target: Contract address the key is allowed to call. + selector: 4-byte function selector. Only applicable for TIP20 tokens. + """ + + target: Address = attrs.field( + converter=as_address, validator=validate_nonempty_address + ) + selector: Selector = attrs.field(converter=as_selector) + + @classmethod + def unrestricted(cls, *, target: BytesLike) -> CallScope: + """Allow all functions on a target (any contract, including TIP20).""" + return cls(target=target, selector=_WILDCARD_SELECTOR) + + @classmethod + def transfer(cls, *, target: BytesLike) -> CallScope: + """Allow ``transfer(address,uint256)`` on a TIP20 token target.""" + return cls(target=_validate_tip20_address(target), selector=_TIP20_TRANSFER) + + @classmethod + def approve(cls, *, target: BytesLike) -> CallScope: + """Allow ``approve(address,uint256)`` on a TIP20 token target.""" + return cls(target=_validate_tip20_address(target), selector=_TIP20_APPROVE) + + @classmethod + def transfer_with_memo(cls, *, target: BytesLike) -> CallScope: + """Allow ``transferWithMemo(address,uint256,bytes32)`` on a TIP20 token target.""" + return cls( + target=_validate_tip20_address(target), + selector=_TIP20_TRANSFER_WITH_MEMO, ) - return TokenLimitRLP(token=token_bytes, limit=self.limit) -@dataclass +# --------------------------------------------------------------------------- +# Key authorization +# --------------------------------------------------------------------------- + + +def _convert_limits( + value: tuple[TokenLimit, ...] | list[TokenLimit] | None, +) -> tuple[TokenLimit, ...] | None: + return None if value is None else tuple(value) + + +def _validate_optional_expiry( + instance: object, attribute: object, value: int | None +) -> None: + if value is not None and value < 0: + raise ValueError(f"expiry must be >= 0, got {value}") + + +@attrs.define(frozen=True) class KeyAuthorization: """Key authorization for provisioning access keys. @@ -91,69 +198,76 @@ class KeyAuthorization: The transaction must be signed by the root key to authorize adding this access key. Args: + key_id: Key identifier (address derived from the public key) chain_id: Chain ID for replay protection (0 = valid on any chain) key_type: Type of key being authorized (SignatureType.SECP256K1, P256, or WEBAUTHN) - key_id: Key identifier (address derived from the public key) expiry: Unix timestamp when key expires (None = never expires) - limits: Token spending limits (None = unlimited, [] = no spending, [...] = specific limits) + limits: Token spending limits (None = unlimited, () = no spending, tuple of :class:`TokenLimit` = specific limits) """ - chain_id: int - key_type: int - key_id: str - expiry: Optional[int] = None - limits: Optional[list[TokenLimit]] = None - - def rlp_encode(self) -> bytes: - """RLP encode the key authorization. - - Format: [chain_id, key_type, key_id, expiry?, limits?] - - expiry and limits are optional trailing fields - """ - key_id_bytes = ( - to_bytes(hexstr=self.key_id) - if isinstance(self.key_id, str) - else self.key_id - ) + key_id: Address = attrs.field( + converter=as_address, validator=validate_nonempty_address + ) + chain_id: int = attrs.field(default=0, validator=attrs.validators.ge(0)) + key_type: SignatureType = attrs.field( + default=SignatureType.SECP256K1, converter=SignatureType + ) + expiry: int | None = attrs.field(default=None, validator=_validate_optional_expiry) + limits: tuple[TokenLimit, ...] | None = attrs.field( + default=None, converter=_convert_limits + ) + def as_rlp_payload(self) -> list: + """Return the RLP-encodable list representation.""" # Build list with required fields - items: list = [self.chain_id, self.key_type, key_id_bytes] + items: list = [self.chain_id, int(self.key_type), bytes(self.key_id)] # Add optional trailing fields - if self.expiry is not None or self.limits is not None: - items.append(self.expiry if self.expiry is not None else b"") + # expiry=0 is treated the same as expiry=None (never expires) + has_expiry = self.expiry is not None and self.expiry != 0 + if has_expiry or self.limits is not None: + items.append(self.expiry if has_expiry else b"") if self.limits is not None: items.append([limit.to_rlp() for limit in self.limits]) - return rlp.encode(items) + return items + + def rlp_encode(self) -> bytes: + """RLP encode the key authorization.""" + return rlp.encode(self.as_rlp_payload()) def signature_hash(self) -> bytes: """Compute the authorization message hash for signing.""" return keccak(self.rlp_encode()) - def sign(self, private_key: str) -> "SignedKeyAuthorization": + def sign(self, private_key: str) -> SignedKeyAuthorization: """Sign the key authorization with the root account's private key. Args: - private_key: Root account private key (hex string with 0x prefix) + private_key: Root account private key (hex string, as used by ``Account.from_key``) Returns: SignedKeyAuthorization that can be attached to a transaction """ + from .models import Signature + msg_hash = self.signature_hash() account = Account.from_key(private_key) signed = account.unsafe_sign_hash(msg_hash) return SignedKeyAuthorization( authorization=self, - v=signed.v, - r=signed.r, - s=signed.s, + signature=Signature(r=signed.r, s=signed.s, v=signed.v), ) -@dataclass +# --------------------------------------------------------------------------- +# Signed key authorization +# --------------------------------------------------------------------------- + + +@attrs.define(frozen=True) class SignedKeyAuthorization: """Signed key authorization that can be attached to a transaction. @@ -161,53 +275,30 @@ class SignedKeyAuthorization: """ authorization: KeyAuthorization - v: int - r: int - s: int - - def rlp_encode(self) -> bytes: - """RLP encode the signed key authorization. - - Format: [[chain_id, key_type, key_id, expiry?, limits?], signature_bytes] - - The KeyAuthorization is encoded as a nested list, then signature as bytes. - This matches the Rust SignedKeyAuthorization struct with #[rlp(trailing)]. - """ - key_id_bytes = ( - to_bytes(hexstr=self.authorization.key_id) - if isinstance(self.authorization.key_id, str) - else self.authorization.key_id - ) + signature: Signature # from .models - # Build authorization as nested list - auth_items: list = [ - self.authorization.chain_id, - self.authorization.key_type, - key_id_bytes, - ] + @property + def v(self) -> int: + """Signature v value (deprecated, use ``self.signature.v``).""" + return self.signature.v - # Add optional trailing fields (expiry, limits) - if ( - self.authorization.expiry is not None - or self.authorization.limits is not None - ): - auth_items.append( - self.authorization.expiry - if self.authorization.expiry is not None - else b"" - ) + @property + def r(self) -> int: + """Signature r value (deprecated, use ``self.signature.r``).""" + return self.signature.r - if self.authorization.limits is not None: - auth_items.append([limit.to_rlp() for limit in self.authorization.limits]) + @property + def s(self) -> int: + """Signature s value (deprecated, use ``self.signature.s``).""" + return self.signature.s - # Build signature bytes: r (32) || s (32) || v (1) = 65 bytes - r_bytes = self.r.to_bytes(32, "big") - s_bytes = self.s.to_bytes(32, "big") - v_byte = bytes([self.v]) - signature_bytes = r_bytes + s_bytes + v_byte + def as_rlp_payload(self) -> list: + """Return the RLP-encodable list representation.""" + return [self.authorization.as_rlp_payload(), self.signature.to_bytes()] - # Encode as [auth_list, signature_bytes] - return rlp.encode([auth_items, signature_bytes]) + def rlp_encode(self) -> bytes: + """RLP encode the signed key authorization.""" + return rlp.encode(self.as_rlp_payload()) def to_json(self) -> dict: """Convert to JSON format for eth_estimateGas and other RPC calls. @@ -215,19 +306,15 @@ def to_json(self) -> dict: Returns: Dict with camelCase keys matching Tempo's JSON-RPC format. """ - KEY_TYPE_NAMES = {0: "secp256k1", 1: "p256", 2: "webAuthn"} - - result = { + result: dict = { "chainId": hex(self.authorization.chain_id), - "keyType": KEY_TYPE_NAMES.get( - self.authorization.key_type, str(self.authorization.key_type) - ), - "keyId": self.authorization.key_id, + "keyType": self.authorization.key_type.to_json_name(), + "keyId": to_checksum_address(bytes(self.authorization.key_id)), "signature": { "type": "secp256k1", - "r": hex(self.r), - "s": hex(self.s), - "v": self.v, + "r": hex(self.signature.r), + "s": hex(self.signature.s), + "v": self.signature.v, }, } @@ -236,58 +323,131 @@ def to_json(self) -> dict: if self.authorization.limits is not None: result["limits"] = [ - {"token": limit.token, "limit": hex(limit.limit)} + { + "token": to_checksum_address(bytes(limit.token)), + "limit": hex(limit.limit), + } for limit in self.authorization.limits ] return result def recover_signer(self) -> str: - """Recover the address that signed this authorization. + """Recover the checksummed address that signed this authorization.""" + msg_hash = self.authorization.signature_hash() + recovered = Account._recover_hash( + msg_hash, + vrs=(self.signature.v, self.signature.r, self.signature.s), + ) + return to_checksum_address(recovered) - Returns: - Checksummed address of the signer (root account) + +# --------------------------------------------------------------------------- +# Keychain signature (0x04 envelope) +# --------------------------------------------------------------------------- + + +@attrs.define(frozen=True) +class KeychainSignature: + """Keychain V2 signature: ``0x04 || root_account (20) || inner (65)``. + + Args: + root_account: Address of the root account the access key signs on behalf of. + inner: The secp256k1 signature from the access key. + """ + + TYPE_BYTE: ClassVar[int] = 0x04 + LENGTH: ClassVar[int] = 86 + + root_account: Address = attrs.field( + converter=as_address, validator=validate_nonempty_address + ) + inner: Signature # from .models + + def to_bytes(self) -> bytes: + return ( + bytes([self.TYPE_BYTE]) + bytes(self.root_account) + self.inner.to_bytes() + ) + + @classmethod + def from_bytes(cls, raw: BytesLike) -> KeychainSignature: + """Parse a 86-byte keychain signature.""" + from .models import Signature + from .types import as_bytes as _as_bytes + + b = _as_bytes(raw) + if len(b) != cls.LENGTH: + raise ValueError( + f"keychain signature must be {cls.LENGTH} bytes, got {len(b)}" + ) + if b[0] != cls.TYPE_BYTE: + raise ValueError(f"expected type byte 0x04, got {b[0]:#04x}") + return cls( + root_account=b[1:21], + inner=Signature.from_bytes(b[21:86]), + ) + + @classmethod + def sign( + cls, + msg_hash: bytes, + access_key_private_key: str, + root_account: BytesLike, + ) -> KeychainSignature: + """Build a Keychain V2 signature for a message hash. + + The access key signs ``keccak256(0x04 || sig_hash || user_address)`` + instead of the raw sig_hash, providing domain separation. + + Args: + msg_hash: 32-byte transaction signature hash. + access_key_private_key: Private key of the access key (hex string, + as used by ``Account.from_key``). + root_account: Address of the root account (hex string or bytes). """ - msg_hash = self.authorization.signature_hash() + from .models import Signature - # Reconstruct signature for recovery - signature = ( - self.r.to_bytes(32, "big") + self.s.to_bytes(32, "big") + bytes([self.v]) + if len(msg_hash) != 32: + raise ValueError(f"msg_hash must be 32 bytes, got {len(msg_hash)}") + + root_bytes = as_address(root_account) + + signing_hash = keccak(bytes([cls.TYPE_BYTE]) + msg_hash + bytes(root_bytes)) + + account = Account.from_key(access_key_private_key) + signed_msg = account.unsafe_sign_hash(signing_hash) + + return cls( + root_account=root_bytes, + inner=Signature(r=signed_msg.r, s=signed_msg.s, v=signed_msg.v), ) - recovered = Account._recover_hash(msg_hash, signature=signature) - return to_checksum_address(recovered) + +# --------------------------------------------------------------------------- +# Convenience aliases / constants for backwards compat +# --------------------------------------------------------------------------- + +KEYCHAIN_SIGNATURE_TYPE = KeychainSignature.TYPE_BYTE +INNER_SIGNATURE_LENGTH = 65 # r (32) + s (32) + v (1) +KEYCHAIN_SIGNATURE_LENGTH = KeychainSignature.LENGTH + + +# --------------------------------------------------------------------------- +# Deprecated free functions — thin wrappers for backwards compat +# --------------------------------------------------------------------------- def create_key_authorization( key_id: str, chain_id: int = 0, key_type: int = SignatureType.SECP256K1, - expiry: Optional[int] = None, - limits: Optional[list[dict]] = None, + expiry: int | None = None, + limits: list[dict] | None = None, ) -> KeyAuthorization: """Create a key authorization for provisioning an access key. - Args: - key_id: Address of the access key to authorize - chain_id: Chain ID for replay protection (0 = valid on any chain) - key_type: Signature type (default: SECP256K1) - expiry: Unix timestamp when key expires (None = never expires) - limits: List of token limits as dicts with ``token`` and ``limit`` keys. - Use ``None`` for unlimited or ``[]`` for no spending. - - Returns: - KeyAuthorization that can be signed by the root account - - Example: - >>> auth = create_key_authorization( - ... key_id="0xAccessKeyAddress...", - ... chain_id=42429, - ... expiry=1893456000, # Year 2030 - ... limits=[{"token": "0xUSDC...", "limit": 1000 * 10**6}], - ... ) - >>> signed = auth.sign("0xRootPrivateKey...") - >>> tx = TempoTransaction.create(..., key_authorization=signed.rlp_encode()) + .. deprecated:: + Use ``KeyAuthorization(...)`` directly. """ token_limits = None if limits is not None: @@ -296,9 +456,9 @@ def create_key_authorization( ] return KeyAuthorization( + key_id=key_id, chain_id=chain_id, key_type=key_type, - key_id=key_id, expiry=expiry, limits=token_limits, ) @@ -311,63 +471,23 @@ def build_keychain_signature( ) -> bytes: """Build a Keychain V2 signature for a message hash. - The access key signs ``keccak256(0x04 || sig_hash || user_address)`` - instead of the raw sig_hash, providing domain separation. - - Args: - msg_hash: 32-byte transaction signature hash - access_key_private_key: Private key of the access key (hex string with 0x prefix) - root_account: Address of the root account (hex string with 0x prefix) - - Returns: - 86-byte Keychain signature: 0x04 || root_account (20 bytes) || inner_sig (65 bytes) + .. deprecated:: + Use ``KeychainSignature.sign()`` instead, which returns a structured + :class:`KeychainSignature` object. """ - root_account_bytes = to_bytes(hexstr=root_account) - - # Compute V2 signing hash: keccak256(0x04 || sig_hash || user_address) - signing_hash = keccak( - bytes([KEYCHAIN_SIGNATURE_TYPE]) + msg_hash + root_account_bytes - ) - - # Sign with the access key - account = Account.from_key(access_key_private_key) - signed_msg = account.unsafe_sign_hash(signing_hash) - - # Build the inner secp256k1 signature (65 bytes): r || s || v - inner_sig = ( - signed_msg.r.to_bytes(32, "big") - + signed_msg.s.to_bytes(32, "big") - + bytes([signed_msg.v]) - ) - - # Build Keychain signature: 0x04 || root_account (20 bytes) || inner_sig (65 bytes) - keychain_sig = bytes([KEYCHAIN_SIGNATURE_TYPE]) + root_account_bytes + inner_sig - - assert len(keychain_sig) == KEYCHAIN_SIGNATURE_LENGTH, ( - f"Expected {KEYCHAIN_SIGNATURE_LENGTH} bytes, got {len(keychain_sig)}" - ) - - return keychain_sig + return KeychainSignature.sign( + msg_hash, access_key_private_key, root_account + ).to_bytes() -def sign_tx_access_key(tx, access_key_private_key: str, root_account: str): +def sign_tx_access_key( + tx: TempoTransaction, + access_key_private_key: str, + root_account: str, +) -> TempoTransaction: """Sign a Tempo transaction using access key mode (Keychain signature). - Returns a new TempoTransaction with the keychain signature applied. - - Args: - tx: TempoTransaction to sign - access_key_private_key: Private key of the access key (hex string with 0x prefix) - root_account: Address of the root account (hex string with 0x prefix) - - Returns: - New TempoTransaction with the keychain signature applied + .. deprecated:: + Use ``tx.sign_access_key()`` instead. """ - import attrs - - tx_with_sender = attrs.evolve(tx, sender_address=to_bytes(hexstr=root_account)) - msg_hash = tx_with_sender.get_signing_hash(for_fee_payer=False) - keychain_sig = build_keychain_signature( - msg_hash, access_key_private_key, root_account - ) - return attrs.evolve(tx_with_sender, sender_signature=keychain_sig) + return tx.sign_access_key(access_key_private_key, root_account) diff --git a/pytempo/models.py b/pytempo/models.py index 9564529..dfbd3be 100644 --- a/pytempo/models.py +++ b/pytempo/models.py @@ -1,6 +1,8 @@ """Strongly-typed data models for Tempo transactions.""" -from typing import Optional +from __future__ import annotations + +from typing import TYPE_CHECKING import attrs import rlp @@ -17,21 +19,17 @@ as_optional_address, ) +if TYPE_CHECKING: + from .keychain import KeychainSignature, SignedKeyAuthorization + def _validate_call_value( - instance: "Call", attribute: attrs.Attribute, value: int + instance: Call, attribute: attrs.Attribute, value: int ) -> None: if value < 0: raise ValueError("call.value must be >= 0") -def _validate_call_to( - instance: "Call", attribute: attrs.Attribute, value: Address -) -> None: - if len(bytes(value)) not in (0, 20): - raise ValueError("call.to must be 20 bytes (or empty for contract creation)") - - @attrs.define(frozen=True) class Call: """Single call in a batch transaction.""" @@ -49,13 +47,13 @@ def create( to: BytesLike, value: int = 0, data: BytesLike = b"", - ) -> "Call": + ) -> Call: """Create a Call with automatic type coercion.""" return cls(to=to, value=value, data=data) def _validate_access_list_address( - instance: "AccessListItem", attribute: attrs.Attribute, value: Address + instance: AccessListItem, attribute: attrs.Attribute, value: Address ) -> None: if len(bytes(value)) != 20: raise ValueError("access list address must be 20 bytes") @@ -84,7 +82,7 @@ def create( cls, address: BytesLike, storage_keys: tuple[BytesLike, ...] = (), - ) -> "AccessListItem": + ) -> AccessListItem: """Create an AccessListItem with automatic type coercion.""" return cls(address=address, storage_keys=storage_keys) @@ -94,14 +92,14 @@ def create( def _validate_signature_r( - instance: "Signature", attribute: attrs.Attribute, value: int + instance: Signature, attribute: attrs.Attribute, value: int ) -> None: if not (0 < value < SECP256K1_N): raise ValueError(f"signature r must be in range (0, secp256k1_n), got {value}") def _validate_signature_s( - instance: "Signature", attribute: attrs.Attribute, value: int + instance: Signature, attribute: attrs.Attribute, value: int ) -> None: if not (0 < value <= SECP256K1_HALF_N): raise ValueError( @@ -110,7 +108,7 @@ def _validate_signature_s( def _validate_signature_v( - instance: "Signature", attribute: attrs.Attribute, value: int + instance: Signature, attribute: attrs.Attribute, value: int ) -> None: if value not in (0, 1, 27, 28): raise ValueError(f"signature v must be 0, 1, 27, or 28, got {value}") @@ -142,7 +140,7 @@ def to_rlp_list(self) -> list: return [v_normalized, self.r, self.s] @classmethod - def from_bytes(cls, sig_bytes: bytes) -> "Signature": + def from_bytes(cls, sig_bytes: bytes) -> Signature: """Parse a 65-byte signature and validate r/s/v ranges. Raises: @@ -214,28 +212,24 @@ class TempoTransaction: nonce_key: int = 0 nonce: int = 0 - valid_before: Optional[int] = None - valid_after: Optional[int] = None + valid_before: int | None = None + valid_after: int | None = None - fee_token: Optional[Address] = attrs.field( - default=None, converter=as_optional_address - ) + fee_token: Address | None = attrs.field(default=None, converter=as_optional_address) - sender_address: Optional[Address] = attrs.field( + sender_address: Address | None = attrs.field( default=None, converter=as_optional_address ) awaiting_fee_payer: bool = False - fee_payer_signature: Optional[Signature | bytes] = None - sender_signature: Optional[Signature | bytes] = None + fee_payer_signature: Signature | None = None + sender_signature: Signature | KeychainSignature | None = None tempo_authorization_list: tuple[bytes, ...] = attrs.field( factory=tuple, converter=_convert_tempo_auth_list ) - key_authorization: Optional[bytes] = attrs.field( - default=None, converter=lambda x: as_bytes(x) if x is not None else None - ) + key_authorization: SignedKeyAuthorization | None = None # ------------------------------------------------------------------------- # Factory methods @@ -251,37 +245,16 @@ def create( max_priority_fee_per_gas: int = 0, nonce: int = 0, nonce_key: int = 0, - valid_before: Optional[int] = None, - valid_after: Optional[int] = None, - fee_token: Optional[BytesLike] = None, + valid_before: int | None = None, + valid_after: int | None = None, + fee_token: BytesLike | None = None, awaiting_fee_payer: bool = False, calls: tuple[Call, ...] = (), access_list: tuple[AccessListItem, ...] = (), tempo_authorization_list: tuple[BytesLike, ...] = (), - key_authorization: Optional[BytesLike] = None, - ) -> "TempoTransaction": - """ - Create a transaction with automatic type coercion. - - Args: - chain_id: Chain ID (default: 1) - gas_limit: Gas limit (default: 21_000) - max_fee_per_gas: Max fee per gas in wei - max_priority_fee_per_gas: Max priority fee per gas in wei - nonce: Transaction nonce - nonce_key: Nonce key for 2D nonce system - valid_before: Expiration timestamp (optional) - valid_after: Activation timestamp (optional) - fee_token: Fee token address as hex string or bytes (optional) - awaiting_fee_payer: Whether transaction awaits fee payer signature - calls: Tuple of Call objects - access_list: Tuple of AccessListItem objects - tempo_authorization_list: Tuple of authorization bytes - key_authorization: Signed key authorization bytes (optional) - - Returns: - New TempoTransaction instance - """ + key_authorization: SignedKeyAuthorization | None = None, + ) -> TempoTransaction: + """Create a transaction with automatic type coercion.""" return cls( chain_id=chain_id, gas_limit=gas_limit, @@ -300,7 +273,7 @@ def create( ) @classmethod - def from_dict(cls, d: dict) -> "TempoTransaction": + def from_dict(cls, d: dict) -> TempoTransaction: """ Parse a transaction from a dict with camelCase or snake_case keys. @@ -408,7 +381,7 @@ def validate(self, *, require_sender: bool = False) -> None: def _has_fee_payer(self) -> bool: return self.fee_payer_signature is not None or self.awaiting_fee_payer - def _encode_optional_uint(self, v: Optional[int]) -> bytes | int: + def _encode_optional_uint(self, v: int | None) -> bytes | int: return b"" if v is None else v def get_signing_hash(self, for_fee_payer: bool = False) -> bytes: @@ -448,8 +421,7 @@ def _signing_hash_sender(self) -> bytes: ] if self.key_authorization is not None: - key_auth_decoded = rlp.decode(self.key_authorization) - fields.append(key_auth_decoded) + fields.append(self.key_authorization.as_rlp_payload()) return keccak(bytes([self.TRANSACTION_TYPE]) + rlp.encode(fields)) @@ -473,37 +445,18 @@ def _signing_hash_fee_payer(self) -> bytes: ] if self.key_authorization is not None: - key_auth_decoded = rlp.decode(self.key_authorization) - fields.append(key_auth_decoded) + fields.append(self.key_authorization.as_rlp_payload()) return keccak(bytes([self.FEE_PAYER_MAGIC_BYTE]) + rlp.encode(fields)) def encode(self) -> bytes: - """ - Encode complete transaction: 0x76 || rlp([14 fields]) - - Returns: - Encoded transaction with type prefix - """ + """Encode complete transaction: ``0x76 || rlp([fields])``.""" self.validate() - def sender_sig_to_bytes(sig: Optional[Signature | bytes]) -> bytes: - if sig is None: - return b"" - if isinstance(sig, bytes): - return sig - return sig.to_bytes() - - def fee_payer_sig_to_rlp(sig: Optional[Signature | bytes]) -> list | bytes: - """Encode fee_payer_signature as RLP list [v, r, s] or empty bytes.""" - if sig is None: - return b"" - if isinstance(sig, bytes): - return sig - return sig.to_rlp_list() - - sender_sig = sender_sig_to_bytes(self.sender_signature) - fee_payer_sig = fee_payer_sig_to_rlp(self.fee_payer_signature) + sender_sig = self.sender_signature.to_bytes() if self.sender_signature else b"" + fee_payer_sig = ( + self.fee_payer_signature.to_rlp_list() if self.fee_payer_signature else b"" + ) fields = [ self.chain_id, @@ -521,12 +474,8 @@ def fee_payer_sig_to_rlp(sig: Optional[Signature | bytes]) -> list | bytes: list(self.tempo_authorization_list), ] - # key_authorization is a trailing optional field (only include when present) - # The field is already RLP-encoded, so we need to decode it first to get the - # raw structure, otherwise it gets double-encoded as a bytes string. if self.key_authorization is not None: - key_auth_decoded = rlp.decode(self.key_authorization) - fields.append(key_auth_decoded) + fields.append(self.key_authorization.as_rlp_payload()) fields.append(sender_sig) else: fields.append(sender_sig) @@ -537,11 +486,10 @@ def hash(self) -> bytes: """Get transaction hash.""" return keccak(self.encode()) - def vrs(self) -> tuple[Optional[int], Optional[int], Optional[int]]: + def vrs(self) -> tuple[int | None, int | None, int | None]: """Get v, r, s values for secp256k1 signatures. - Returns (None, None, None) if signature is not a Signature object - (e.g., for keychain signatures stored as raw bytes). + Returns (None, None, None) for keychain signatures. """ if isinstance(self.sender_signature, Signature): return ( @@ -551,15 +499,15 @@ def vrs(self) -> tuple[Optional[int], Optional[int], Optional[int]]: ) return (None, None, None) - def sign(self, private_key: str, for_fee_payer: bool = False) -> "TempoTransaction": - """ - Sign the transaction with secp256k1 private key. + def sign(self, private_key: str, for_fee_payer: bool = False) -> TempoTransaction: + """Sign the transaction with secp256k1 private key. Returns a new TempoTransaction with the signature applied. Args: - private_key: Private key as hex string - for_fee_payer: If True, sign as fee payer; else sign as sender + private_key: Private key as hex string (as used by + ``Account.from_key``). + for_fee_payer: If True, sign as fee payer; else sign as sender. """ account = Account.from_key(private_key) @@ -575,24 +523,41 @@ def sign(self, private_key: str, for_fee_payer: bool = False) -> "TempoTransacti sender_addr = as_address(account.address) return attrs.evolve(self, sender_signature=sig, sender_address=sender_addr) + def sign_access_key( + self, + access_key_private_key: str, + root_account: str, + ) -> TempoTransaction: + """Sign the transaction using an access key (Keychain signature). + + Returns a new TempoTransaction with the keychain signature applied. + + Args: + access_key_private_key: Private key of the access key (hex string, + as used by ``Account.from_key``). + root_account: Address of the root account (hex string). + """ + from .keychain import KeychainSignature + + root_addr = as_address(root_account) + tx_with_sender = attrs.evolve(self, sender_address=root_addr) + msg_hash = tx_with_sender.get_signing_hash(for_fee_payer=False) + sig = KeychainSignature.sign(msg_hash, access_key_private_key, root_addr) + return attrs.evolve(tx_with_sender, sender_signature=sig) + def to_estimate_gas_request( self, sender: str, - key_id: Optional[str] = None, - key_authorization: Optional[dict] = None, + key_id: str | None = None, + key_authorization: dict | SignedKeyAuthorization | None = None, ) -> dict: """Build an eth_estimateGas request dict from this transaction. Args: - sender: Address of the sender (hex string) - key_id: Optional access key address for keychain signature gas estimation - key_authorization: Optional SignedKeyAuthorization.to_json() dict - - Returns: - Dict suitable for w3.eth.estimate_gas() - - Example: - >>> gas = w3.eth.estimate_gas(tx.to_estimate_gas_request(sender)) + sender: Address of the sender (hex string). + key_id: Optional access key address for keychain signature gas estimation. + key_authorization: Optional :class:`SignedKeyAuthorization` or + pre-built JSON dict. """ if not self.calls: raise ValueError("Transaction must have at least one call") @@ -619,6 +584,9 @@ def to_estimate_gas_request( request["keyId"] = key_id if key_authorization is not None: - request["keyAuthorization"] = key_authorization + if isinstance(key_authorization, dict): + request["keyAuthorization"] = key_authorization + else: + request["keyAuthorization"] = key_authorization.to_json() return request diff --git a/pytempo/types.py b/pytempo/types.py index 9ecb62f..4ffd352 100644 --- a/pytempo/types.py +++ b/pytempo/types.py @@ -9,6 +9,7 @@ Address = NewType("Address", bytes) Hash32 = NewType("Hash32", bytes) +Selector = NewType("Selector", bytes) BytesLike = Union[bytes, str] @@ -91,3 +92,26 @@ def as_hash32(value: BytesLike) -> Hash32: if len(b) != 32: raise ValueError(f"hash32 must be 32 bytes, got {len(b)}") return Hash32(b) + + +def as_selector(value: BytesLike) -> Selector: + """Convert hex string or bytes to a validated 4-byte function selector. + + Use as: attrs.field(converter=as_selector) + + Raises: + TypeError: If value is not a string or bytes-like object (rejects int). + ValueError: If selector is not exactly 4 bytes. + """ + b = as_bytes(value) + if len(b) != 4: + raise ValueError(f"selector must be exactly 4 bytes, got {len(b)}") + return Selector(b) + + +def validate_nonempty_address( + instance: object, attribute: object, value: Address +) -> None: + """Attrs validator: address must be exactly 20 bytes (not empty).""" + if len(bytes(value)) != 20: + raise ValueError("address must be exactly 20 bytes") diff --git a/tests/test_integration.py b/tests/test_integration.py index c4de348..e2e76b3 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -70,6 +70,12 @@ def w3(rpc_url): return Web3(Web3.HTTPProvider(rpc_url)) +@pytest.fixture(scope="module") +def is_t2(): + """Check if running against a T2 network via TEMPO_HARDFORK env var (default T3).""" + return os.environ.get("TEMPO_HARDFORK", "T3") == "T2" + + @pytest.fixture(scope="module") def chain_id(w3): """Get the chain ID from the connected node.""" @@ -396,11 +402,10 @@ def test_add_access_key_with_key_authorization(self, w3, chain_id, funded_accoun expiry = int(time.time()) + 3600 auth = KeyAuthorization( + key_id=access_key.address, chain_id=chain_id, key_type=SignatureType.SECP256K1, - key_id=access_key.address, expiry=expiry, - limits=None, ) signed_auth = auth.sign(funded_account.key.hex()) @@ -413,7 +418,7 @@ def test_add_access_key_with_key_authorization(self, w3, chain_id, funded_accoun max_fee_per_gas=max_fee, max_priority_fee_per_gas=priority_fee, calls=(Call.create(to=COUNTER_CONTRACT, data=COUNTER_INCREMENT),), - key_authorization=signed_auth.rlp_encode(), + key_authorization=signed_auth, ) # Estimate gas from the transaction @@ -421,7 +426,7 @@ def test_add_access_key_with_key_authorization(self, w3, chain_id, funded_accoun tx.to_estimate_gas_request( funded_account.address, key_id=access_key.address, - key_authorization=signed_auth.to_json(), + key_authorization=signed_auth, ) ) @@ -434,11 +439,10 @@ def test_add_access_key_with_key_authorization(self, w3, chain_id, funded_accoun max_fee_per_gas=max_fee, max_priority_fee_per_gas=priority_fee, calls=(Call.create(to=COUNTER_CONTRACT, data=COUNTER_INCREMENT),), - key_authorization=signed_auth.rlp_encode(), + key_authorization=signed_auth, ) - signed_tx = sign_tx_access_key( - tx, + signed_tx = tx.sign_access_key( access_key_private_key=access_key.key.hex(), root_account=funded_account.address, ) @@ -459,11 +463,10 @@ def test_sign_tx_with_existing_access_key(self, w3, chain_id, funded_account): expiry = int(time.time()) + 3600 auth = KeyAuthorization( + key_id=access_key.address, chain_id=chain_id, key_type=SignatureType.SECP256K1, - key_id=access_key.address, expiry=expiry, - limits=None, ) signed_auth = auth.sign(funded_account.key.hex()) @@ -479,7 +482,7 @@ def test_sign_tx_with_existing_access_key(self, w3, chain_id, funded_account): max_fee_per_gas=max_fee, max_priority_fee_per_gas=priority_fee, calls=(Call.create(to=COUNTER_CONTRACT, data=COUNTER_INCREMENT),), - key_authorization=signed_auth.rlp_encode(), + key_authorization=signed_auth, ) # Estimate gas from the transaction @@ -487,7 +490,7 @@ def test_sign_tx_with_existing_access_key(self, w3, chain_id, funded_account): tx1.to_estimate_gas_request( funded_account.address, key_id=access_key.address, - key_authorization=signed_auth.to_json(), + key_authorization=signed_auth, ) ) @@ -500,11 +503,10 @@ def test_sign_tx_with_existing_access_key(self, w3, chain_id, funded_account): max_fee_per_gas=max_fee, max_priority_fee_per_gas=priority_fee, calls=(Call.create(to=COUNTER_CONTRACT, data=COUNTER_INCREMENT),), - key_authorization=signed_auth.rlp_encode(), + key_authorization=signed_auth, ) - signed_tx1 = sign_tx_access_key( - tx1, + signed_tx1 = tx1.sign_access_key( access_key_private_key=access_key.key.hex(), root_account=funded_account.address, ) @@ -542,8 +544,7 @@ def test_sign_tx_with_existing_access_key(self, w3, chain_id, funded_account): calls=(Call.create(to=COUNTER_CONTRACT, data=COUNTER_INCREMENT),), ) - signed_tx2 = sign_tx_access_key( - tx2, + signed_tx2 = tx2.sign_access_key( access_key_private_key=access_key.key.hex(), root_account=funded_account.address, ) @@ -780,7 +781,7 @@ class TestKeychainSelectors: authorizeKey → getKey → revokeKey round-trip via the precompile. """ - def test_authorize_get_revoke_round_trip(self, w3, chain_id, funded_account): + def test_authorize_get_revoke_round_trip(self, w3, chain_id, funded_account, is_t2): """Authorize an access key, verify via getKey, revoke, verify revoked.""" max_fee, priority_fee = get_gas_params(w3) @@ -796,6 +797,7 @@ def test_authorize_get_revoke_round_trip(self, w3, chain_id, funded_account): expiry=expiry, enforce_limits=False, limits=[], + legacy=is_t2, ) tx = TempoTransaction.create( chain_id=chain_id, @@ -946,7 +948,7 @@ def test_inline_key_auth_with_limits(self, w3, chain_id, funded_account): max_fee_per_gas=max_fee, max_priority_fee_per_gas=priority_fee, calls=(Call.create(to=COUNTER_CONTRACT, data=COUNTER_INCREMENT),), - key_authorization=signed_auth.rlp_encode(), + key_authorization=signed_auth, ) gas_estimate = w3.eth.estimate_gas( @@ -965,7 +967,7 @@ def test_inline_key_auth_with_limits(self, w3, chain_id, funded_account): max_fee_per_gas=max_fee, max_priority_fee_per_gas=priority_fee, calls=(Call.create(to=COUNTER_CONTRACT, data=COUNTER_INCREMENT),), - key_authorization=signed_auth.rlp_encode(), + key_authorization=signed_auth, ) signed_tx = sign_tx_access_key( diff --git a/tests/test_keychain.py b/tests/test_keychain.py index 1b023d0..f9bfa4b 100644 --- a/tests/test_keychain.py +++ b/tests/test_keychain.py @@ -6,7 +6,8 @@ from eth_account import Account from eth_utils import to_bytes -from pytempo import Call, TempoTransaction +from pytempo import Call, CallScope, TempoTransaction +from pytempo.contracts import ALPHA_USD from pytempo.contracts.account_keychain import AccountKeychain from pytempo.contracts.addresses import ACCOUNT_KEYCHAIN_ADDRESS from pytempo.keychain import ( @@ -16,10 +17,11 @@ KEYCHAIN_SIGNATURE_TYPE, # Key authorization classes KeyAuthorization, + KeychainSignature, SignatureType, SignedKeyAuthorization, TokenLimit, - # Signing functions + # Signing functions (deprecated wrappers) build_keychain_signature, create_key_authorization, sign_tx_access_key, @@ -257,10 +259,12 @@ def test_signature_length_is_86_bytes(self): calls=(Call.create(to="0x" + "c" * 40, value=1000),), ) - signed = sign_tx_access_key(tx, access_key_private, root_account) + signed = tx.sign_access_key(access_key_private, root_account) - assert len(signed.sender_signature) == KEYCHAIN_SIGNATURE_LENGTH - assert len(signed.sender_signature) == 86 + sig = signed.sender_signature + assert isinstance(sig, KeychainSignature) + assert len(sig.to_bytes()) == KEYCHAIN_SIGNATURE_LENGTH + assert len(sig.to_bytes()) == 86 def test_signature_starts_with_0x04(self): """First byte must be 0x04 (Keychain V2 type identifier).""" @@ -274,10 +278,11 @@ def test_signature_starts_with_0x04(self): calls=(Call.create(to="0x" + "c" * 40, value=1000),), ) - signed = sign_tx_access_key(tx, access_key_private, root_account) + signed = tx.sign_access_key(access_key_private, root_account) - assert signed.sender_signature[0] == KEYCHAIN_SIGNATURE_TYPE - assert signed.sender_signature[0] == 0x04 + sig_bytes = signed.sender_signature.to_bytes() + assert sig_bytes[0] == KEYCHAIN_SIGNATURE_TYPE + assert sig_bytes[0] == 0x04 def test_root_account_embedded_in_signature(self): """Bytes 1-21 must contain the root account address.""" @@ -291,15 +296,14 @@ def test_root_account_embedded_in_signature(self): calls=(Call.create(to="0x" + "c" * 40, value=1000),), ) - signed = sign_tx_access_key(tx, access_key_private, root_account) - - embedded_address = signed.sender_signature[1:21] - expected_address = to_bytes(hexstr=root_account) + signed = tx.sign_access_key(access_key_private, root_account) - assert embedded_address == expected_address + sig = signed.sender_signature + assert isinstance(sig, KeychainSignature) + assert bytes(sig.root_account) == to_bytes(hexstr=root_account) def test_inner_signature_is_65_bytes(self): - """Bytes 21-86 must be 65-byte inner signature (r || s || v).""" + """Inner signature must be 65 bytes (r || s || v).""" access_key_private = "0x" + "a" * 64 root_account = "0x" + "b" * 40 @@ -310,11 +314,12 @@ def test_inner_signature_is_65_bytes(self): calls=(Call.create(to="0x" + "c" * 40, value=1000),), ) - signed = sign_tx_access_key(tx, access_key_private, root_account) + signed = tx.sign_access_key(access_key_private, root_account) - inner_sig = signed.sender_signature[21:] - assert len(inner_sig) == INNER_SIGNATURE_LENGTH - assert len(inner_sig) == 65 + sig = signed.sender_signature + assert isinstance(sig, KeychainSignature) + assert len(sig.inner.to_bytes()) == INNER_SIGNATURE_LENGTH + assert len(sig.inner.to_bytes()) == 65 def test_sender_address_set_to_root_account(self): """sender_address must be set to root account.""" @@ -328,7 +333,7 @@ def test_sender_address_set_to_root_account(self): calls=(Call.create(to="0x" + "c" * 40, value=1000),), ) - signed = sign_tx_access_key(tx, access_key_private, root_account) + signed = tx.sign_access_key(access_key_private, root_account) assert bytes(signed.sender_address) == to_bytes(hexstr=root_account) @@ -348,10 +353,10 @@ def test_different_signature_length(self): calls=(Call.create(to="0x" + "c" * 40, value=1000),), ) - keychain_signed = sign_tx_access_key(tx, access_key_private, root_account) + keychain_signed = tx.sign_access_key(access_key_private, root_account) normal_signed = tx.sign(access_key_private) - assert len(keychain_signed.sender_signature) == 86 # Keychain + assert len(keychain_signed.sender_signature.to_bytes()) == 86 # Keychain assert len(normal_signed.sender_signature.to_bytes()) == 65 # Normal secp256k1 def test_different_type_prefix(self): @@ -366,14 +371,14 @@ def test_different_type_prefix(self): calls=(Call.create(to="0x" + "c" * 40, value=1000),), ) - keychain_signed = sign_tx_access_key(tx, access_key_private, root_account) + keychain_signed = tx.sign_access_key(access_key_private, root_account) normal_signed = tx.sign(access_key_private) - assert keychain_signed.sender_signature[0] == 0x04 + assert keychain_signed.sender_signature.to_bytes()[0] == 0x04 assert normal_signed.sender_signature.to_bytes()[0] != 0x04 - def test_encoded_transactions_different(self): - """Encoded transactions should be different.""" + def test_both_produce_valid_encoded_tx(self): + """Both signing methods should produce a valid encoded tx.""" access_key_private = "0x" + "a" * 64 root_account = "0x" + "b" * 40 @@ -384,10 +389,12 @@ def test_encoded_transactions_different(self): calls=(Call.create(to="0x" + "c" * 40, value=1000),), ) - keychain_signed = sign_tx_access_key(tx, access_key_private, root_account) + keychain_signed = tx.sign_access_key(access_key_private, root_account) normal_signed = tx.sign(access_key_private) assert keychain_signed.encode() != normal_signed.encode() + assert keychain_signed.encode()[0] == 0x76 + assert normal_signed.encode()[0] == 0x76 class TestBuildKeychainSignature: @@ -428,8 +435,40 @@ def test_different_hash_different_signature(self): assert sig1[21:] != sig2[21:] +class TestKeychainSignatureType: + """Tests for KeychainSignature structured type.""" + + def test_roundtrip_bytes(self): + """to_bytes / from_bytes roundtrip should preserve data.""" + access_key_private = "0x" + "a" * 64 + root_account = "0x" + "b" * 40 + msg_hash = b"\x00" * 32 + + sig = KeychainSignature.sign(msg_hash, access_key_private, root_account) + raw = sig.to_bytes() + parsed = KeychainSignature.from_bytes(raw) + + assert bytes(parsed.root_account) == bytes(sig.root_account) + assert parsed.inner == sig.inner + + def test_from_bytes_rejects_wrong_length(self): + with pytest.raises(ValueError, match="86 bytes"): + KeychainSignature.from_bytes(b"\x00" * 85) + + def test_from_bytes_rejects_wrong_type_byte(self): + raw = b"\x05" + b"\x00" * 85 + with pytest.raises(ValueError, match="0x04"): + KeychainSignature.from_bytes(raw) + + def test_frozen(self): + """KeychainSignature should be immutable.""" + sig = KeychainSignature.sign(b"\x00" * 32, "0x" + "a" * 64, "0x" + "b" * 40) + with pytest.raises(AttributeError): + sig.root_account = b"\x00" * 20 # type: ignore[misc] + + class TestSignatureType: - """Tests for SignatureType constants.""" + """Tests for SignatureType enum.""" def test_secp256k1_is_zero(self): assert SignatureType.SECP256K1 == 0 @@ -440,28 +479,48 @@ def test_p256_is_one(self): def test_webauthn_is_two(self): assert SignatureType.WEBAUTHN == 2 + def test_rejects_invalid_value(self): + with pytest.raises(ValueError): + SignatureType(999) + + def test_json_names(self): + assert SignatureType.SECP256K1.to_json_name() == "secp256k1" + assert SignatureType.P256.to_json_name() == "p256" + assert SignatureType.WEBAUTHN.to_json_name() == "webAuthn" + class TestTokenLimit: - """Tests for TokenLimit dataclass.""" + """Tests for TokenLimit attrs model.""" - def test_to_rlp(self): - """Should convert to RLP-serializable format.""" + def test_accepts_hex_string(self): + """Should convert hex string to Address.""" limit = TokenLimit(token="0x" + "a" * 40, limit=1000) - rlp_obj = limit.to_rlp() + assert bytes(limit.token) == bytes.fromhex("a" * 40) + assert limit.limit == 1000 - assert rlp_obj.token == bytes.fromhex("a" * 40) - assert rlp_obj.limit == 1000 + def test_rejects_empty_token(self): + with pytest.raises(ValueError, match="20 bytes"): + TokenLimit(token="0x", limit=1000) + + def test_rejects_negative_limit(self): + with pytest.raises(ValueError): + TokenLimit(token="0x" + "a" * 40, limit=-1) + + def test_frozen(self): + limit = TokenLimit(token="0x" + "a" * 40, limit=1000) + with pytest.raises(AttributeError): + limit.limit = 2000 # type: ignore[misc] class TestKeyAuthorization: - """Tests for KeyAuthorization dataclass.""" + """Tests for KeyAuthorization attrs model.""" def test_rlp_encode_minimal(self): """Should RLP encode with minimal fields.""" auth = KeyAuthorization( + key_id="0x" + "b" * 40, chain_id=42429, key_type=SignatureType.SECP256K1, - key_id="0x" + "b" * 40, ) encoded = auth.rlp_encode() @@ -471,9 +530,9 @@ def test_rlp_encode_minimal(self): def test_rlp_encode_with_expiry(self): """Should RLP encode with expiry.""" auth = KeyAuthorization( + key_id="0x" + "b" * 40, chain_id=42429, key_type=SignatureType.SECP256K1, - key_id="0x" + "b" * 40, expiry=1893456000, ) @@ -483,9 +542,9 @@ def test_rlp_encode_with_expiry(self): def test_rlp_encode_with_limits(self): """Should RLP encode with token limits.""" auth = KeyAuthorization( + key_id="0x" + "b" * 40, chain_id=42429, key_type=SignatureType.SECP256K1, - key_id="0x" + "b" * 40, limits=[TokenLimit(token="0x" + "c" * 40, limit=1000)], ) @@ -495,9 +554,9 @@ def test_rlp_encode_with_limits(self): def test_signature_hash_deterministic(self): """Should produce deterministic hash.""" auth = KeyAuthorization( + key_id="0x" + "b" * 40, chain_id=42429, key_type=SignatureType.SECP256K1, - key_id="0x" + "b" * 40, ) hash1 = auth.signature_hash() @@ -509,14 +568,14 @@ def test_signature_hash_deterministic(self): def test_signature_hash_different_for_different_auth(self): """Different authorizations should have different hashes.""" auth1 = KeyAuthorization( + key_id="0x" + "b" * 40, chain_id=42429, key_type=SignatureType.SECP256K1, - key_id="0x" + "b" * 40, ) auth2 = KeyAuthorization( + key_id="0x" + "c" * 40, chain_id=42429, key_type=SignatureType.SECP256K1, - key_id="0x" + "c" * 40, ) assert auth1.signature_hash() != auth2.signature_hash() @@ -525,30 +584,51 @@ def test_sign_returns_signed_authorization(self): """Should return a SignedKeyAuthorization.""" private_key = "0x" + "a" * 64 auth = KeyAuthorization( + key_id="0x" + "b" * 40, chain_id=42429, key_type=SignatureType.SECP256K1, - key_id="0x" + "b" * 40, ) signed = auth.sign(private_key) assert isinstance(signed, SignedKeyAuthorization) assert signed.authorization == auth - assert signed.v in (27, 28) - assert signed.r > 0 - assert signed.s > 0 + assert signed.signature.v in (27, 28) + assert signed.signature.r > 0 + assert signed.signature.s > 0 + + def test_rejects_empty_key_id(self): + with pytest.raises(ValueError, match="20 bytes"): + KeyAuthorization(key_id="0x") + + def test_rejects_invalid_key_type(self): + with pytest.raises(ValueError): + KeyAuthorization(key_id="0x" + "b" * 40, key_type=999) + + def test_converter_accepts_int_key_type(self): + """IntEnum converter should accept plain ints for valid values.""" + auth = KeyAuthorization(key_id="0x" + "b" * 40, key_type=1) + assert auth.key_type is SignatureType.P256 + + def test_frozen(self): + auth = KeyAuthorization( + key_id="0x" + "b" * 40, + chain_id=42429, + ) + with pytest.raises(AttributeError): + auth.chain_id = 1 # type: ignore[misc] class TestSignedKeyAuthorization: - """Tests for SignedKeyAuthorization dataclass.""" + """Tests for SignedKeyAuthorization attrs model.""" def test_rlp_encode(self): """Should RLP encode the signed authorization.""" private_key = "0x" + "a" * 64 auth = KeyAuthorization( + key_id="0x" + "b" * 40, chain_id=42429, key_type=SignatureType.SECP256K1, - key_id="0x" + "b" * 40, ) signed = auth.sign(private_key) @@ -562,9 +642,9 @@ def test_recover_signer(self): account = Account.from_key(private_key) auth = KeyAuthorization( + key_id="0x" + "b" * 40, chain_id=42429, key_type=SignatureType.SECP256K1, - key_id="0x" + "b" * 40, ) signed = auth.sign(private_key) @@ -577,9 +657,9 @@ def test_recover_signer_with_expiry_and_limits(self): account = Account.from_key(private_key) auth = KeyAuthorization( + key_id="0x" + "b" * 40, chain_id=42429, key_type=SignatureType.P256, - key_id="0x" + "b" * 40, expiry=1893456000, limits=[TokenLimit(token="0x" + "c" * 40, limit=1000000)], ) @@ -588,56 +668,151 @@ def test_recover_signer_with_expiry_and_limits(self): recovered = signed.recover_signer() assert recovered.lower() == account.address.lower() + def test_to_json(self): + """Should produce valid JSON dict.""" + private_key = "0x" + "a" * 64 + auth = KeyAuthorization( + key_id="0x" + "b" * 40, + chain_id=42429, + key_type=SignatureType.SECP256K1, + expiry=1893456000, + limits=[TokenLimit(token="0x" + "c" * 40, limit=1000)], + ) + signed = auth.sign(private_key) + j = signed.to_json() -class TestCreateKeyAuthorization: - """Tests for create_key_authorization helper function.""" - - def test_creates_basic_authorization(self): - """Should create a basic KeyAuthorization.""" - auth = create_key_authorization(key_id="0x" + "b" * 40) - - assert auth.chain_id == 0 - assert auth.key_type == SignatureType.SECP256K1 - assert auth.key_id == "0x" + "b" * 40 - assert auth.expiry is None - assert auth.limits is None + assert j["keyType"] == "secp256k1" + assert "signature" in j + assert "expiry" in j + assert "limits" in j + assert len(j["limits"]) == 1 - def test_creates_with_all_options(self): - """Should create with all options specified.""" - auth = create_key_authorization( + def test_frozen(self): + private_key = "0x" + "a" * 64 + auth = KeyAuthorization( key_id="0x" + "b" * 40, chain_id=42429, - key_type=SignatureType.WEBAUTHN, - expiry=1893456000, - limits=[{"token": "0x" + "c" * 40, "limit": 1000}], ) + signed = auth.sign(private_key) + with pytest.raises(AttributeError): + signed.authorization = None # type: ignore[misc] - assert auth.chain_id == 42429 - assert auth.key_type == SignatureType.WEBAUTHN - assert auth.expiry == 1893456000 - assert len(auth.limits) == 1 - assert auth.limits[0].token == "0x" + "c" * 40 - assert auth.limits[0].limit == 1000 + +class TestSignTxWorkflow: + """Tests for the full sign → encode workflow.""" def test_sign_and_use_workflow(self): """Test the full workflow: create, sign, encode.""" private_key = "0x" + "a" * 64 account = Account.from_key(private_key) - # Create authorization - auth = create_key_authorization( + auth = KeyAuthorization( key_id="0x" + "b" * 40, chain_id=42429, expiry=1893456000, ) - # Sign it signed = auth.sign(private_key) - # Verify signer assert signed.recover_signer().lower() == account.address.lower() - # Encode for transaction encoded = signed.rlp_encode() assert isinstance(encoded, bytes) assert len(encoded) > 0 + + def test_deprecated_sign_tx_access_key(self): + """Deprecated wrapper should still work.""" + access_key_private = "0x" + "a" * 64 + root_account = "0x" + "b" * 40 + + tx = TempoTransaction.create( + chain_id=42431, + gas_limit=21000, + nonce=0, + calls=(Call.create(to="0x" + "c" * 40, value=1000),), + ) + + signed = sign_tx_access_key(tx, access_key_private, root_account) + assert isinstance(signed.sender_signature, KeychainSignature) + assert len(signed.sender_signature.to_bytes()) == 86 + + +class TestCallScopeConstructors: + """Tests for CallScope named constructors.""" + + def test_transfer_selector(self): + s = CallScope.transfer(target=ALPHA_USD) + assert bytes(s.selector) == bytes.fromhex("a9059cbb") + + def test_approve_selector(self): + s = CallScope.approve(target=ALPHA_USD) + assert bytes(s.selector) == bytes.fromhex("095ea7b3") + + def test_transfer_with_memo_selector(self): + s = CallScope.transfer_with_memo(target=ALPHA_USD) + assert bytes(s.selector) == bytes.fromhex("95777d59") + + def test_unrestricted_selector(self): + s = CallScope.unrestricted(target="0x" + "aa" * 20) + assert bytes(s.selector) == b"\x00\x00\x00\x00" + + def test_unrestricted_allows_tip20(self): + s = CallScope.unrestricted(target=ALPHA_USD) + assert bytes(s.target).startswith(bytes.fromhex("20C000000000000000000000")) + + def test_tip20_rejects_non_tip20_address(self): + with pytest.raises(ValueError, match="TIP20"): + CallScope.transfer(target="0x" + "aa" * 20) + + def test_frozen(self): + s = CallScope.transfer(target=ALPHA_USD) + with pytest.raises(AttributeError): + s.selector = b"\x00" * 4 # type: ignore[misc] + + +class TestRecoverSigner: + """Tests for SignedKeyAuthorization.recover_signer.""" + + def test_roundtrip(self): + private_key = "0x" + "a" * 64 + account = Account.from_key(private_key) + + auth = KeyAuthorization( + key_id="0x" + "b" * 40, + chain_id=42429, + ) + signed = auth.sign(private_key) + assert signed.recover_signer().lower() == account.address.lower() + + +class TestDeprecatedVrsShims: + """Tests for v/r/s property shims on SignedKeyAuthorization.""" + + def test_v_r_s_match_signature(self): + auth = KeyAuthorization(key_id="0x" + "b" * 40, chain_id=42429) + signed = auth.sign("0x" + "a" * 64) + + assert signed.v == signed.signature.v + assert signed.r == signed.signature.r + assert signed.s == signed.signature.s + + +class TestCreateKeyAuthorizationCompat: + """Tests for deprecated create_key_authorization wrapper.""" + + def test_matches_direct_construction(self): + via_wrapper = create_key_authorization( + key_id="0x" + "b" * 40, + chain_id=42429, + key_type=SignatureType.SECP256K1, + expiry=1893456000, + limits=[{"token": "0x" + "c" * 40, "limit": 1000}], + ) + via_direct = KeyAuthorization( + key_id="0x" + "b" * 40, + chain_id=42429, + key_type=SignatureType.SECP256K1, + expiry=1893456000, + limits=(TokenLimit(token="0x" + "c" * 40, limit=1000),), + ) + assert via_wrapper.rlp_encode() == via_direct.rlp_encode() diff --git a/tests/test_typed_models.py b/tests/test_typed_models.py index 9f6fce2..a5ebadf 100644 --- a/tests/test_typed_models.py +++ b/tests/test_typed_models.py @@ -64,7 +64,7 @@ def test_as_hash32_rejects_int(self): class TestCall: - """Test Call dataclass.""" + """Test Call model.""" def test_create_call(self): call = Call.create( @@ -88,7 +88,7 @@ def test_call_as_rlp_list(self): class TestAccessListItem: - """Test AccessListItem dataclass.""" + """Test AccessListItem model.""" def test_create_access_list_item(self): item = AccessListItem.create( @@ -101,7 +101,7 @@ def test_create_access_list_item(self): class TestSignature: - """Test Signature dataclass.""" + """Test Signature model.""" def test_signature_to_bytes(self): sig = Signature(r=1, s=2, v=27)