Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changelog/plain-cats-growl.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
pytempo: minor
---

Added `SelectorRule` class for per-selector recipient filtering in call scope restrictions. Extended `CallScope` factory methods (`transfer`, `approve`, `transfer_with_memo`) to accept optional `recipients` lists, and added `CallScope.with_selector` for arbitrary 4-byte selector scoping. Added `AccountKeychain.set_allowed_calls` and `remove_allowed_calls` static methods, and added validation guards to `authorize_key` rejecting conflicting `legacy`/`allowed_calls` combinations.
2 changes: 2 additions & 0 deletions pytempo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
CallScope,
KeyAuthorization,
KeychainSignature,
SelectorRule,
SignatureType,
SignedKeyAuthorization,
TokenLimit,
Expand Down Expand Up @@ -81,6 +82,7 @@
"SignatureType",
"TokenLimit",
"CallScope",
"SelectorRule",
"KeychainSignature",
# Keychain signing (deprecated free functions)
"KEYCHAIN_SIGNATURE_TYPE",
Expand Down
42 changes: 39 additions & 3 deletions pytempo/contracts/account_keychain.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def authorize_key(
Ignored when ``legacy=True``.
legacy: Use pre-T3 flat-parameter encoding. Pass ``True`` until T3 is activated, then remove this argument.
"""
if legacy and (allowed_calls or not allow_any_calls):
raise ValueError("legacy=True does not support call restrictions")

if allowed_calls and allow_any_calls:
raise ValueError(
"allowed_calls was provided but allow_any_calls=True; "
"pass allow_any_calls=False to create a scoped key"
)

if legacy:
limit_tuples = (
[(t, a) for t, a, *_ in ((*lim, 0)[:3] for lim in limits)]
Expand All @@ -97,9 +106,7 @@ def authorize_key(
else []
)
call_tuples = (
[(bytes(s.target), [(bytes(s.selector), [])]) for s in allowed_calls]
if allowed_calls
else []
[s.to_abi_tuple() for s in allowed_calls] if allowed_calls else []
)
config = (
expiry,
Expand All @@ -122,6 +129,35 @@ def revoke_key(*, key_id: str) -> Call:
data = encode_calldata(_ABI, "revokeKey", [key_id])
return Call.create(to=ACCOUNT_KEYCHAIN_ADDRESS, data=data)

@staticmethod
def set_allowed_calls(
*,
key_id: str,
scopes: Sequence[CallScope],
) -> Call:
"""Build a ``setAllowedCalls(address,CallScope[])`` call.

Args:
key_id: The access key address.
scopes: List of :class:`~pytempo.CallScope` to set as the allowlist.
"""
call_tuples = [s.to_abi_tuple() for s in scopes]
data = encode_calldata(_ABI, "setAllowedCalls", [key_id, call_tuples])
return Call.create(to=ACCOUNT_KEYCHAIN_ADDRESS, data=data)

@staticmethod
def remove_allowed_calls(*, key_id: str, target: str) -> Call:
"""Build a ``removeAllowedCalls(address,address)`` call.

Removes all call-scope rules targeting ``target`` from the key's allowlist.

Args:
key_id: The access key address.
target: The contract address to remove from the allowlist.
"""
data = encode_calldata(_ABI, "removeAllowedCalls", [key_id, target])
return Call.create(to=ACCOUNT_KEYCHAIN_ADDRESS, data=data)

@staticmethod
def update_spending_limit(*, key_id: str, token: str, new_limit: int) -> Call:
"""Build an ``updateSpendingLimit(address,address,uint256)`` call."""
Expand Down
126 changes: 116 additions & 10 deletions pytempo/keychain.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,37 @@ def _validate_tip20_address(target: BytesLike) -> Address:
return addr


def _convert_addresses(
value: tuple[Address, ...] | list[BytesLike] | None,
) -> tuple[Address, ...]:
if value is None:
return ()
return tuple(as_address(v) for v in value)


@attrs.define(frozen=True)
class SelectorRule:
"""A single function-selector restriction with optional recipient filtering.

Args:
selector: 4-byte function selector.
recipients: Allowed first-argument addresses. Empty means any recipient.
"""

selector: Selector = attrs.field(converter=as_selector)
recipients: tuple[Address, ...] = attrs.field(
factory=tuple, converter=_convert_addresses
)


def _convert_selector_rules(
value: tuple[SelectorRule, ...] | list[SelectorRule] | None,
) -> tuple[SelectorRule, ...]:
if value is None:
return ()
return tuple(value)


@attrs.define(frozen=True)
class CallScope:
"""Call scope restriction for access keys (TIP-1011).
Expand All @@ -137,38 +168,113 @@ class CallScope:
- ``CallScope.approve(target=...)`` — allow ``approve`` on a TIP20 token.
- ``CallScope.transfer_with_memo(target=...)`` — allow ``transferWithMemo``
on a TIP20 token.
- ``CallScope.with_selector(target=..., selector=...)`` — allow an arbitrary
4-byte selector on any contract.

Args:
target: Contract address the key is allowed to call.
selector: 4-byte function selector. Only applicable for TIP20 tokens.
selector: 4-byte function selector (kept for backwards compatibility).
selector_rules: Full selector rules with optional recipient filtering.
When empty, falls back to ``selector`` as a single wildcard-recipient rule.
"""

target: Address = attrs.field(
converter=as_address, validator=validate_nonempty_address
)
selector: Selector = attrs.field(converter=as_selector)
selector_rules: tuple[SelectorRule, ...] = attrs.field(
factory=tuple, converter=_convert_selector_rules
)

def to_abi_tuple(self) -> tuple:
"""Convert to ABI-encodable tuple ``(target, [(selector, recipients), ...])``.

If ``selector_rules`` is empty, falls back to a single rule from ``selector``
with no recipient restriction (backwards-compatible behaviour).
"""
rules = self.selector_rules
if not rules:
rules = (SelectorRule(selector=self.selector),)
return (
bytes(self.target),
[(bytes(r.selector), [bytes(a) for a in r.recipients]) for r in rules],
)

@classmethod
def unrestricted(cls, *, target: BytesLike) -> CallScope:
"""Allow all functions on a target (any contract, including TIP20)."""
return cls(target=target, selector=_WILDCARD_SELECTOR)

@classmethod
def transfer(cls, *, target: BytesLike) -> CallScope:
"""Allow ``transfer(address,uint256)`` on a TIP20 token target."""
return cls(target=_validate_tip20_address(target), selector=_TIP20_TRANSFER)
def with_selector(
cls,
*,
target: BytesLike,
selector: BytesLike,
) -> CallScope:
"""Allow calls matching an arbitrary 4-byte function selector.

Args:
target: Contract address.
selector: 4-byte function selector.
"""
sel = as_selector(selector)
rule = SelectorRule(selector=sel)
return cls(target=target, selector=sel, selector_rules=(rule,))

@classmethod
def transfer(
cls,
*,
target: BytesLike,
recipients: list[BytesLike] = (),
) -> CallScope:
"""Allow ``transfer(address,uint256)`` on a TIP20 token target.

Args:
target: TIP20 token address.
recipients: Allowed transfer recipients. Empty means any recipient.
"""
addr = _validate_tip20_address(target)
rule = SelectorRule(selector=_TIP20_TRANSFER, recipients=recipients)
return cls(target=addr, selector=_TIP20_TRANSFER, selector_rules=(rule,))

@classmethod
def approve(cls, *, target: BytesLike) -> CallScope:
"""Allow ``approve(address,uint256)`` on a TIP20 token target."""
return cls(target=_validate_tip20_address(target), selector=_TIP20_APPROVE)
def approve(
cls,
*,
target: BytesLike,
recipients: list[BytesLike] = (),
) -> CallScope:
"""Allow ``approve(address,uint256)`` on a TIP20 token target.

Args:
target: TIP20 token address.
recipients: Allowed spender addresses. Empty means any spender.
"""
addr = _validate_tip20_address(target)
rule = SelectorRule(selector=_TIP20_APPROVE, recipients=recipients)
return cls(target=addr, selector=_TIP20_APPROVE, selector_rules=(rule,))

@classmethod
def transfer_with_memo(cls, *, target: BytesLike) -> CallScope:
"""Allow ``transferWithMemo(address,uint256,bytes32)`` on a TIP20 token target."""
def transfer_with_memo(
cls,
*,
target: BytesLike,
recipients: list[BytesLike] = (),
) -> CallScope:
"""Allow ``transferWithMemo(address,uint256,bytes32)`` on a TIP20 token target.

Args:
target: TIP20 token address.
recipients: Allowed transfer recipients. Empty means any recipient.
"""
addr = _validate_tip20_address(target)
rule = SelectorRule(selector=_TIP20_TRANSFER_WITH_MEMO, recipients=recipients)
return cls(
target=_validate_tip20_address(target),
target=addr,
selector=_TIP20_TRANSFER_WITH_MEMO,
selector_rules=(rule,),
)


Expand Down
129 changes: 128 additions & 1 deletion tests/test_keychain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from eth_account import Account
from eth_utils import to_bytes

from pytempo import Call, CallScope, TempoTransaction
from pytempo import Call, CallScope, SelectorRule, TempoTransaction
from pytempo.contracts import ALPHA_USD
from pytempo.contracts.account_keychain import AccountKeychain
from pytempo.contracts.addresses import ACCOUNT_KEYCHAIN_ADDRESS
Expand Down Expand Up @@ -769,6 +769,133 @@ def test_frozen(self):
with pytest.raises(AttributeError):
s.selector = b"\x00" * 4 # type: ignore[misc]

def test_with_selector(self):
target = "0x" + "aa" * 20
sel = bytes.fromhex("aabbccdd")
s = CallScope.with_selector(target=target, selector=sel)
assert bytes(s.selector) == sel
assert len(s.selector_rules) == 1
assert s.selector_rules[0].recipients == ()

def test_with_selector_no_recipients(self):
target = "0x" + "aa" * 20
sel = bytes.fromhex("aabbccdd")
s = CallScope.with_selector(target=target, selector=sel)
assert len(s.selector_rules) == 1
assert s.selector_rules[0].recipients == ()

def test_transfer_with_recipients(self):
recipient = "0x" + "bb" * 20
s = CallScope.transfer(target=ALPHA_USD, recipients=[recipient])
assert len(s.selector_rules) == 1
assert len(s.selector_rules[0].recipients) == 1

def test_to_abi_tuple_fallback(self):
"""CallScope without selector_rules falls back to selector field."""
s = CallScope(target="0x" + "aa" * 20, selector=bytes.fromhex("aabbccdd"))
target_bytes, rules = s.to_abi_tuple()
assert len(rules) == 1
assert rules[0][0] == bytes.fromhex("aabbccdd")
assert rules[0][1] == []

def test_to_abi_tuple_with_rules(self):
"""CallScope with selector_rules uses them directly."""
recipient = "0x" + "bb" * 20
s = CallScope.transfer(target=ALPHA_USD, recipients=[recipient])
target_bytes, rules = s.to_abi_tuple()
assert len(rules) == 1
assert rules[0][0] == bytes.fromhex("a9059cbb")
assert len(rules[0][1]) == 1


class TestSelectorRule:
"""Tests for SelectorRule."""

def test_empty_recipients(self):
r = SelectorRule(selector=bytes.fromhex("aabbccdd"))
assert r.recipients == ()

def test_with_recipients(self):
addr = "0x" + "aa" * 20
r = SelectorRule(selector=bytes.fromhex("aabbccdd"), recipients=[addr])
assert len(r.recipients) == 1

def test_frozen(self):
r = SelectorRule(selector=bytes.fromhex("aabbccdd"))
with pytest.raises(AttributeError):
r.selector = b"\x00" * 4 # type: ignore[misc]


class TestSetAndRemoveAllowedCalls:
"""Tests for AccountKeychain.set_allowed_calls and remove_allowed_calls."""

def test_set_allowed_calls_encodes(self):
key_id = "0x" + "11" * 20
scope = CallScope.transfer(target=ALPHA_USD)
call = AccountKeychain.set_allowed_calls(key_id=key_id, scopes=[scope])
assert call.to is not None
assert call.data is not None

def test_remove_allowed_calls_encodes(self):
key_id = "0x" + "11" * 20
target = "0x" + "22" * 20
call = AccountKeychain.remove_allowed_calls(key_id=key_id, target=target)
assert call.to is not None
assert call.data is not None

def test_set_allowed_calls_with_recipients(self):
key_id = "0x" + "11" * 20
recipient = "0x" + "33" * 20
scope = CallScope.transfer(target=ALPHA_USD, recipients=[recipient])
call = AccountKeychain.set_allowed_calls(key_id=key_id, scopes=[scope])
assert call.data is not None

def test_set_allowed_calls_with_selector(self):
key_id = "0x" + "11" * 20
scope = CallScope.with_selector(
target="0x" + "22" * 20,
selector=bytes.fromhex("aabbccdd"),
)
call = AccountKeychain.set_allowed_calls(key_id=key_id, scopes=[scope])
assert call.data is not None


class TestAuthorizeKeyGuards:
"""Tests for authorize_key argument validation."""

def test_rejects_allowed_calls_with_allow_any_calls_true(self):
scope = CallScope.transfer(target=ALPHA_USD)
with pytest.raises(ValueError, match="allow_any_calls"):
AccountKeychain.authorize_key(
key_id="0x" + "11" * 20,
signature_type=SignatureType.SECP256K1,
expiry=2**64 - 1,
allowed_calls=[scope],
)

def test_rejects_legacy_with_call_restrictions(self):
scope = CallScope.transfer(target=ALPHA_USD)
with pytest.raises(ValueError, match="legacy"):
AccountKeychain.authorize_key(
key_id="0x" + "11" * 20,
signature_type=SignatureType.SECP256K1,
expiry=2**64 - 1,
allowed_calls=[scope],
allow_any_calls=False,
legacy=True,
)

def test_accepts_allowed_calls_with_allow_any_calls_false(self):
scope = CallScope.transfer(target=ALPHA_USD)
call = AccountKeychain.authorize_key(
key_id="0x" + "11" * 20,
signature_type=SignatureType.SECP256K1,
expiry=2**64 - 1,
allowed_calls=[scope],
allow_any_calls=False,
)
assert call.data is not None


class TestRecoverSigner:
"""Tests for SignedKeyAuthorization.recover_signer."""
Expand Down
Loading